Skip to content

Commit

Permalink
Update citation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed May 26, 2024
1 parent e2e4333 commit 320fb59
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,13 @@ If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@article{dao2023flashattention2,
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
year={2023}
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```
7 changes: 6 additions & 1 deletion flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput

try:
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
except ImportError:
GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"])
SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"])


@dataclass
Expand Down

0 comments on commit 320fb59

Please sign in to comment.