Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST]CUTLASS support for fp8 sparse matrix(for W) multiplication for A*W=Y with GPU(SM90a/89)sparse tensor core #2029

Open
zhink opened this issue Jan 7, 2025 · 5 comments

Comments

@zhink
Copy link

zhink commented Jan 7, 2025

What is your question?
In the 62_hopper_sparse_gemm example, it seems that matrix A is not a weight, and the original weight is stored in rows; does fp8 matrix multiplication support sparse weights (does sparse weights support column storage) for matrix multiplication on SM89/90 graphics cards?

@hwu36
Copy link
Collaborator

hwu36 commented Jan 8, 2025

The sparse tensor cores in sm90a and sm89 for fp8 in the formart A: row+sparse x B: col+dense = C:dense. The sparse gemm kernel is limited by this.

@zhink
Copy link
Author

zhink commented Jan 8, 2025

Thanks very much. @hwu36 @klevzoff
1.Is there a plan to support FP8 gemm sparse matrix into A(row+dense) x B(sparse) = C(row+dense)?
2.How to apply sparse matrix multiplication to speed up inference in LLama FP8 models?
3.does int8 and fp16/bf16 is limited same to formart A: row+sparse x B: col+dense = C:dense in sm 70/80/90(a)?

@hwu36
Copy link
Collaborator

hwu36 commented Jan 9, 2025

1.Is there a plan to support FP8 gemm sparse matrix into A(row+dense) x B(sparse) = C(row+dense)?

no. we are limited by the hw.

2.How to apply sparse matrix multiplication to speed up inference in LLama FP8 models?

i am not a model guy. i cannot answer it. a little more background. for dense fp8 gemm, we also only support A: row x B: col

  1. does int8 and fp16/bf16 is limited same to formart A: row+sparse x B: col+dense = C:dense in sm 70/80/90(a)?

sm70 does not have sparse tensor cores. for fp16/bf16, A and B can be any combination of row and col.

when i said The sparse tensor cores in sm90a and sm89 for fp8 in the formart A: row+sparse x B: col+dense = C:dense, the more accurate way to say it is that in sm80/89 ldmatrix before mma can only transpose 16bit data. all sparse and dense mma instructions are in the form of rowxcol. in sm90a, ldmatrix and mma is kind of merged into one instruction, but the limitation still applies.

@klevzoff
Copy link
Contributor

klevzoff commented Jan 10, 2025

1.Is there a plan to support FP8 gemm sparse matrix into A(row+dense) x B(sparse) = C(row+dense)?

Assuming B is col-major, you can do this with swap+transpose trick based on identity C^T = (AxB)^T = B^T x A^T. Configure your mainloop with swapped A/B arguments, and also transpose layout tags for all tensors (A/B/C/D):

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
    cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
    TileShape, ClusterShape,
    cutlass::epilogue::collective::EpilogueTileAuto,
    ElementAccumulator, ElementAccumulator,
    ElementC, ColumnMajor, AlignmentC, // Note: ColumnMajor instead of RowMajor
    ElementD, ColumnMajor, AlignmentD, // Note: ColumnMajor instead of RowMajor
    EpilogueSchedule
  >::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
    cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
    ElementB, RowMajor,    AlignmentB, // Note: B+RowMajor instead of A+ColumnMajor
    ElementA, ColumnMajor, AlignmentA, // Note: A+ColumnMajor instead of B+RowMajor
    ElementAccumulator,
    TileShape, ClusterShape,
    cutlass::gemm::collective::StageCountAutoCarveout<
      static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
    KernelSchedule
  >::CollectiveOp;

Correspondingly, swap A/B pointers and strides when you construct the mainloop arguments:

typename CollectiveMainloop::Arguments {
  ptr_B, layout_B, // Note: B instead of A
  ptr_A, stride_A, // Note: A instead of B
  ptr_E, layout_E
};

and swap M/N in ProblemShape and when computing strides for C/D. You should get the expected result.

@klevzoff
Copy link
Contributor

klevzoff commented Jan 10, 2025

2.How to apply sparse matrix multiplication to speed up inference in LLama FP8 models?

Check out this blog post by NeuralMagic and their Sparse-Llama model which builds on top of these sparse matmuls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants