diff --git a/awq/quantize/auto_scale.py b/awq/quantize/auto_scale.py index bb66e10..8a1dac1 100644 --- a/awq/quantize/auto_scale.py +++ b/awq/quantize/auto_scale.py @@ -6,6 +6,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.activations import GELUActivation +from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm from .qmodule import ScaledActivation from ..utils.module import get_op_by_name, get_op_name, set_op_by_name @@ -439,6 +440,45 @@ def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}): inp=input_feat["mlp.dense_4h_to_h"], ) ) + elif "phi" in str(module.__class__).lower(): + # attention input + scales_list.append( + _auto_get_scale( + prev_op=module.input_layernorm, + layers=[module.self_attn.qkv_proj], + inp=input_feat["self_attn.qkv_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + # attn out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.qkv_proj.weight.shape == module.self_attn.o_proj.weight.shape: + scales_list.append( + _auto_get_scale( + prev_op=module.self_attn.qkv_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # fc1 + scales_list.append( + _auto_get_scale( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_up_proj], + inp=input_feat["mlp.gate_up_proj"], + module2inspect=module.mlp, + ) + ) + # fc2 + scales_list.append( + _auto_get_scale( + prev_op=module.mlp.gate_up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + else: raise NotImplementedError(f"{type(module)} not supported yet!") @@ -464,6 +504,8 @@ def apply_scale(module, scales_list, input_feat_dict=None): new_module = ScaledActivation(prev_op, scales) set_op_by_name(module, prev_op_name, new_module) scale_gelu_fc(prev_op, layers[0], scales) + elif "rmsnorm" in str(prev_op.__class__).lower(): + scale_ln_fcs(prev_op, layers, scales) else: raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!") diff --git a/awq/quantize/pre_quant.py b/awq/quantize/pre_quant.py index f35531d..45335a5 100644 --- a/awq/quantize/pre_quant.py +++ b/awq/quantize/pre_quant.py @@ -39,6 +39,8 @@ def get_blocks(model): layers = model.transformer.h elif "neox" in str(model.__class__).lower(): layers = model.gpt_neox.layers + elif "phi" in str(model.__class__).lower(): + layers = model.model.layers else: raise NotImplementedError(type(model)) return layers @@ -73,6 +75,8 @@ def move_embed(model, device): model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) model.gpt_neox.emb_dropout = model.gpt_neox.emb_dropout.to(device) model.embed_out = model.embed_out.to(device) + elif "phi" in str(model.__class__).lower(): + model.model.embed_tokens = model.model.embed_tokens.to(device) else: raise NotImplementedError(type(model)) diff --git a/tinychat/demo.py b/tinychat/demo.py index 4cec8d1..b8fa37f 100644 --- a/tinychat/demo.py +++ b/tinychat/demo.py @@ -10,10 +10,11 @@ from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids from tinychat.utils.tune import device_warmup, tune_all_wqlinears +from transformers import Phi3ForCausalLM import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # opt_params in TinyLLMEngine gen_params = AttributeDict( @@ -125,6 +126,7 @@ def stream_output(output_stream): "llama", "falcon", "mpt", + "phi3" ], "We only support llama & falcon & mpt now" assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now" @@ -168,6 +170,7 @@ def skip(*args, **kwargs): "llama": LlamaForCausalLM, "falcon": FalconForCausalLM, "mpt": MPTForCausalLM, + "phi3": Phi3ForCausalLM } if args.precision == "W4A16": @@ -205,6 +208,12 @@ def skip(*args, **kwargs): make_quant_attn(model, args.device) make_quant_norm(model) + + if args.precision == "W4A16" and args.model_type.lower() == "phi3": + from tinychat.modules import make_quant_norm, make_quant_attn + + make_quant_attn(model, args.device) + make_quant_norm(model) if args.max_seq_len <= 1024: short_prompt = True diff --git a/tinychat/scripts/llama2_demo.sh b/tinychat/scripts/llama2_demo.sh index 95cf734..1dc2dab 100755 --- a/tinychat/scripts/llama2_demo.sh +++ b/tinychat/scripts/llama2_demo.sh @@ -2,10 +2,10 @@ MODEL_PATH=/data/llm/checkpoints/llama2-hf MODEL_NAME=llama-2-7b-chat # # Perform AWQ search and save search results (we already did it for you): -# mkdir -p awq_cache -# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \ -# --w_bit 4 --q_group_size 128 \ -# --run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt +mkdir -p awq_cache +python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \ + --w_bit 4 --q_group_size 128 \ + --run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt # Generate real quantized weights (INT4): mkdir -p quant_cache diff --git a/tinychat/scripts/phi3_demo.sh b/tinychat/scripts/phi3_demo.sh new file mode 100644 index 0000000..3fbed4c --- /dev/null +++ b/tinychat/scripts/phi3_demo.sh @@ -0,0 +1,34 @@ +MODEL_PATH=/data/lujunli/hf_download/ +MODEL_NAME=Phi-3-mini-128k-instruct + +export CUDA_VISIBLE_DEVICES=1,2 + +# # Perform AWQ search and save search results (we already did it for you): +# mkdir -p awq_cache +# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \ +# --w_bit 4 --q_group_size 128 \ +# --run_awq --dump_awq awq_cache/phi-3-chat-w4-g128.pt + +# # Generate real quantized weights (INT4): +# mkdir -p quant_cache +# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \ +# --w_bit 4 --q_group_size 128 \ +# --load_awq awq_cache/phi-3-chat-w4-g128.pt \ +# --q_backend real --dump_quant quant_cache/phi-3-chat-w4-g128-awq.pt + +# # Run the TinyChat demo: + +CUDA_VISIBLE_DEVICES=1 python demo.py --model_type phi3 \ + --model_path $MODEL_PATH/$MODEL_NAME \ + --q_group_size 128 --load_quant quant_cache/phi-3-chat-w4-g128-awq-v2.pt \ + --precision W4A16 + +# # Split checkpoint into shards for mem-efficient loading: +# python split_ckpt.py --input_path quant_cache/phi-3-chat-w4-g128-awq.pt \ +# --output_path quant_cache/phi-3-chat-w4-g128-awq + +# # Run the TinyChat demo in mem_efficient_load mode: +# python demo.py --model_type llama \ +# --model_path $MODEL_PATH/$MODEL_NAME \ +# --q_group_size 128 --load_quant quant_cache/phi-3-chat-w4-g128-awq \ +# --precision W4A16 --mem_efficient_load diff --git a/tinychat/utils/load_quant.py b/tinychat/utils/load_quant.py index 017a9a4..862bc15 100644 --- a/tinychat/utils/load_quant.py +++ b/tinychat/utils/load_quant.py @@ -89,6 +89,7 @@ def load_awq_model(model, checkpoint, w_bit, group_size, device): "MPTBlock", "DecoderLayer", "CLIPEncoderLayer", + "Phi3DecoderLayer" ], ).to(device) return model diff --git a/tinychat/utils/prompt_templates.py b/tinychat/utils/prompt_templates.py index 3e431ed..e4cbeb3 100644 --- a/tinychat/utils/prompt_templates.py +++ b/tinychat/utils/prompt_templates.py @@ -233,6 +233,8 @@ def __init__(self): super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter) + + class FalconPrompter(BasePrompter): def __init__(self): system_inst = ( @@ -260,6 +262,36 @@ def __init__(self): super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter) +class Phi3Prompter(BasePrompter): + """ + Example: + <|start_header_id|>Human<|end_header_id|> + + What do you think is the meaning of life?<|eot_id|> + + <|start_header_id|>Phi<|end_header_id|> + + That's a profound question that philosophers have grappled with for millennia. While there may not be one single answer, here are a few perspectives to consider: + - Some believe the meaning of life is to seek happiness and fulfillment. This could be through relationships, experiences, achieving goals, or finding inner peace. + - Others see the meaning as making a positive impact in the world, whether through grand achievements or everyday acts of kindness. + - From a religious or spiritual view, the meaning may involve connecting with or serving a higher power. + - An existentialist perspective is that life has no inherent meaning, and we must create our own purpose. + Ultimately, I believe the meaning of life is deeply personal. It's up to each of us to reflect and decide what gives our existence significance. What are your thoughts on this? I'm curious to hear your perspective!<|eot_id|> + + <|start_header_id|>Human<|end_header_id|> + + """ + def __init__(self): + system_inst = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are Phi, a witty and philosophical AI assistant created by Anthropic. " + \ + "Your aim is to engage in thoughtful discussions, provide insightful perspectives, and explore profound topics with curiosity and openness. " + \ + "You enjoy pondering big questions about life, meaning, and the human experience." + role1 = "<|start_header_id|>Human<|end_header_id|>\n\n" + role2 = "<|start_header_id|>Phi<|end_header_id|>\n\n" + sen_spliter = "<|eot_id|>" + qa_spliter = "" + colon = "" + super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter, colon=colon) + class MPTChatPrompter(BasePrompter): def __init__(self): system_inst = ( @@ -301,6 +333,8 @@ def get_prompter(model_type, model_path="", short_prompt=False, empty_prompt=Fal return MPTChatPrompter() else: return MPTPrompter() + elif model_type.lower() == "phi3": + return Phi3Prompter() else: raise ValueError(f"model type {model_type} is not supported") @@ -318,5 +352,7 @@ def get_stop_token_ids(model_type, model_path=""): return [50278, 0] else: return [] + elif model_type.lower() == "phi3": + return [50256, 50257] else: raise ValueError(f"model type {model_type} is not supported")