Skip to content

Commit

Permalink
feat: natively support Granite models (#2682)
Browse files Browse the repository at this point in the history
* feat: natively support Granite models

* Update doc
  • Loading branch information
OlivierDehaene authored Oct 23, 2024
1 parent f58eb70 commit 03c9388
Show file tree
Hide file tree
Showing 9 changed files with 816 additions and 634 deletions.
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
Expand Down
2 changes: 2 additions & 0 deletions nix/impure-shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ mkShell {
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
Expand Down
1 change: 1 addition & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2),
Ssm,
GptBigcode,
Granite,
Santacoder,
Bloom,
Mpt,
Expand Down
1,355 changes: 744 additions & 611 deletions server/poetry.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions server/requirements_cuda.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
12 changes: 6 additions & 6 deletions server/requirements_intel.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
12 changes: 6 additions & 6 deletions server/requirements_rocm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
16 changes: 14 additions & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ class ModelType(enum.Enum):
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
Expand Down Expand Up @@ -862,7 +867,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
Expand All @@ -876,7 +886,9 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def __init__(
device=weights.device,
)

self.softmax_scale = self.head_size**-0.5
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)

if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
Expand All @@ -180,7 +183,7 @@ def __init__(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=getattr(config, "attention_bias", False),
)

self.o_proj = TensorParallelAdapterRowLinear.load(
Expand Down Expand Up @@ -436,6 +439,11 @@ def __init__(self, index, prefix, config, weights):
eps=config.rms_norm_eps,
)

# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -466,13 +474,16 @@ def forward(
max_s,
adapter_data,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier

# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)

mlp_output = self.dense(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier

return mlp_output, attn_res

Expand Down Expand Up @@ -624,13 +635,28 @@ def __init__(self, prefix: str, config, weights):
else:
suffix = "lm_head"

# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier

with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)

# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -664,4 +690,11 @@ def forward(
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)

# Used in Granite
if not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling

return logits, speculative_logits

0 comments on commit 03c9388

Please sign in to comment.