Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation with full CUDA graphs (without breaks) #74

Open
bm-synth opened this issue Dec 20, 2024 · 17 comments
Open

Compilation with full CUDA graphs (without breaks) #74

bm-synth opened this issue Dec 20, 2024 · 17 comments
Labels
enhancement New feature or request

Comments

@bm-synth
Copy link

bm-synth commented Dec 20, 2024

Hi @jason-huang03
(cc @akedia in #73 )

Ive been trying to compile a model that uses Sage Attention with torch.compile(model, mode='max-autotune') , following the Python custom ops tutorial. Out of the box, max-autotune yields several graph break warnings that are detrimental to performance but output "seems" correct. Following that tutorial, we try to create a custom torch operator for Sage, but it leads to the wrong output. As of the new release 2.0.1, max-autotune-no-cudagraphs works well. Any compilation type with cuda graphs will not work. This is what i tried:


# just in case: load C libraries befor
torch.ops.load_library('/mnt/SageAttention/sageattention/_fused.cpython-311-x86_64-linux-gnu.so')
torch.ops.load_library('/mnt/SageAttention/sageattention/_qattn.cpython-311-x86_64-linux-gnu.so')

@torch.library.custom_op("mylib::sage_attention", mutates_args={}, device_types="cuda")
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tensor_layout: str) -> torch.Tensor:
    return sa.sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, pv_accum_dtype="fp32+fp32", smooth_v=False)
    # return sa.sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, pv_accum_dtype="fp32", smooth_v=False)
    # return sa.sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, pv_accum_dtype="fp32", smooth_v=False)

@torch.library.register_fake("mylib::sage_attention")
def _(q, k, v, tensor_layout):
   return torch.empty(*q.shape, device=q.device, dtype=q.dtype)

# random inputs
q = k = v = torch.rand([1, 48, 9676, 64], device="cuda", dtype=torch.bfloat16)

# test registration
torch.library.opcheck(sage_attention, [q, k, v, "HND"])

# call function that was registered
out = torch.ops.mylib.sage_attention(q, k, v, tensor_layout="NHD")

# call regular function 
out = sage_attention(q, k, v, tensor_layout="NHD")

@jason-huang03 is this something you'd like to make possible in your code base?

@bm-synth bm-synth changed the title Support CUDA graph breaks How to support compilation with full CUDA graphs (without breaks) Dec 20, 2024
@bm-synth bm-synth changed the title How to support compilation with full CUDA graphs (without breaks) Compilation with full CUDA graphs (without breaks) Dec 20, 2024
@jason-huang03
Copy link
Member

Yes, we are working on it to make SageAttention compatible with torch.compile.

@jason-huang03
Copy link
Member

Thank you for your advice!

@bm-synth
Copy link
Author

@jason-huang03 just noticed my suggestion is not returning the correct output. im trying to find out why, any idea?

@jason-huang03
Copy link
Member

Perhaps it is the empty_like, which might give incorrect stride information. Can you try torch.empty instead?

@bm-synth
Copy link
Author

still the same noisy output :( do you have an example that worked for you?

@jason-huang03
Copy link
Member

Currently not. I have tested with torch.compile with non-cuda-graph mode and it works fine (wrapping the sage_attn with @torch.compiler.disable). I shall try to compile with cuda graph and inform you the result.

@bm-synth
Copy link
Author

same here, tried Sage with max-autotune-no-cudagraphs and works fine, but wont be faster than scaled_dot_product_attention with max-autotune in my model. I'm also trying to compile with CUDA graphs and i'll update you if i find out how.

@jason-huang03
Copy link
Member

Sure!

@bm-synth
Copy link
Author

bm-synth commented Dec 20, 2024

So far, no success:

  • No compilation, no fake tensor: works
  • No compilation, with fake tensors: works
  • Model compiled with torch compile default, no fake tensor: wrong results
  • Model compiled with torch compile default, with fake tensors: wrong results
  • Model compiled with torch compile default, with @torch.compiler.disable around Sage: works
  • Model compiled with torch compile max-autotune-no-cudagraphs, with @torch.compiler.disable around Sage: fails
  • Model compiled with torch compile max-autotune, with @torch.compiler.disable around Sage: works with low performance (similar to no CUDA graphs) and prints several CUDA graph break warnings:
~/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py:84: UserWarning: The CUDA Graph is empty. This usually means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:208.)

@bm-synth
Copy link
Author

@jason-huang03 success!

To make it work, i needed to specify q, k, and v as mutable parameters, and return as you suggested a fake tensor as torch.empty(*q.shape, dtype=q.dtype, device=q.device) instead of torch.empty_like(q) 🚀

I updated the PR message with the correct code.

@jason-huang03
Copy link
Member

@bm-synth
Awesome! We will include this feature in version 2.0.1 which will be available in a day or two.

@jason-huang03
Copy link
Member

Hey @bm-synth I still got blurry result😂
My code is as this:

from sageattention import sageattn
import torch.nn.functional as F

@torch.library.custom_op("mylib::sa_wrapper", mutates_args={"q", "k", "v"}, device_types="cuda")
def sa_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
    return sageattn(q, k, v, is_causal=is_causal)

@sa_wrapper.register_fake
def _(q, k, v, is_causal=False):
    return torch.empty(*q.shape, device=q.device, dtype=q.dtype)

F.scaled_dot_product_attention = sa_wrapper

Is there anything wrong with my code?

@bm-synth
Copy link
Author

bm-synth commented Dec 20, 2024

@jason-huang03 would that blurring happen if you call sa_wrapper directly? Also try to remove default variable assignments such as is_causal: bool = False. Also, try to add torch.compiler.cudagraph_mark_step_begin() before every iteration.

@jason-huang03 jason-huang03 added the enhancement New feature or request label Dec 21, 2024
@bm-synth
Copy link
Author

bm-synth commented Dec 24, 2024

@jason-huang03 you're right, max-autotune will not yield graph breaks warnings but will return noisy outputs.

I've been doing some debugging.

I checked the number of calls and input sizes when calling the compiled version of sage_attention and it matches the non-compiled implementation.

I also looked at the The Custom Operators Manual, it suggested to call a function registered as @torch.library.custom_op("mylib::sage_attention" via torch.ops.mylib.sage_attention(q, k, v , ...). I tried that, didn't work. So I ran the torch.library.opcheck function to check if there was any registration errors. There are no errors.

I also tried to load the C libraries before, also recommended in the documentation, via torch.ops.load_library(*.so). No luck.

As a side note, i updated this issue text with that code.

For sanity checking, can you double check that your implementations follows the sections "How to add CPU/CUDA/Backend implementations" in The Custom Operators Manual and Custom C++ and CUDA Operators?

I will now try to compare the input and output tensor values of the compiled vs uncompiled version. Meanwhile, we should bring this with the CogVideoX example to the torch developers for further help.

@bm-synth
Copy link
Author

I will now try to compare the input and output tensor values of the compiled vs uncompiled version.

ok i checked, it's the output of sageattn_qk_int8_pv_fp8_cuda that is different on the fake tensors use case:

  • if you compile with max-autotune without fake tensors (and with graph break warnings), max-autotune-no-cudagraphs without fake tensors, or max-autotune-no-cudagraphs with fake tensor, the outputs are the same.
  • if you compile max-autotune with fake tensor, then the input to sageattn_qk_int8_pv_fp8_cuda is the same but the output is not.

We should ask torch devs for help. For the time being, can we allow the code to compile with cuda graphs and ignore sage?

@jason-huang03
Copy link
Member

Nice observation. What do you mean by "allow code to compile with cuda graphs and ignore sage"? do you mean using the triton kernel instead of the cuda kernel?

@bm-synth
Copy link
Author

bm-synth commented Dec 24, 2024

"allow code to compile with cuda graphs and ignore sage"

i meant instead of trying to have a CUDA graph of the whole run, that includes the sage operation and yields bad results, we can try to have a CUDA graph of the whole code except sage (maybe disable compilation of it?)

do you mean using the triton kernel instead of the cuda kernel?

Triton kernels work well with max-aututone and the fake tensors (I am not sure here, i will test again later), but at the moment we only have the triton kernlel for fp16 quantization. I believe the visible speedups on sage (in my use case and CogVideoX) are from the int8 quantization. Are you planning to write the int8 sage on triton as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants