diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 8e32a9cd6..1c2a4a481 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1070,6 +1070,7 @@ def from_pretrained( default_prepend_bos: bool = True, default_padding_side: Literal["left", "right"] = "right", dtype="float32", + first_n_layers: Optional[int] = None, **from_pretrained_kwargs, ) -> "HookedTransformer": """Load in a Pretrained Model. @@ -1204,6 +1205,7 @@ def from_pretrained( the model. default_padding_side: Which side to pad on when tokenizing. Defaults to "right". + first_n_layers: If specified, only load the first n layers of the model. """ assert not ( @@ -1261,6 +1263,7 @@ def from_pretrained( n_devices=n_devices, default_prepend_bos=default_prepend_bos, dtype=dtype, + first_n_layers=first_n_layers, **from_pretrained_kwargs, ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7c36efdd7..e7ebea947 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1389,6 +1389,7 @@ def get_pretrained_model_config( n_devices: int = 1, default_prepend_bos: bool = True, dtype: torch.dtype = torch.float32, + first_n_layers: Optional[int] = None, **kwargs, ): """Returns the pretrained model config as an HookedTransformerConfig object. @@ -1501,6 +1502,8 @@ def get_pretrained_model_config( cfg_dict["default_prepend_bos"] = default_prepend_bos if hf_cfg is not None: cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) + if first_n_layers is not None: + cfg_dict["n_layers"] = first_n_layers cfg = HookedTransformerConfig.from_dict(cfg_dict) return cfg