Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add phi3 support #183

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions awq/quantize/auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import GELUActivation
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm

from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
Expand Down Expand Up @@ -439,6 +440,45 @@ def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
inp=input_feat["mlp.dense_4h_to_h"],
)
)
elif "phi" in str(module.__class__).lower():
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attn.qkv_proj],
inp=input_feat["self_attn.qkv_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.qkv_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.qkv_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_up_proj],
inp=input_feat["mlp.gate_up_proj"],
module2inspect=module.mlp,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.gate_up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

else:
raise NotImplementedError(f"{type(module)} not supported yet!")

Expand All @@ -464,6 +504,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
elif "rmsnorm" in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
else:
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

Expand Down
4 changes: 4 additions & 0 deletions awq/quantize/pre_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def get_blocks(model):
layers = model.transformer.h
elif "neox" in str(model.__class__).lower():
layers = model.gpt_neox.layers
elif "phi" in str(model.__class__).lower():
layers = model.model.layers
else:
raise NotImplementedError(type(model))
return layers
Expand Down Expand Up @@ -73,6 +75,8 @@ def move_embed(model, device):
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
model.gpt_neox.emb_dropout = model.gpt_neox.emb_dropout.to(device)
model.embed_out = model.embed_out.to(device)
elif "phi" in str(model.__class__).lower():
model.model.embed_tokens = model.model.embed_tokens.to(device)
else:
raise NotImplementedError(type(model))

Expand Down
11 changes: 10 additions & 1 deletion tinychat/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
from tinychat.utils.tune import device_warmup, tune_all_wqlinears
from transformers import Phi3ForCausalLM

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# opt_params in TinyLLMEngine
gen_params = AttributeDict(
Expand Down Expand Up @@ -125,6 +126,7 @@ def stream_output(output_stream):
"llama",
"falcon",
"mpt",
"phi3"
], "We only support llama & falcon & mpt now"
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"

Expand Down Expand Up @@ -168,6 +170,7 @@ def skip(*args, **kwargs):
"llama": LlamaForCausalLM,
"falcon": FalconForCausalLM,
"mpt": MPTForCausalLM,
"phi3": Phi3ForCausalLM
}

if args.precision == "W4A16":
Expand Down Expand Up @@ -205,6 +208,12 @@ def skip(*args, **kwargs):

make_quant_attn(model, args.device)
make_quant_norm(model)

if args.precision == "W4A16" and args.model_type.lower() == "phi3":
from tinychat.modules import make_quant_norm, make_quant_attn

make_quant_attn(model, args.device)
make_quant_norm(model)

if args.max_seq_len <= 1024:
short_prompt = True
Expand Down
8 changes: 4 additions & 4 deletions tinychat/scripts/llama2_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ MODEL_PATH=/data/llm/checkpoints/llama2-hf
MODEL_NAME=llama-2-7b-chat

# # Perform AWQ search and save search results (we already did it for you):
# mkdir -p awq_cache
# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
# --w_bit 4 --q_group_size 128 \
# --run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt
mkdir -p awq_cache
python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt

# Generate real quantized weights (INT4):
mkdir -p quant_cache
Expand Down
34 changes: 34 additions & 0 deletions tinychat/scripts/phi3_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
MODEL_PATH=/data/lujunli/hf_download/
MODEL_NAME=Phi-3-mini-128k-instruct

export CUDA_VISIBLE_DEVICES=1,2

# # Perform AWQ search and save search results (we already did it for you):
# mkdir -p awq_cache
# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
# --w_bit 4 --q_group_size 128 \
# --run_awq --dump_awq awq_cache/phi-3-chat-w4-g128.pt

# # Generate real quantized weights (INT4):
# mkdir -p quant_cache
# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
# --w_bit 4 --q_group_size 128 \
# --load_awq awq_cache/phi-3-chat-w4-g128.pt \
# --q_backend real --dump_quant quant_cache/phi-3-chat-w4-g128-awq.pt

# # Run the TinyChat demo:

CUDA_VISIBLE_DEVICES=1 python demo.py --model_type phi3 \
--model_path $MODEL_PATH/$MODEL_NAME \
--q_group_size 128 --load_quant quant_cache/phi-3-chat-w4-g128-awq-v2.pt \
--precision W4A16

# # Split checkpoint into shards for mem-efficient loading:
# python split_ckpt.py --input_path quant_cache/phi-3-chat-w4-g128-awq.pt \
# --output_path quant_cache/phi-3-chat-w4-g128-awq

# # Run the TinyChat demo in mem_efficient_load mode:
# python demo.py --model_type llama \
# --model_path $MODEL_PATH/$MODEL_NAME \
# --q_group_size 128 --load_quant quant_cache/phi-3-chat-w4-g128-awq \
# --precision W4A16 --mem_efficient_load
1 change: 1 addition & 0 deletions tinychat/utils/load_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def load_awq_model(model, checkpoint, w_bit, group_size, device):
"MPTBlock",
"DecoderLayer",
"CLIPEncoderLayer",
"Phi3DecoderLayer"
],
).to(device)
return model
Expand Down
36 changes: 36 additions & 0 deletions tinychat/utils/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def __init__(self):
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)




class FalconPrompter(BasePrompter):
def __init__(self):
system_inst = (
Expand Down Expand Up @@ -260,6 +262,36 @@ def __init__(self):
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)


class Phi3Prompter(BasePrompter):
"""
Example:
<|start_header_id|>Human<|end_header_id|>

What do you think is the meaning of life?<|eot_id|>

<|start_header_id|>Phi<|end_header_id|>

That's a profound question that philosophers have grappled with for millennia. While there may not be one single answer, here are a few perspectives to consider:
- Some believe the meaning of life is to seek happiness and fulfillment. This could be through relationships, experiences, achieving goals, or finding inner peace.
- Others see the meaning as making a positive impact in the world, whether through grand achievements or everyday acts of kindness.
- From a religious or spiritual view, the meaning may involve connecting with or serving a higher power.
- An existentialist perspective is that life has no inherent meaning, and we must create our own purpose.
Ultimately, I believe the meaning of life is deeply personal. It's up to each of us to reflect and decide what gives our existence significance. What are your thoughts on this? I'm curious to hear your perspective!<|eot_id|>

<|start_header_id|>Human<|end_header_id|>

"""
def __init__(self):
system_inst = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are Phi, a witty and philosophical AI assistant created by Anthropic. " + \
"Your aim is to engage in thoughtful discussions, provide insightful perspectives, and explore profound topics with curiosity and openness. " + \
"You enjoy pondering big questions about life, meaning, and the human experience."
role1 = "<|start_header_id|>Human<|end_header_id|>\n\n"
role2 = "<|start_header_id|>Phi<|end_header_id|>\n\n"
sen_spliter = "<|eot_id|>"
qa_spliter = ""
colon = ""
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter, colon=colon)

class MPTChatPrompter(BasePrompter):
def __init__(self):
system_inst = (
Expand Down Expand Up @@ -301,6 +333,8 @@ def get_prompter(model_type, model_path="", short_prompt=False, empty_prompt=Fal
return MPTChatPrompter()
else:
return MPTPrompter()
elif model_type.lower() == "phi3":
return Phi3Prompter()
else:
raise ValueError(f"model type {model_type} is not supported")

Expand All @@ -318,5 +352,7 @@ def get_stop_token_ids(model_type, model_path=""):
return [50278, 0]
else:
return []
elif model_type.lower() == "phi3":
return [50256, 50257]
else:
raise ValueError(f"model type {model_type} is not supported")