-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation with full CUDA graphs (without breaks) #74
Comments
Yes, we are working on it to make SageAttention compatible with torch.compile. |
Thank you for your advice! |
@jason-huang03 just noticed my suggestion is not returning the correct output. im trying to find out why, any idea? |
Perhaps it is the |
still the same noisy output :( do you have an example that worked for you? |
Currently not. I have tested with |
same here, tried Sage with |
Sure! |
So far, no success:
|
@jason-huang03 success! To make it work, i needed to specify I updated the PR message with the correct code. |
@bm-synth |
Hey @bm-synth I still got blurry result😂 from sageattention import sageattn
import torch.nn.functional as F
@torch.library.custom_op("mylib::sa_wrapper", mutates_args={"q", "k", "v"}, device_types="cuda")
def sa_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
return sageattn(q, k, v, is_causal=is_causal)
@sa_wrapper.register_fake
def _(q, k, v, is_causal=False):
return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
F.scaled_dot_product_attention = sa_wrapper Is there anything wrong with my code? |
@jason-huang03 would that blurring happen if you call |
@jason-huang03 you're right, I've been doing some debugging. I checked the number of calls and input sizes when calling the compiled version of I also looked at the The Custom Operators Manual, it suggested to call a function registered as I also tried to load the C libraries before, also recommended in the documentation, via As a side note, i updated this issue text with that code. For sanity checking, can you double check that your implementations follows the sections "How to add CPU/CUDA/Backend implementations" in The Custom Operators Manual and Custom C++ and CUDA Operators? I will now try to compare the input and output tensor values of the compiled vs uncompiled version. Meanwhile, we should bring this with the CogVideoX example to the torch developers for further help. |
ok i checked, it's the output of
We should ask torch devs for help. For the time being, can we allow the code to compile with cuda graphs and ignore sage? |
Nice observation. What do you mean by "allow code to compile with cuda graphs and ignore sage"? do you mean using the triton kernel instead of the cuda kernel? |
i meant instead of trying to have a CUDA graph of the whole run, that includes the sage operation and yields bad results, we can try to have a CUDA graph of the whole code except sage (maybe disable compilation of it?)
Triton kernels work well with |
Hi @jason-huang03
(cc @akedia in #73 )
Ive been trying to compile a model that uses Sage Attention with
torch.compile(model, mode='max-autotune')
, following the Python custom ops tutorial. Out of the box,max-autotune
yields several graph break warnings that are detrimental to performance but output "seems" correct. Following that tutorial, we try to create a custom torch operator for Sage, but it leads to the wrong output. As of the new release 2.0.1,max-autotune-no-cudagraphs
works well. Any compilation type with cuda graphs will not work. This is what i tried:@jason-huang03 is this something you'd like to make possible in your code base?
The text was updated successfully, but these errors were encountered: