Skip to content

Commit

Permalink
Allow to_str_tokens of 1x1 tensors and add HookedTransformer.to_singl… (
Browse files Browse the repository at this point in the history
#245)

* Allow to_str_tokens of 1x1 tensors and add HookedTransformer.to_single_str_token

* Implement changes and add unit test

* Fix dim=1 syntax and use unsqueezing instead

* Add torch.Tensor and np.ndarray type hints to the function signature

* Add a couple more test cases

* Add pytest fixtures

* Change type signature

* Upper case T tensor

* Int, lowercase T tensor

* Int, not float

* Make the tensors actually conform to the shape specifications

* Address Neel's comments

* Dummy commit to rerun CI/CD
  • Loading branch information
adamyedidia authored Apr 16, 2023
1 parent c034846 commit 6eca22f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
59 changes: 59 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer

ref_tensor = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
shape = ref_tensor.shape
Expand Down Expand Up @@ -52,3 +53,61 @@ def test_indices(self, input_slice, expected_indices):
def test_indices_error(self):
with pytest.raises(ValueError):
_ = utils.Slice(input_slice=[0, 2, 5]).indices()


MODEL = "solu-1l"

model = HookedTransformer.from_pretrained(MODEL)

@pytest.fixture
def nested_list_1():
return [1]


@pytest.fixture
def nested_list_1x1():
return [[6]]


@pytest.fixture
def nested_list_1x3():
return [[1, 2, 3]]


def test_to_str_tokens(nested_list_1, nested_list_1x1, nested_list_1x3):
tensor_1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1))
assert isinstance(tensor_1_to_str_tokens, list)
assert len(tensor_1_to_str_tokens) == 1
assert isinstance(tensor_1_to_str_tokens[0], str)

tensor_1x1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x1))
assert isinstance(tensor_1x1_to_str_tokens, list)
assert len(tensor_1x1_to_str_tokens) == 1
assert isinstance(tensor_1x1_to_str_tokens[0], str)

ndarray_1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1))
assert isinstance(ndarray_1_to_str_tokens, list)
assert len(ndarray_1_to_str_tokens) == 1
assert isinstance(ndarray_1_to_str_tokens[0], str)

ndarray_1x1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x1))
assert isinstance(ndarray_1x1_to_str_tokens, list)
assert len(ndarray_1x1_to_str_tokens) == 1
assert isinstance(ndarray_1x1_to_str_tokens[0], str)

single_int_to_single_str_token = model.to_single_str_token(3)
assert isinstance(single_int_to_single_str_token, str)

squeezable_tensor_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x3))
assert isinstance(squeezable_tensor_to_str_tokens, list)
assert len(squeezable_tensor_to_str_tokens) == 3
assert isinstance(squeezable_tensor_to_str_tokens[0], str)
assert isinstance(squeezable_tensor_to_str_tokens[1], str)
assert isinstance(squeezable_tensor_to_str_tokens[2], str)

squeezable_ndarray_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x3))
assert isinstance(squeezable_ndarray_to_str_tokens, list)
assert len(squeezable_ndarray_to_str_tokens) == 3
assert isinstance(squeezable_ndarray_to_str_tokens[0], str)
assert isinstance(squeezable_ndarray_to_str_tokens[1], str)
assert isinstance(squeezable_ndarray_to_str_tokens[2], str)
20 changes: 19 additions & 1 deletion transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,12 @@ def to_string(

def to_str_tokens(
self,
input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]], list],
input: Union[str,
Int[torch.Tensor, "pos"],
Int[torch.Tensor, "1 pos"],
Int[np.ndarray, "pos"],
Int[np.ndarray, "1 pos"],
list],
prepend_bos: bool = True,
) -> List[str]:
"""Method to map text, a list of text or tokens to a list of tokens as strings
Expand Down Expand Up @@ -506,12 +511,18 @@ def to_str_tokens(
elif isinstance(input, torch.Tensor):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.dim() == 0:
# Don't pass dimensionless tensor
tokens = tokens.unsqueeze(0)
assert (
tokens.dim() == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
elif isinstance(input, np.ndarray):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.ndim == 0:
# Don't pass dimensionless tensor
tokens = np.expand_dims(tokens, axis=0)
assert (
tokens.ndim == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
Expand All @@ -531,6 +542,13 @@ def to_single_token(self, string):
assert not token.shape, f"Input string: {string} is not a single token!"
return token.item()

def to_single_str_token(self, int_token: int) -> str:
# Gives the single token corresponding to an int in string form
assert isinstance(int_token, int)
token = self.to_str_tokens(torch.tensor([int_token]))
assert len(token) == 1
return token[0]

def get_token_position(
self,
single_token: Union[str, int],
Expand Down

0 comments on commit 6eca22f

Please sign in to comment.