Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable Vanilla Bwd and Refactor (#86)
* Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes
- Loading branch information