From 589e2b6fc56a0b6e6cf6091cd6d57eb0cd2c4e2a Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Wed, 18 Dec 2024 10:42:05 +0000 Subject: [PATCH] add warning for old flash attn --- setup.py | 2 +- xfuser/model_executor/pipelines/base_pipeline.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5405d09..4c8ab88 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ def get_cuda_version(): "imageio", "imageio-ffmpeg", "optimum-quanto", - "flash_attn>=2.7.0" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops + "flash_attn>=2.6.3" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops ], extras_require={ "diffusers": [ diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 5ca54f3..c497535 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod from functools import wraps +from packaging import version from typing import Callable, Dict, List, Optional, Tuple, Union +import sys import torch import torch.distributed import torch.nn as nn @@ -299,6 +301,12 @@ def _convert_transformer_backbone( if enable_torch_compile or enable_onediff: if getattr(transformer, "forward") is not None: if enable_torch_compile: + if "flash_attn" in sys.modules: + import flash_attn + if version.parse(flash_attn.__version__) < version.parse("2.7.0") or version.parse(torch.__version__) < version.parse("2.4.0"): + logger.warning( + "flash-attn or torch version is too old, performance with torch.compile may be suboptimal due to too many graph breaks" + ) optimized_transformer_forward = torch.compile( getattr(transformer, "forward") )