-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated chunker.py #986
base: main
Are you sure you want to change the base?
Updated chunker.py #986
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,86 @@ | ||
from abc import abstractmethod | ||
from typing import Any | ||
from typing import Any, List, TypeVar, Generic | ||
|
||
T = TypeVar('T') | ||
|
||
class Chunker: | ||
class Chunker(Generic[T]): | ||
@abstractmethod | ||
def chunk(self, tokens: list[Any]) -> list[Any]: | ||
def chunk(self, tokens: List[T]) -> List[List[T]]: | ||
pass | ||
|
||
|
||
class TextOverlapChunker(Chunker): | ||
class TextOverlapChunker(Chunker[T]): | ||
""" | ||
TextOverlapChunker is a class for chunking text into smaller segments while allowing for token overlap. | ||
|
||
This class inherits from the Chunker class and is designed to divide long text tokens into chunks, each containing | ||
a specified number of tokens. It allows for a controlled overlap of tokens between adjacent chunks. | ||
|
||
TextOverlapChunker is a class for chunking sequences into smaller segments with controlled overlap. | ||
|
||
This class inherits from the Chunker class and divides sequences into chunks of a specified size, | ||
with configurable overlap between adjacent chunks. The implementation ensures that: | ||
1. All tokens are included in at least one chunk | ||
2. No chunk exceeds the specified maximum size | ||
3. Overlap is consistently maintained between chunks | ||
4. The last chunk is handled correctly even if smaller than chunk_token_count | ||
|
||
Args: | ||
chunk_token_count: The maximum number of tokens to include in each chunk. | ||
chunk_overlap_token_count: The number of tokens that can overlap between adjacent chunks. | ||
This value must be less than the `chunk_token_count` to ensure meaningful chunking. | ||
|
||
chunk_overlap_token_count: The number of tokens that should overlap between adjacent chunks. | ||
Must be less than chunk_token_count. | ||
|
||
Raises: | ||
ValueError: If chunk_overlap_token_count >= chunk_token_count or if either parameter is negative. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
chunker = TextOverlapChunker(chunk_token_count=1000, chunk_overlap_token_count=100) | ||
chunks = chunker.chunk(data) | ||
>>> chunker = TextOverlapChunker(chunk_token_count=5, chunk_overlap_token_count=2) | ||
>>> tokens = list("ABCDEFGHIJK") | ||
>>> chunks = chunker.chunk(tokens) | ||
>>> for chunk in chunks: print(''.join(chunk)) | ||
ABCDE | ||
DEFGH | ||
GHIJK | ||
""" | ||
|
||
def __init__(self, chunk_token_count: int = 1000, chunk_overlap_token_count: int = 100) -> None: | ||
super().__init__() | ||
if chunk_token_count <= 0: | ||
raise ValueError("Chunk token count must be positive") | ||
if chunk_overlap_token_count < 0: | ||
raise ValueError("Chunk overlap token count must be non-negative") | ||
if chunk_overlap_token_count >= chunk_token_count: | ||
raise Exception("Token overlap count between chunks must be lesser than chunk token count") | ||
raise ValueError("Token overlap count between chunks must be less than chunk token count") | ||
|
||
self._chunk_token_count = chunk_token_count | ||
self._chunk_overlap_token_count = chunk_overlap_token_count | ||
|
||
def chunk(self, tokens: list[Any]) -> list[Any]: | ||
return [ | ||
tokens[a : a + self._chunk_token_count] | ||
for a in range(0, len(tokens), self._chunk_token_count - self._chunk_overlap_token_count) | ||
] | ||
|
||
def chunk(self, tokens: List[T]) -> List[List[T]]: | ||
""" | ||
Divide the input sequence into overlapping chunks. | ||
|
||
Args: | ||
tokens: The input sequence to be chunked. | ||
|
||
Returns: | ||
A list of chunks, where each chunk is a list of tokens. | ||
|
||
Note: | ||
The last chunk may be smaller than chunk_token_count but will maintain | ||
the specified overlap with the previous chunk if possible. | ||
""" | ||
if not tokens: | ||
return [] | ||
|
||
chunks = [] | ||
stride = self._chunk_token_count - self._chunk_overlap_token_count | ||
|
||
for start in range(0, len(tokens), stride): | ||
# Calculate end index for current chunk | ||
end = min(start + self._chunk_token_count, len(tokens)) | ||
chunk = tokens[start:end] | ||
|
||
Comment on lines
+73
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. python can understand out-of-bounds slices and truncates appropriately. e.g.
|
||
# Add chunk if it's the first chunk, maintains minimum size, or is the last piece | ||
if (start == 0 or | ||
len(chunk) >= self._chunk_overlap_token_count or | ||
end == len(tokens)): | ||
Comment on lines
+77
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm struggling to understand why we need all these conditions. Is there ever a case, given the validations added in if this is not the last chunk, then if this is the last chunk then we add it anyway. I think you only need the if start == 0 case for when there is only one rather small chunk, in which case it's also the last chunk. |
||
chunks.append(chunk) | ||
|
||
# If we've processed all tokens, break | ||
if end == len(tokens): | ||
break | ||
Comment on lines
+83
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can just exit the for loop because we're at the end of the for loop in this case, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. edit: I think this is the cause of the test failures - if the natural chunk boundary lines up exactly with the end of the sequence then you'll break prematurely, right?
Comment on lines
+73
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. putting all these together, you can rewrite the loop more concisely as for start in range(0, len(tokens), stride):
chunks.append(tokens[start: start + self._chunk_token_count]) which is equivalent to the original list comprehension |
||
|
||
return chunks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generics were added in python 3.12. We want to be able to support >=3.9,<3.13. You will need to find another way of writing this.
sycamore/lib/sycamore/pyproject.toml
Line 16 in 96b61e0