Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash2 and supports cross attention and dropout #905

Merged
merged 3 commits into from
Jan 8, 2025

Conversation

hanzhi713
Copy link
Member

@hanzhi713 hanzhi713 commented Jan 6, 2025

This PR updates pallas implementation to FlashAttention 2. Additionally, we add supports for the following features:

  1. q_seq_len != kv_seq_len
  2. dropout for triton kernels
    3. cudnn dropout.

Additional changes:

  1. Added NoPopDict to workaround a jax bug. See GPU pallas_call loses compiler params during second call when double jit-wrapped jax-ml/jax#25714.
  2. Separated dKdV and dQ into two kernels in flash backward. This improved performance by 10~15% when head_dim >= 128. Note that technically fusing dKdVdQ into a single loop and use atomic add for dQ is the fastest solution, but pallas atomics are extremely slow according to empirical testing.

This PR requires Jax >= 0.4.34 and can only be merged after Jax upgrade.

Update: cudnn dropout is disabled in this PR to allow it to be merged before jax upgrade. A follow-up PR will be created to enable cudnn dropout after jax upgrade.

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 6, 2025 20:43
@hanzhi713 hanzhi713 changed the title Support cross attention and dropout Flash2 and supports cross attention and dropout Jan 6, 2025
@hanzhi713
Copy link
Member Author

Hi, @ruomingp can you re-approve this PR? I removed cudnn dropout support (but kept the unit tests) so that it can be merged before the jax upgrade and used by our users.

@hanzhi713 hanzhi713 enabled auto-merge January 8, 2025 20:55
@hanzhi713 hanzhi713 requested a review from ruomingp January 8, 2025 21:43
@hanzhi713 hanzhi713 added this pull request to the merge queue Jan 8, 2025
Merged via the queue into apple:main with commit 6559036 Jan 8, 2025
6 checks passed
@hanzhi713 hanzhi713 deleted the flash-2 branch January 8, 2025 22:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants