diff --git a/denoising_diffusion_pytorch/attend.py b/denoising_diffusion_pytorch/attend.py index fa689f906..333b75113 100644 --- a/denoising_diffusion_pytorch/attend.py +++ b/denoising_diffusion_pytorch/attend.py @@ -72,7 +72,7 @@ def flash_attn(self, q, k, v): if exists(self.scale): default_scale = q.shape[-1] - q = q * (scale / default_scale) + q = q * (self.scale / default_scale) q, k, v = map(lambda t: t.contiguous(), (q, k, v))