Skip to content

Commit

Permalink
Flash2 and supports cross attention and dropout (#905)
Browse files Browse the repository at this point in the history
* Support cross attention and dropout

* Fix comments

* Disable cudnn dropout
  • Loading branch information
hanzhi713 authored Jan 8, 2025
1 parent ee7d60d commit 6559036
Show file tree
Hide file tree
Showing 8 changed files with 804 additions and 700 deletions.
929 changes: 396 additions & 533 deletions axlearn/common/flash_attention/gpu_attention.py

Large diffs are not rendered by default.

138 changes: 69 additions & 69 deletions axlearn/common/flash_attention/gpu_attention_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,95 +10,95 @@
"""FlashAttention kernel benchmarks.
Tor run: python3 gpu_attention_benchmark.py > out.txt
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5:
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5 with Jax == 0.4.36:
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn
bs=1,seq_len=1024 0.020608 0.018656 0.023680
bs=1,seq_len=4096 0.037856 0.022784 0.056704
bs=1,seq_len=8192 0.033792 0.032768 0.104448
bs=1,seq_len=131072 0.227808 0.198816 1.486752
bs=4,seq_len=1024 0.021440 0.022208 0.024032
bs=4,seq_len=4096 0.069728 0.054624 0.059584
bs=4,seq_len=8192 0.081952 0.076064 0.105920
bs=4,seq_len=131072 0.823104 0.705056 1.488832
bs=8,seq_len=1024 0.032544 0.030688 0.024608
bs=8,seq_len=4096 0.089728 0.071648 0.063584
bs=8,seq_len=8192 0.129184 0.114944 0.109856
bs=8,seq_len=131072 1.616800 1.376288 1.503360
bs=16,seq_len=1024 0.050976 0.048608 0.037504
bs=16,seq_len=4096 0.136768 0.117312 0.104224
bs=16,seq_len=8192 0.234688 0.200128 0.190944
bs=16,seq_len=131072 3.211200 2.727040 2.779872
bs=32,seq_len=1024 0.078656 0.072992 0.061440
bs=32,seq_len=4096 0.236576 0.204512 0.190752
bs=32,seq_len=8192 0.443488 0.372352 0.361216
bs=32,seq_len=131072 6.392320 5.453344 5.495488
bs=1,seq_len=1024 0.020832 0.017536 0.024128
bs=1,seq_len=4096 0.037472 0.021248 0.058656
bs=1,seq_len=8192 0.034016 0.032576 0.108576
bs=1,seq_len=131072 0.229856 0.198944 1.558464
bs=4,seq_len=1024 0.021632 0.023296 0.024352
bs=4,seq_len=4096 0.068064 0.055168 0.061312
bs=4,seq_len=8192 0.080352 0.075968 0.109696
bs=4,seq_len=131072 0.824576 0.703360 1.560768
bs=8,seq_len=1024 0.033536 0.030304 0.024448
bs=8,seq_len=4096 0.089056 0.071712 0.062944
bs=8,seq_len=8192 0.128960 0.114848 0.112736
bs=8,seq_len=131072 1.620032 1.373088 1.566208
bs=16,seq_len=1024 0.050368 0.048064 0.036608
bs=16,seq_len=4096 0.134816 0.116320 0.104320
bs=16,seq_len=8192 0.234880 0.200384 0.191936
bs=16,seq_len=131072 3.219008 2.726912 2.784768
bs=32,seq_len=1024 0.078112 0.070816 0.061568
bs=32,seq_len=4096 0.235648 0.203296 0.191936
bs=32,seq_len=8192 0.442080 0.371936 0.365152
bs=32,seq_len=131072 6.404832 5.448480 5.541504
is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn
bs=1,num_kv_heads=1 0.049280 0.059296 0.378304
bs=1,num_kv_heads=8 0.076352 0.070912 0.377344
bs=8,num_kv_heads=1 0.111072 0.080480 0.377696
bs=8,num_kv_heads=8 0.425536 0.368576 0.386880
bs=1,num_kv_heads=1 0.027648 0.058464 0.398816
bs=1,num_kv_heads=8 0.076096 0.070368 0.398912
bs=8,num_kv_heads=1 0.101696 0.078560 0.399040
bs=8,num_kv_heads=8 0.426656 0.367616 0.403360
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128
jax axlearn jax-cudnn
bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928
bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376
bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992
bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224
bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296
bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648
bs=1,seq_len=131072,sw_sz=-1 0.230336 0.199968 1.559168
bs=1,seq_len=131072,sw_sz=4096 0.235296 0.051296 4.414048
bs=1,seq_len=131072,sw_sz=16384 0.235904 0.062976 4.385216
bs=8,seq_len=131072,sw_sz=-1 1.619008 1.372768 1.570272
bs=8,seq_len=131072,sw_sz=4096 1.635424 0.194720 4.390976
bs=8,seq_len=131072,sw_sz=16384 1.632832 0.321280 4.361984
is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
bs=2 3.502944 0.915360 0.467744 0.845792
bs=4 6.969376 1.753152 0.890496 1.617280
bs=8 13.962816 3.415232 1.735232 3.150752
bs=2 3.583424 0.894912 0.488480 0.852960
bs=4 7.107168 1.712448 0.922592 1.629888
bs=8 14.202400 3.341568 1.801920 3.184064
is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
num_heads=12 1.262560 0.393536 0.205952 0.362304
num_heads=16 1.786816 0.498304 0.257664 0.459936
num_heads=32 3.507488 2.591456 0.468672 2.443296
num_heads=48 5.246336 1.338272 0.675968 1.231328
num_heads=72 7.866848 1.961152 0.995712 1.805376
num_heads=12 1.287712 0.383200 0.214400 0.365120
num_heads=16 1.803232 0.485408 0.270496 0.463040
num_heads=32 3.578208 0.896576 0.488544 2.468096
num_heads=48 5.346112 1.305856 0.707872 1.241728
num_heads=72 8.001568 1.915776 1.035200 1.820288
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
seq_len=128 0.030592 0.011584 0.013024 0.012960
seq_len=256 0.051520 0.015648 0.016640 0.015744
seq_len=512 0.118720 0.038976 0.028224 0.037152
seq_len=1024 0.310880 0.096256 0.054784 0.090368
seq_len=2048 0.931072 0.277312 0.150784 0.256928
seq_len=4096 3.516672 2.595872 0.465568 2.448128
seq_len=256 0.049184 0.015360 0.016352 0.015488
seq_len=512 0.110400 0.038624 0.028480 0.037760
seq_len=1024 0.302304 0.094560 0.056736 0.090464
seq_len=2048 0.936832 0.269856 0.154304 0.258944
seq_len=4096 3.584800 0.895776 0.487104 2.462560
seq_len=8192 14.260608 3.268320 1.742048 3.104640
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
per_head_dim=16 3.220960 0.487808 0.332928 0.478720
per_head_dim=32 3.277824 0.530240 0.334624 0.515040
per_head_dim=64 3.345376 0.696480 0.338944 0.631296
per_head_dim=128 3.515616 2.594208 0.465824 2.442784
per_head_dim=16 3.262592 0.518912 0.356544 0.477120
per_head_dim=32 3.323552 0.563520 0.358944 0.533344
per_head_dim=64 3.411744 0.690464 0.360192 0.635296
per_head_dim=128 3.585920 0.896032 0.488416 2.461696
is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
bs=2 10.780096 4.573344 2.080672 4.487104
bs=4 21.426336 9.336192 3.988224 9.159904
bs=8 42.808033 18.926559 7.975296 18.075487
bs=2 10.878624 3.924992 2.123008 4.504256
bs=4 21.626017 8.043040 4.071552 9.186080
bs=8 43.269279 16.195999 8.124896 18.184799
is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
num_heads=12 4.128352 1.738016 0.882976 1.696704
num_heads=16 5.467808 2.307488 1.120608 2.247904
num_heads=32 10.782432 4.559456 2.082592 4.488448
num_heads=48 16.119776 6.958272 3.027808 6.858144
num_heads=72 24.140833 10.706656 4.560288 10.279136
num_heads=12 4.159424 1.519680 0.898816 1.711808
num_heads=16 5.486912 2.001952 1.142144 2.256960
num_heads=32 10.886848 3.928896 2.114496 4.502976
num_heads=48 16.224319 6.085408 3.093696 6.888640
num_heads=72 24.367489 9.190560 4.642720 10.323552
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
seq_len=128 0.058944 0.037824 0.039040 0.036384
seq_len=256 0.100384 0.069024 0.052608 0.067872
seq_len=512 0.317056 0.159904 0.111840 0.158912
seq_len=1024 0.906400 0.431104 0.244160 0.421792
seq_len=2048 2.861056 1.319648 0.655840 1.297728
seq_len=4096 10.762560 4.576864 2.079904 4.489056
seq_len=256 0.094496 0.060096 0.053184 0.065760
seq_len=512 0.297440 0.139328 0.112736 0.161664
seq_len=1024 0.886304 0.361536 0.246848 0.418720
seq_len=2048 2.857952 1.118368 0.675168 1.294144
seq_len=4096 10.880512 3.914048 2.119808 4.503936
seq_len=8192 43.000095 14.913824 7.484128 16.730017
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
per_head_dim=16 10.084800 1.744640 1.263264 1.711296
per_head_dim=32 10.204480 2.098816 1.291104 2.041184
per_head_dim=64 10.374720 2.649888 1.335200 2.510304
per_head_dim=128 10.779680 4.568096 2.079264 4.489792
per_head_dim=16 10.150080 1.826656 1.288192 1.718688
per_head_dim=32 10.277440 2.028608 1.316512 2.048864
per_head_dim=64 10.463904 2.569408 1.364448 2.540512
per_head_dim=128 10.875328 3.929568 2.124192 4.502912
"""
# pylint: enable=line-too-long
import itertools
Expand Down Expand Up @@ -365,8 +365,8 @@ def bench_flash_attention_fwd_bwd(use_bwd: bool):
libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"]
benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8])
benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72])
# 128 to 4096.
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)])
# 256 to 8192.
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(8, 14)])
benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128])


Expand Down
Loading

0 comments on commit 6559036

Please sign in to comment.