Skip to content

Commit

Permalink
Merge branch 'main' into deepseek_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
baishihao committed Jan 13, 2025
2 parents 6020ab2 + 269631d commit 4a20aeb
Show file tree
Hide file tree
Showing 15 changed files with 365 additions and 566 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 4}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}}
{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 1, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}}
11 changes: 10 additions & 1 deletion lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import copy
from lightllm.utils.log_utils import init_logger
from lightllm.distributed import custom_comm_ops

Expand Down Expand Up @@ -27,10 +28,18 @@ def capture_decode(self, decode_func, input_ids, infer_state):
infer_state.max_len_in_batch = self.graph_max_len_in_batch
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
# warmup
# 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上
# 做一些初始化的操作,后续层可以复用这些计算的结果,如
# lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
# 中做的一些操作,所以在 warmup 的时候,需要调用infer_state的copy函数做一个
# 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性,
# 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象
# 中的 tensor。
for _ in range(1):
torch.cuda.synchronize()
decode_func(input_ids, infer_state)
decode_func(input_ids, copy.copy(infer_state)) # infer_state must copy()
torch.cuda.synchronize()

with custom_comm_ops.lightllm_capture_graph():
with torch.cuda.graph(graph_obj, pool=self.mempool):
predict_logics = decode_func(input_ids, infer_state)
Expand Down
Loading

0 comments on commit 4a20aeb

Please sign in to comment.