Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about performance of qwen2-vl on A10 #54

Open
gxm651182644 opened this issue Nov 28, 2024 · 5 comments
Open

Question about performance of qwen2-vl on A10 #54

gxm651182644 opened this issue Nov 28, 2024 · 5 comments

Comments

@gxm651182644
Copy link

I benchmark qwen2-vl model inference with sageattn on A10, but i do not see speed improvement.

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto",attn_implementation="sdpa",
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]

text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

generated_ids = model.generate(**inputs, max_new_tokens=1)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
`
qwen2 support attn_implementation as follows. I do not know how to integrate sageattn to flash_attention_2 , so i use sdpa to do benchmark

QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
"sdpa": Qwen2VLSdpaAttention,
}
to support sageattn,
I replace https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L810
"attn_output = torch.nn.functional.scaled_dot_product_attention" to "attn_output = sageattn"

  • benchmark result
<style> </style>
attn_implementation RT and mem
Sdpa model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:1.9317545413970947s,             gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi
Sdpa + enable sageattn model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:2.100834906101227s,             gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi

In addition ,when batch_size >1 ,the sdpa will OOM, But flashattn2 can support batch_size=20.

I am wondering if I integrate sageattn with the correct way
or could you provide a example of integrate sageattn to qwen2 model ?
Thank!

@gxm651182644 gxm651182644 changed the title Question about improvement of qwen2-vl model on A10 Question about performance of qwen2-vl on A10 Nov 28, 2024
@jason-huang03
Copy link
Member

See #48

@gxm651182644
Copy link
Author

See #48
as you can see in the code block, i set the max_new_token=1. i want boost the prefill/context stage of attention operation.

@jason-huang03
Copy link
Member

We have not tested our kernel on A10. Can you run the benchmarking code under ./bench directory? By the way, how long is your sequence length?

@jason-huang03
Copy link
Member

Never mind, we will bench on A10 gpu later. It seems that A10 is also a popular gpu.

@gxm651182644
Copy link
Author

I have benchmark on A10 GPU as follows。
And my sequence length 1272, the inputs_embeds shape :torch.Size([1, 1272, 1536])

cmd method benchmark
python bench_baseline.py fa2 Baseline: fa2 batch: 4, head: 32, headdim: 128 is_causal: False 1024 flops:40.175341404583286 2048 flops:63.5630322165171 4096 flops:64.01148248993015 8192 flops:64.76479064198912 16384 flops:65.10703895866709 32768 flops:65.14412976234757 is_causal: True 1024 flops:53.742264335071035 2048 flops:59.065702193521396 4096 flops:61.00142243811802 8192 flops:61.46631331345574 16384 flops:61.30878961365662 32768 flops:60.8437007818992
python  bench_baseline.py --method xformers xformers Baseline: xformers batch: 4, head: 32, headdim: 128 is_causal: False 1024 flops:40.20658897960468 2048 flops:41.01555553542401 4096 flops:41.75975084860743 8192 flops:42.2126865305537 16384 flops:42.39974519456994 32768 flops:42.52372622197224 is_causal: True 1024 flops:34.938615320886385 2048 flops:38.674978580578255 4096 flops:40.80530514268178 8192 flops:42.01013814438734 16384 flops:42.34412978306931 32768 flops:42.05084122742436
python  bench_qk_int8_pv_fp16_cuda.py   CUDA QK Int8 PV FP16 batch: 4, head: 32, headdim: 128, pv_accum_dtype: fp16 is_causal: False 1024 flops:82.75165112847091 2048 flops:84.59680884716835 4096 flops:85.24697479956505 8192 flops:86.31497985829552 16384 flops:86.83321248226595 32768 flops:86.88962633589557 is_causal: True 1024 flops:65.01688965948028 2048 flops:75.19012664557366 4096 flops:80.79297679593644 8192 flops:83.90270239749144 16384 flops:85.49200151297464 32768 flops:86.31312313640244
python bench_qk_int8_pv_fp8_cuda.py   RuntimeError: CUDA error: unspecified launch failure CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
python bench_qk_int8_pv_fp16_triton.py   Triton QK Int8 PV FP16 batch_size: 4, num_heads: 32, head_dim: 128 1024 flops:75.19770239320923 2048 flops:78.5884531698023 4096 flops:79.96917833617543 8192 flops:81.26514735519596 16384 flops:81.90514659156548 32768 flops:82.04224281532841 1024 flops:59.13339112318871 2048 flops:69.00567466502596 4096 flops:74.86831496149328 8192 flops:78.0369149726375 16384 flops:79.69784644085885

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants