Skip to content

Commit

Permalink
Dropout (#101)
Browse files Browse the repository at this point in the history
* Alex's work

This is a combination of 11 commits.

save

fix: dropout=0.0 woorks

feat: dropout restrictions removed. failing tests

test: reduced tests to simple cases

test: failure is due to query + key padding mask NOT varlen itself

feat: varlen dropout fwd passes

fix: varlen bwd dropout works!

test: discovered  bwd error for non-dropout cases for large seqlen

save

save

use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes

* Almost Everything works.

This is a combination of 16 commits.

Work so far

This is a combination of 63 commits.

pick test case

save philox offsets into metadata

pass offset to ref

common dropout mask

simple droput out mask

start dropout ref. work on returning SD_Mask next with negative numbers

refernce is working

dropout bwd ref faling case

transfer rng_state properly

save changes

one dropout mask function

save

save

minizmize diff

save

use torch.where in backward

save

save

save

dk works!

passes

reference is working. TODO" attn_ref is broken

varlen ref working

attn failing case

with ones. attn_ref matches. fails with randn. we are seeing failure with large sizes from dv.

save

skip attn matrices

compare the masks and find failing case

rm cdiv_fn

put dropout and alibi in common

save

compare masks

save

save

pytorch ref is using tiles

save

save

tl_rand_ref

cache ref dropout mask

new generate_dropout_mask_ref using tiling

issolate failing varlen case

simple dropout

loop on k

print rng_outputs

save

fwd kernel works

save

dv passed

close to dk

simple ref

save

seperate droped and scaled in ref and triton kernel

ref changes

working delta with dp

find failing dv failures

find failing case due to delta

save

delta from dp working

bwd impl green

enable test fwd

save

save

delete kernels

save

probably mask application mismatch

dump forward dropout

pass dropout mask tensor to bwd_core

different dropout fraction in fwd and bwd

mismatch found on columns greater than 64

fix dropout bug. philox was not offset

run full suite

stop debug and approximate delta

fix drop_mask non issue

skip attn check

clean up common

bad varlen config

fix varlen bug

save

* fix datatype mismatch

* clean up

* use pytorch dropout

* It works on MI300.

* remove _bwd_preprocess_use_p

* fix torch interface bug

---------

Co-authored-by: Alex Kranias <[email protected]>
  • Loading branch information
micmelesse and alexkranias-amd authored Dec 6, 2024
1 parent c57b3d0 commit c7db746
Show file tree
Hide file tree
Showing 10 changed files with 634 additions and 440 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@ csrc/flash_attn_ck
core.*
*.csv
*.png
*.html
*.html
*.json
*.txt
Loading

0 comments on commit c7db746

Please sign in to comment.