diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md index 7706ec9aa..f69b53573 100644 --- a/.github/ISSUE_TEMPLATE/bug.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -17,11 +17,11 @@ Please try to provide a minimal example to reproduce the bug. Error messages and Describe the characteristic of your environment: * Describe how `transformer_lens` was installed (pip, docker, source, ...) * What OS are you using? (Linux, MacOS, Windows) - * Python version (We suppourt 3.7 -3.10 currently) + * Python version (We support 3.7--3.10 currently) **Additional context** Add any other context about the problem here. ### Checklist -- [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**) \ No newline at end of file +- [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 8e32a9cd6..1c2a4a481 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1070,6 +1070,7 @@ def from_pretrained( default_prepend_bos: bool = True, default_padding_side: Literal["left", "right"] = "right", dtype="float32", + first_n_layers: Optional[int] = None, **from_pretrained_kwargs, ) -> "HookedTransformer": """Load in a Pretrained Model. @@ -1204,6 +1205,7 @@ def from_pretrained( the model. default_padding_side: Which side to pad on when tokenizing. Defaults to "right". + first_n_layers: If specified, only load the first n layers of the model. """ assert not ( @@ -1261,6 +1263,7 @@ def from_pretrained( n_devices=n_devices, default_prepend_bos=default_prepend_bos, dtype=dtype, + first_n_layers=first_n_layers, **from_pretrained_kwargs, ) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 6906de38f..cfca7fb72 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -78,7 +78,7 @@ class HookedTransformerConfig: attention attn_types (List[str], *optional*): the types of attention to use for local attention - weight_init_mode (str): the initialization mode to use for the + init_mode (str): the initialization mode to use for the weights. Only relevant for custom models, ignored for pre-trained. We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'. @@ -100,7 +100,7 @@ class HookedTransformerConfig: Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible. initializer_range (float): The standard deviation of the normal used to - initialise the weights, initialized to 0.8 / sqrt(d_model). If weight_init_mode is + initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set. init_weights (bool): Whether to initialize the weights. Defaults to diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7c36efdd7..e7ebea947 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1389,6 +1389,7 @@ def get_pretrained_model_config( n_devices: int = 1, default_prepend_bos: bool = True, dtype: torch.dtype = torch.float32, + first_n_layers: Optional[int] = None, **kwargs, ): """Returns the pretrained model config as an HookedTransformerConfig object. @@ -1501,6 +1502,8 @@ def get_pretrained_model_config( cfg_dict["default_prepend_bos"] = default_prepend_bos if hf_cfg is not None: cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) + if first_n_layers is not None: + cfg_dict["n_layers"] = first_n_layers cfg = HookedTransformerConfig.from_dict(cfg_dict) return cfg