-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about performance of qwen2-vl on A10 #54
Comments
See #48 |
|
We have not tested our kernel on A10. Can you run the benchmarking code under ./bench directory? By the way, how long is your sequence length? |
Never mind, we will bench on A10 gpu later. It seems that A10 is also a popular gpu. |
I have benchmark on A10 GPU as follows。
|
I benchmark qwen2-vl model inference with sageattn on A10, but i do not see speed improvement.
Env
branch: master
commit-id:commit 5f88df2
A10 24G
python 3.9
cuda_12.3
torch 2.5.1
triton 3.1.
Model
https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
How to infer
ref from https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
`from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto",attn_implementation="sdpa",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=1)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
`
qwen2 support attn_implementation as follows. I do not know how to integrate sageattn to flash_attention_2 , so i use sdpa to do benchmark
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
"sdpa": Qwen2VLSdpaAttention,
}
to support sageattn,
I replace https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L810
"attn_output = torch.nn.functional.scaled_dot_product_attention" to "attn_output = sageattn"
In addition ,when batch_size >1 ,the sdpa will OOM, But flashattn2 can support batch_size=20.
I am wondering if I integrate sageattn with the correct way
or could you provide a example of integrate sageattn to qwen2 model ?
Thank!
The text was updated successfully, but these errors were encountered: