Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Cutlass kernel causes no grad in torch backward pass #1980

Open
MinghaoYan opened this issue Dec 10, 2024 · 2 comments
Open

[QST] Cutlass kernel causes no grad in torch backward pass #1980

MinghaoYan opened this issue Dec 10, 2024 · 2 comments

Comments

@MinghaoYan
Copy link

What is your question?

import cutlass

import torch
from torch.autograd import Function

class GroupedGemm(Function):
    @staticmethod
    def forward(ctx, As, Bs):
        # Validate inputs
        assert len(As) == len(Bs), "Number of A and B matrices must match"
        for A, B in zip(As, Bs):
            assert A.size(-1) == B.size(-2), f"Incompatible dimensions for GEMM: {A.size()} and {B.size()}"

        # Save inputs for backward
        ctx.save_for_backward(*As, *Bs)

        # Prepare CUTLASS plan
        plan = cutlass.op.GroupedGemm(element=As[0].dtype, layout=cutlass.LayoutType.RowMajor)

        # Prepare Cs and Ds for CUTLASS GEMM
        Cs = [torch.zeros(A.size(0), B.size(-1), device=A.device, dtype=A.dtype) for A, B in zip(As, Bs)]
        Ds = [torch.empty_like(C) for C in Cs]

        # Run CUTLASS grouped GEMM
        plan.run(As, Bs, Cs, Ds)

        result = torch.cat([d for d in Ds], dim=0)

        return result

    @staticmethod
    def backward(ctx, grad):
        """
        Compute gradients using CUTLASS plan.run in the backward pass.
        """
        grad = grad.contiguous()
        num_problems = len(ctx.saved_tensors) // 2
        As = ctx.saved_tensors[:num_problems]
        Bs = ctx.saved_tensors[num_problems:]

        # Prepare CUTLASS plan
        plan = cutlass.op.GroupedGemm(element=grad.dtype, layout=cutlass.LayoutType.RowMajor)

        # Compute gradient w.r.t. As
        agrad_list = []
        if ctx.needs_input_grad[0]:
            for grad_i, B in zip(torch.split(grad, [A.size(0) for A in As]), Bs):
                B_transposed = B.transpose(-2, -1)
                A_grad = torch.zeros(grad_i.size(0), B_transposed.size(1), device=grad.device, dtype=grad.dtype)
                plan.run([grad_i], [B_transposed], [A_grad], [A_grad])
                agrad_list.append(A_grad)

        # Compute gradient w.r.t. Bs
        bgrad_list = []
        if ctx.needs_input_grad[1]:
            for A, grad_i in zip(As, torch.split(grad, [A.size(0) for A in As])):
                A_transposed = A.transpose(-2, -1)
                B_grad = torch.zeros_like(Bs[0])
                plan.run([A_transposed], [grad_i], [B_grad], [B_grad])
                bgrad_list.append(B_grad)

        # Return gradients
        agrad = agrad_list if ctx.needs_input_grad[0] else None
        bgrad = bgrad_list if ctx.needs_input_grad[1] else None
        return agrad, bgrad

def test_lora_gradient():
    A = torch.randn(10, 5, requires_grad=True, device="cuda")
    B = torch.randn(5, 2, requires_grad=True, device="cuda")
    C = GroupedGemm.apply([A], [B])
    print(C, f"C require grad is {C[0].requires_grad}")
    loss = C.sum()
    loss.backward()

    print(A.grad)
    print(B.grad)

    assert A.grad is not None, "Gradient not computed for A"
    assert B.grad is not None, "Gradient not computed for B"

test_lora_gradient()

I'm trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:

Traceback (most recent call last):
  File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 87, in <module>
    test_lora_gradient()
  File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 79, in test_lora_gradient
    loss.backward()
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

What is the best way to wrap a Cutlass kernel?

@jackkosaian
Copy link
Contributor

I'm not familiar with how to ensure that a custom PyTorch layer works with autograd. However, the recommended way for using a CUTLASS kernel in the PyTorch is by exporting it to a PyTorch CUDA extension. You can see an example of this here.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants