From 455c559b43dedc3baedc4ade7ad42b2bb455f669 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Tue, 16 Jul 2024 14:14:05 -0700 Subject: [PATCH 1/2] Mainloop schedule uniformly avialable in as GemmKernel::Schedule --- include/cutlass/gemm/kernel/sm70_gemm.hpp | 1 + include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 1 + include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp | 1 + .../gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp | 1 + .../gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp | 1 + include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp | 1 + .../gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp | 1 + .../cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp | 1 + 8 files changed, 8 insertions(+) diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index 954c9cbb6c..eb5b95164e 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -72,6 +72,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 916f92d707..74899b9abb 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -91,6 +91,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 2ddaef13b3..66a18c42cd 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -80,6 +80,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index b2077db027..19dad373e0 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -80,6 +80,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 7924b0131d..ebc512e2d8 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -81,6 +81,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index edff1d8e18..a3c1f8313f 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -78,6 +78,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 877d2c1ddf..c61041117d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -79,6 +79,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index abf79e842a..01e85d2445 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -80,6 +80,7 @@ class GemmUniversal< using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; From a0d7eed5b09c903c2017111a29d87e40947256f3 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 17 Jul 2024 09:35:55 -0700 Subject: [PATCH 2/2] tools/ --- tools/library/include/cutlass/library/descriptions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/library/include/cutlass/library/descriptions.h b/tools/library/include/cutlass/library/descriptions.h index 6a6aab434b..d12d321032 100644 --- a/tools/library/include/cutlass/library/descriptions.h +++ b/tools/library/include/cutlass/library/descriptions.h @@ -34,7 +34,7 @@ #include #include #include - +#include ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass {