You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import cutlass
import torch
from torch.autograd import Function
class GroupedGemm(Function):
@staticmethod
def forward(ctx, As, Bs):
# Validate inputs
assert len(As) == len(Bs), "Number of A and B matrices must match"
for A, B in zip(As, Bs):
assert A.size(-1) == B.size(-2), f"Incompatible dimensions for GEMM: {A.size()} and {B.size()}"
# Save inputs for backward
ctx.save_for_backward(*As, *Bs)
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=As[0].dtype, layout=cutlass.LayoutType.RowMajor)
# Prepare Cs and Ds for CUTLASS GEMM
Cs = [torch.zeros(A.size(0), B.size(-1), device=A.device, dtype=A.dtype) for A, B in zip(As, Bs)]
Ds = [torch.empty_like(C) for C in Cs]
# Run CUTLASS grouped GEMM
plan.run(As, Bs, Cs, Ds)
result = torch.cat([d for d in Ds], dim=0)
return result
@staticmethod
def backward(ctx, grad):
"""
Compute gradients using CUTLASS plan.run in the backward pass.
"""
grad = grad.contiguous()
num_problems = len(ctx.saved_tensors) // 2
As = ctx.saved_tensors[:num_problems]
Bs = ctx.saved_tensors[num_problems:]
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=grad.dtype, layout=cutlass.LayoutType.RowMajor)
# Compute gradient w.r.t. As
agrad_list = []
if ctx.needs_input_grad[0]:
for grad_i, B in zip(torch.split(grad, [A.size(0) for A in As]), Bs):
B_transposed = B.transpose(-2, -1)
A_grad = torch.zeros(grad_i.size(0), B_transposed.size(1), device=grad.device, dtype=grad.dtype)
plan.run([grad_i], [B_transposed], [A_grad], [A_grad])
agrad_list.append(A_grad)
# Compute gradient w.r.t. Bs
bgrad_list = []
if ctx.needs_input_grad[1]:
for A, grad_i in zip(As, torch.split(grad, [A.size(0) for A in As])):
A_transposed = A.transpose(-2, -1)
B_grad = torch.zeros_like(Bs[0])
plan.run([A_transposed], [grad_i], [B_grad], [B_grad])
bgrad_list.append(B_grad)
# Return gradients
agrad = agrad_list if ctx.needs_input_grad[0] else None
bgrad = bgrad_list if ctx.needs_input_grad[1] else None
return agrad, bgrad
def test_lora_gradient():
A = torch.randn(10, 5, requires_grad=True, device="cuda")
B = torch.randn(5, 2, requires_grad=True, device="cuda")
C = GroupedGemm.apply([A], [B])
print(C, f"C require grad is {C[0].requires_grad}")
loss = C.sum()
loss.backward()
print(A.grad)
print(B.grad)
assert A.grad is not None, "Gradient not computed for A"
assert B.grad is not None, "Gradient not computed for B"
test_lora_gradient()
I'm trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:
Traceback (most recent call last):
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 87, in <module>
test_lora_gradient()
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 79, in test_lora_gradient
loss.backward()
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
What is the best way to wrap a Cutlass kernel?
The text was updated successfully, but these errors were encountered:
I'm not familiar with how to ensure that a custom PyTorch layer works with autograd. However, the recommended way for using a CUTLASS kernel in the PyTorch is by exporting it to a PyTorch CUDA extension. You can see an example of this here.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
What is your question?
I'm trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:
What is the best way to wrap a Cutlass kernel?
The text was updated successfully, but these errors were encountered: