diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 0863ed0fa..ca9e99805 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -90,11 +90,17 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): gpt2_logits_n_devices, gpt2_cache_n_devices = model_n_devices.run_with_cache( gpt2_tokens, remove_batch_dim=True) + # Make sure the tensors in cache remain on their respective devices + for i in range(model_n_devices.cfg.n_layers): + expected_device = get_device_for_block_index(i, cfg=model_n_devices.cfg) + cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device + assert cache_device == expected_device + assert torch.allclose(gpt2_logits_1_device.to("cpu"), gpt2_logits_n_devices.to("cpu")) for key in gpt2_cache_1_device.keys(): - assert torch.allclose(gpt2_cache_1_device[key], - gpt2_cache_n_devices[key]) + assert torch.allclose(gpt2_cache_1_device[key].to("cpu"), + gpt2_cache_n_devices[key].to("cpu")) cuda_devices = set() n_params_on_device = {} @@ -114,3 +120,27 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): print( f"Number of devices: {n_devices}, Model loss (1 device): {loss_1_device}, Model loss ({n_devices} devices): {loss_n_devices}, Time taken (1 device): {elapsed_time_1_device:.4f} seconds, Time taken ({n_devices} devices): {elapsed_time_n_devices:.4f} seconds" ) + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices") +def test_cache_device(): + model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1") + + logits, cache = model.run_with_cache("Hello there") + assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cuda:1")) + + logits, cache = model.run_with_cache("Hello there", device="cpu") + assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) + + model.to("cuda") + logits, cache = model.run_with_cache("Hello there") + assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device) + + +def norm_device(device): + """ + Convenience function to normalize device strings for comparison. + """ + device_str = str(device) + if device_str.startswith("cuda") and ':' not in device_str: + device_str += ':0' + return device_str diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py index 8a419385f..341f46432 100644 --- a/transformer_lens/head_detector.py +++ b/transformer_lens/head_detector.py @@ -122,9 +122,10 @@ def get_duplicate_token_head_detection_pattern(model: HookedTransformer, sequenc model: Model being used. sequence: String being fed to the model.""" - sequence = model.to_tokens(sequence) - token_pattern = [np.array(sequence) for i in range(sequence.shape[-1])] - token_pattern = np.concatenate(token_pattern, axis=0) + sequence = model.to_tokens(sequence).detach().cpu() + + # Repeat sequence to create a square matrix. + token_pattern = sequence.repeat(sequence.shape[-1], 1).numpy() # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates. eq_mask = np.equal(token_pattern, token_pattern.T).astype(int) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index f386c071b..d983f50cf 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -280,15 +280,13 @@ def add_caching_hooks( Args: names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. - device (_type_, optional): The device to store on. Defaults to CUDA if available else CPU. + device (_type_, optional): The device to store on. Defaults to same device as model. remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. Returns: cache (dict): The cache where activations will be stored. """ - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" if cache is None: cache = {} @@ -340,8 +338,7 @@ def run_with_cache( list of str, or a function that takes a string and returns a bool. Defaults to None, which means cache everything. device (str or torch.Device, optional): The device to cache activations on. Defaults to the - model device. Note that this must be set if the model does not have a model.cfg.device - attribute. WARNING: Setting a different device than the one used by the model leads to + model device. WARNING: Setting a different device than the one used by the model leads to significant performance degradation. remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only makes sense with batch_size=1 inputs. Defaults to False. @@ -382,7 +379,7 @@ def get_caching_hooks( Args: names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. - device (_type_, optional): The device to store on. Defaults to CUDA if available else CPU. + device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. @@ -391,8 +388,6 @@ def get_caching_hooks( fwd_hooks (list): The forward hooks. bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. """ - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" if cache is None: cache = {} diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index 928c21f37..9c3fdad28 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -9,10 +9,25 @@ def get_device_for_block_index( cfg: HookedTransformerConfig, device: Optional[Union[torch.device, str]] = None, ): + """ + Determine the device for a given layer index based on the model configuration. + + This function assists in distributing model layers across multiple devices. The distribution + is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). + + Args: + index (int): Model layer index. + cfg (HookedTransformerConfig): Model and device configuration. + device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. + If not provided, the function uses the device specified in the configuration (cfg.device). + + Returns: + torch.device: The device for the specified layer index. + """ assert cfg.device is not None layers_per_device = cfg.n_layers // cfg.n_devices if device is None: device = cfg.device - if isinstance(device, torch.device): - device = device.type - return torch.device(device, index // layers_per_device) + device = torch.device(device) + device_index = (device.index or 0) + (index // layers_per_device) + return torch.device(device.type, device_index)