Skip to content

Commit

Permalink
Accept Hopper matmuls and update default heuristic (#3579)
Browse files Browse the repository at this point in the history
This updates the default (non-plugin) matmul heuristic to support Hopper
matmuls. This change means that we can not run matmuls on Hopper
similarly to how we do it on Ampere and Turing, including using the
Python interface.

I tried to make the default heuristic somewhat thoughtful and not just a
placeholder. Here are some notes about the Hopper heuristic in its
current form:
- I set the macro to Hopper_64_64_16. I intended to always use the
largest macro for which the N size divided the problem's N, but this led
to lower perf on the handful of examples I looked at. We should
benchmark more and find out why this is once we have warp specialization
and register stealing fully plumbed in, but for the time being I simply
left it at N=64.
- Once the instruction tile is set we set the warp tile equal to the
instruction tile (we can revisit this in the future). Then to find the
CTA tile we double the instruction tile in the M or N dimension until we
run out of registers.
- We start with 8 circular buffering stages and decrease until the
circular buffers fit into smem.
- We use `use_smem_epilogue` when possible. Whenever that is possible we
_always_ use `promote_prologue_smem_reuse` even if it's not needed. This
is to try and avoid bugs like #3602.
- I set the tile rasterization order so that the fast axis is the axis
with the fewest tiles, which should encourage more L2 hits unless there
are tons of tiles in each dimension.
- I cannot yet set grid swizzling due to #3671, but I placed a TODO
comment and some code to do the proper swizzling.

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
jacobhinkle and rdspring1 authored Jan 8, 2025
1 parent e172781 commit 9ce2112
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 47 deletions.
2 changes: 1 addition & 1 deletion csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class MatmulParams : public HeuristicParams {

//! This is the CGA size on Hopper+ devices. This parameter is ignored on
//! Ampere and Turing.
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {2, 1, 1};
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {1, 1, 1};

std::string toString() const override {
std::stringstream ss;
Expand Down
6 changes: 2 additions & 4 deletions csrc/scheduler/matmul_heuristic_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ std::string rolesToPrecisionString(
std::string precision = " ";
const std::vector<TensorView*>& a_operands =
tensor_roles.at(MatmulTensorRole::OPERAND_A);
NVF_ERROR(
a_operands.size() == 1, "We currently require exactly one A operand");
NVF_ERROR(!a_operands.empty(), "We currently require at least one A operand");
const std::vector<TensorView*>& b_operands =
tensor_roles.at(MatmulTensorRole::OPERAND_B);
NVF_ERROR(
b_operands.size() == 1, "We currently require exactly one B operand");
NVF_ERROR(!b_operands.empty(), "We currently require at least one B operand");
TensorView* a = a_operands.front();
TensorView* b = b_operands.front();
NVF_CHECK(
Expand Down
275 changes: 249 additions & 26 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
#include <ir/interface_nodes.h>
#include <ir/internal_nodes.h>
#include <ir/utils.h>
#include <mma_type.h>
#include <options.h>
#include <runtime/executor_utils.h>
#include <scheduler/mma_utils.h>
#include <type.h>
#include <utils.h>
#include <val_graph.h>

#include <algorithm>
#include <deque>
#include <iostream>
Expand All @@ -32,11 +37,8 @@
#include <type_traits>
#include <utility>
#include <variant>
#include "ATen/cuda/CUDAContext.h"
#include "mma_type.h"
#include "mma_utils.h"
#include "type.h"
#include "utils.h"

#include <ATen/cuda/CUDAContext.h>

namespace nvfuser {
namespace matmul_utils {
Expand All @@ -49,25 +51,44 @@ using ProblemShape = std::array<int64_t, 4>;
inline std::optional<MmaMacro> getMmaOp(
const int dev_version,
const ProblemShape& problem) {
using MacroType = MmaMacro;
const int64_t n_extent = problem[(size_t)MatmulDimRole::N];

// NOTE: A temp condition
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N];
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0);
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16};

switch (dev_version) {
case 75:
return (use_small_n) ? MacroType::Turing_16_8_16
: MacroType::Turing_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Turing;
if ((n_extent % 16) == 0) {
macro_encode.n = 16;
}
break;
case 80:
case 86:
case 89:
case 90: // NOTE: temp use ampere matmul for hopper
return (use_small_n) ? MacroType::Ampere_16_8_16
: MacroType::Ampere_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Ampere;
if ((n_extent % 16) == 0) {
macro_encode.n = 16;
}
break;
case 90:
macro_encode.arch = MmaMacroEncode::Arch::Hopper;
macro_encode.m = 64;
// Find the largest instruction tile that divides the problem size and is
// a power of two
macro_encode.n = 64;
// TODO: enable instructions smaller than 64_64_16
while (macro_encode.n > 64) {
if (n_extent % macro_encode.n != 0) {
macro_encode.n /= 2;
} else {
break;
}
}
break;
default:
return std::nullopt;
}
return macro_encode;
}

//! Find the number of circular buffer stages for shared memory operands, so
Expand All @@ -93,9 +114,9 @@ void limitCircularBufferingSmemOperands(
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages;
}

//! A wrapper for core heuristics initialization.
//! We should have already set mparams->mma_macro before calling this function.
inline bool initCoreHeuristics(
namespace {

bool fillDefaultAmpereHeuristic(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles,
Expand Down Expand Up @@ -170,6 +191,7 @@ inline bool initCoreHeuristics(
}
return min_size_bytes;
};
// Use cp.async on Ampere if possible
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported(
mparams,
std::min(
Expand All @@ -186,6 +208,180 @@ inline bool initCoreHeuristics(
return true;
}

bool fillDefaultHopperHeuristic(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles,
const size_t num_problems) {
const auto device_prop = at::cuda::getCurrentDeviceProperties();

const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro);
GemmTile warp_tile = {-1, -1, -1};
GemmTile cta_tile = {-1, -1, -1};

using DimType = decltype(GemmTile::m);

// We typically use larger macros on Hopper. By default we will set the
// warp tile equal to the macro and increase the CTA tile until we hit
// a limit. The limits are given by the maximum number of threads per CTA.

// TODO: it might be advantageous in some cases to issue multiple wgmma
// instructions per warp group
warp_tile = instruction_tile;

// The MmaOp output is a 32-bit float which requires one register per value

// total accumulator registers for warp group
const size_t accum_regs_per_warp_group =
warp_tile.m * warp_tile.n * num_problems;

// The cta tile is a multiple of the warp tile. This lambda checks that cta
// tile given by warp_tile and multiple fits on the SM.
const auto validate_cta_tile_multiple = [&](const DimType m_ratio,
const DimType n_ratio) {
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
DimType num_compute_warp_groups = m_ratio * n_ratio;

// This assumes warp specialization:
// tma warp group + compute warp groups
DimType num_warp_groups = num_compute_warp_groups + 1;

const int64_t threads_per_sm = num_warp_groups * 128;
const size_t max_registers_per_sm =
getRegPerThreadGivenThreadsPerSM(threads_per_sm) * threads_per_sm;
return
// We store one float per CTA tile element for each matmul problem we
// compute
num_warp_groups * accum_regs_per_warp_group < max_registers_per_sm
// TMA box dimensions must be less than or equal to 256
&& cta_m <= 256 &&
cta_n <= 256
// Each warp group is 128 threads. We can only have a maximum of 1024
// threads per SM, or 8 warp groups.
&& num_warp_groups <= 8 &&
// Don't extend the CTA tile beyond the problem size
cta_m <= problem_shape[(size_t)MatmulDimRole::M] &&
cta_n <= problem_shape[(size_t)MatmulDimRole::N];
};

DimType m_ratio = 1;
DimType n_ratio = 1;

bool increased = true;
while (increased) {
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
increased = false;

const auto try_increaseM = [&]() {
if (validate_cta_tile_multiple(m_ratio * 2, n_ratio)) {
m_ratio *= 2;
increased = true;
}
return increased;
};
const auto try_increaseN = [&]() {
if (validate_cta_tile_multiple(m_ratio, n_ratio * 2)) {
n_ratio *= 2;
increased = true;
}
return increased;
};

if (cta_m < cta_n) {
// Try to increase smaller tile dimension first since square tiles are
// optimal for reducing operand load redundancy
if (try_increaseM()) {
continue;
}
try_increaseN();
} else {
if (try_increaseN()) {
continue;
}
try_increaseM();
}
}

cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k};

mparams->tile_sizes = {cta_tile, warp_tile};

// stages and async mem copy
mparams->circular_buffer_options.smem_circular_buffer_stage = 8;

// TODO: We should take the main loop structure into account here to get a
// more accurate estimate in case of horizontal fusion
int64_t operand_smem_per_stage =
(int64_t)num_problems * 2 * (cta_tile.m + cta_tile.n) * cta_tile.k;
// We leave a bit of space for semaphores
int64_t max_operand_smem =
(int64_t)device_prop->sharedMemPerBlock - (1L << 7);

while (mparams->circular_buffer_options.smem_circular_buffer_stage *
operand_smem_per_stage >
max_operand_smem) {
mparams->circular_buffer_options.smem_circular_buffer_stage--;
}

mparams->circular_buffer_options.circular_buffer_smem_write =
mparams->circular_buffer_options.smem_circular_buffer_stage > 1;

// Always use TMA on Hopper
mparams->async_gmem_load_operands = true;

// See here for more information:
// https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/

// We count the number of tiles in each dimension to determine the
// rasterization order. The fast rasterization axis is the shortest axis, to
// encourage L2 hits by looping over the same rows or cols more frequently.
int64_t Mtiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::M], cta_tile.m);
int64_t Ntiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::N], cta_tile.n);

mparams->cta_order = Ntiles >= Mtiles
? MatmulParams::TileRasterizationOrder::ColumnMajor
: MatmulParams::TileRasterizationOrder::RowMajor;

// We also swizzle the tiles as much as possible up to 4 tiles. Like choosing
// the rasterization order, this is used to increase L2 locality
mparams->grid_swizzle_factor = 4L;
while (Mtiles % mparams->grid_swizzle_factor != 0 ||
Ntiles % mparams->grid_swizzle_factor != 0) {
// Decrease the swizzle factor if it would result in nondivisible splits,
// since this would unnecessarily increase the grid size.
mparams->grid_swizzle_factor /= 2L;
}
// TODO: grid swizzling is currently disabled on Hopper since we cannot
// properly inline when we swizzle unmapped loop broadcasts
mparams->grid_swizzle_factor = 1L;

// TODO: Finally, we set the CGA size

return true;
}

} // namespace

//! A wrapper for core heuristics initialization.
//! We should have already set mparams->mma_macro before calling this function.
inline bool initCoreHeuristics(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles,
const size_t num_problems) {
if (isHopper(mparams->mma_macro)) {
return fillDefaultHopperHeuristic(
mparams, problem_shape, tensor_roles, num_problems);
} else if (isAmpere(mparams->mma_macro) || isTuring(mparams->mma_macro)) {
return fillDefaultAmpereHeuristic(
mparams, problem_shape, tensor_roles, num_problems);
}
// Unsupported arch
return false;
}

//! A helper for getting problem shape from fusion and runtime info.
//!
//! For a given domain, try to find the size by evaluating the extent of an
Expand Down Expand Up @@ -790,7 +986,15 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics(
mma_utils::generateSharedMemoryEpilogueHeuristics(
mparams->tile_sizes,
mparams->circular_buffer_options.smem_circular_buffer_stage,
tensor_roles);
tensor_roles,
/*ignore_occupancy_drop=*/true);
if (isHopper(mparams->mma_macro)) {
// Always promote smem reuse for Hopper. This is needed because we use TMA
// which has higher alignment requirements, so it's important that we place
// our TMA buffers at an offset that's a multiple of 64 (like 0) if
// possible.
mparams->promote_prologue_smem_reuse = true;
}

if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
debug() << mparams->toString() << std::endl;
Expand Down Expand Up @@ -842,13 +1046,25 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
{
for (const mma_utils::MatmulPattern& pattern : patterns) {
Expr* op = pattern.output->definition();
if (device_prop->major >= 9 && op->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
if (found_reduction &&
!pattern.output->axis((int64_t)dim)->isReduction()) {
return "Mul+Sum patterns can only be translated to MmaOp "
"on Hopper if the reduction dim is innermost";
if (device_prop->major >= 9) {
for (TensorView* operand : {pattern.A, pattern.B}) {
if (!operand->isFusionInput() &&
(operand->definition() == nullptr ||
!operand->definition()->isA<LoadStoreOp>() ||
!operand->definition()->input(0)->isFusionInput() ||
operand->hasRoot())) {
return "Operand " + operand->toString() +
" must be a fusion input or non-permuting LoadStoreOp of an input on Hopper";
}
}
if (op->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
if (found_reduction &&
!pattern.output->axis((int64_t)dim)->isReduction()) {
return "Mul+Sum patterns can only be translated to MmaOp "
"on Hopper if the reduction dim is innermost";
}
}
}
}
Expand Down Expand Up @@ -922,7 +1138,14 @@ std::string getMatmulRunTimeRejectReason(
Fusion* fusion,
HeuristicDataCache* data_cache,
SchedulerRuntimeInfo& runtime_info) {
// TODO: add proper set of checks
const auto device_prop = at::cuda::getCurrentDeviceProperties();

if (device_prop->major >= 9 &&
runtime_info.getIndexType() != DataType::Int32) {
// See https://github.com/NVIDIA/Fuser/issues/3595
return "Hopper matmul is not yet supported with problem sizes requiring 64-bit indexing";
}

return "";
}

Expand Down
Loading

0 comments on commit 9ce2112

Please sign in to comment.