Skip to content

Commit

Permalink
Enable Vanilla Bwd and Refactor (#86)
Browse files Browse the repository at this point in the history
* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes
  • Loading branch information
micmelesse authored Oct 28, 2024
1 parent 75b5360 commit b2a2dff
Show file tree
Hide file tree
Showing 21 changed files with 6,374 additions and 2,395 deletions.
21 changes: 16 additions & 5 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,24 @@ jobs:
cd ..
- name: Build
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
python setup.py install
- name: Flash Attention Tests Using Reference Impl
run: |
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
export FLASH_ATTENTION_USE_REF=1
pytest tests/test_flash_attn_triton.py
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
pytest tests/test_flash_attn.py
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
export FLASH_ATTENTION_AUTOTUNE=0
pytest tests/test_flash_attn_triton.py
- name: AMD Kernel Tests
run: |
pytest -v -s flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
pytest -v -s flash_attn/flash_attn_triton_kernel_prefill_amd.py
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
export FLASH_ATTENTION_AUTOTUNE=0
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Kernel Bench
run: |
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ var/
# Dev
venv
scripts
*.log
core.*
*.csv
*.png
*.html
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,31 @@ FlashAttention-2 ROCm CK backend currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.

#### Triton Backend
FlashAttention-2 ROCm Triton backend is a work in progress.
It current supports Forwards only. However some features like PagedAttention and Sliding Window are missing. It can run on both MI and Navi Machines. We are working on backwards.
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16 and bf16 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements

Inorder to use the triton backend for rocm, follow the steps below.
#### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

Expand All @@ -140,10 +160,10 @@ cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_ROCM` set to `"TRUE"`.
Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_AMD` set to `"TRUE"`.

```
export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE"
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

# isort: off
# We need to import the CUDA kernels after importing torch
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE"
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_USE_TRITON_AMD", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from flash_attn import flash_attn_triton_interface_amd as flash_attn_gpu
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

Expand Down
49 changes: 49 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Flash Attention Triton Kernel
===============

#### Introduction
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16 and bf16 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements

#### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_USE_TRITON_AMD` set to `"TRUE"`.

```
export FLASH_ATTENTION_USE_TRITON_AMD="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
```

#### Credits
AMD Triton kernels team

OpenAI kernel team
Empty file.
Loading

0 comments on commit b2a2dff

Please sign in to comment.