Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated chunker.py #986

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions lib/sycamore/sycamore/functions/chunker.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,86 @@
from abc import abstractmethod
from typing import Any
from typing import Any, List, TypeVar, Generic

T = TypeVar('T')

class Chunker:
class Chunker(Generic[T]):
@abstractmethod
def chunk(self, tokens: list[Any]) -> list[Any]:
def chunk(self, tokens: List[T]) -> List[List[T]]:
pass
Comment on lines +4 to 9
Copy link
Contributor

@MarkLindblad MarkLindblad Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generics were added in python 3.12. We want to be able to support >=3.9,<3.13. You will need to find another way of writing this.

python = ">=3.9,<3.13"



class TextOverlapChunker(Chunker):
class TextOverlapChunker(Chunker[T]):
"""
TextOverlapChunker is a class for chunking text into smaller segments while allowing for token overlap.

This class inherits from the Chunker class and is designed to divide long text tokens into chunks, each containing
a specified number of tokens. It allows for a controlled overlap of tokens between adjacent chunks.

TextOverlapChunker is a class for chunking sequences into smaller segments with controlled overlap.

This class inherits from the Chunker class and divides sequences into chunks of a specified size,
with configurable overlap between adjacent chunks. The implementation ensures that:
1. All tokens are included in at least one chunk
2. No chunk exceeds the specified maximum size
3. Overlap is consistently maintained between chunks
4. The last chunk is handled correctly even if smaller than chunk_token_count

Args:
chunk_token_count: The maximum number of tokens to include in each chunk.
chunk_overlap_token_count: The number of tokens that can overlap between adjacent chunks.
This value must be less than the `chunk_token_count` to ensure meaningful chunking.

chunk_overlap_token_count: The number of tokens that should overlap between adjacent chunks.
Must be less than chunk_token_count.

Raises:
ValueError: If chunk_overlap_token_count >= chunk_token_count or if either parameter is negative.

Example:
.. code-block:: python

chunker = TextOverlapChunker(chunk_token_count=1000, chunk_overlap_token_count=100)
chunks = chunker.chunk(data)
>>> chunker = TextOverlapChunker(chunk_token_count=5, chunk_overlap_token_count=2)
>>> tokens = list("ABCDEFGHIJK")
>>> chunks = chunker.chunk(tokens)
>>> for chunk in chunks: print(''.join(chunk))
ABCDE
DEFGH
GHIJK
"""

def __init__(self, chunk_token_count: int = 1000, chunk_overlap_token_count: int = 100) -> None:
super().__init__()
if chunk_token_count <= 0:
raise ValueError("Chunk token count must be positive")
if chunk_overlap_token_count < 0:
raise ValueError("Chunk overlap token count must be non-negative")
if chunk_overlap_token_count >= chunk_token_count:
raise Exception("Token overlap count between chunks must be lesser than chunk token count")
raise ValueError("Token overlap count between chunks must be less than chunk token count")

self._chunk_token_count = chunk_token_count
self._chunk_overlap_token_count = chunk_overlap_token_count

def chunk(self, tokens: list[Any]) -> list[Any]:
return [
tokens[a : a + self._chunk_token_count]
for a in range(0, len(tokens), self._chunk_token_count - self._chunk_overlap_token_count)
]

def chunk(self, tokens: List[T]) -> List[List[T]]:
"""
Divide the input sequence into overlapping chunks.

Args:
tokens: The input sequence to be chunked.

Returns:
A list of chunks, where each chunk is a list of tokens.

Note:
The last chunk may be smaller than chunk_token_count but will maintain
the specified overlap with the previous chunk if possible.
"""
if not tokens:
return []

chunks = []
stride = self._chunk_token_count - self._chunk_overlap_token_count

for start in range(0, len(tokens), stride):
# Calculate end index for current chunk
end = min(start + self._chunk_token_count, len(tokens))
chunk = tokens[start:end]

Comment on lines +73 to +75
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python can understand out-of-bounds slices and truncates appropriately. e.g.

>>> "01234"[2:10]
"234"

# Add chunk if it's the first chunk, maintains minimum size, or is the last piece
if (start == 0 or
len(chunk) >= self._chunk_overlap_token_count or
end == len(tokens)):
Comment on lines +77 to +79
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm struggling to understand why we need all these conditions. Is there ever a case, given the validations added in __init__ where this can evaluate to False?

if this is not the last chunk, then
len(chunk) = _chunk_token_count >= _chunk_overlap_token_count

if this is the last chunk then we add it anyway.

I think you only need the if start == 0 case for when there is only one rather small chunk, in which case it's also the last chunk.

chunks.append(chunk)

# If we've processed all tokens, break
if end == len(tokens):
break
Comment on lines +83 to +84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just exit the for loop because we're at the end of the for loop in this case, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: I think this is the cause of the test failures - if the natural chunk boundary lines up exactly with the end of the sequence then you'll break prematurely, right?
overlap = 2, len = 5
"ABCDEFGH" -> "ABCDE", "DEFGH" instead of "ABCDE", "DEFGH", "GH"

Comment on lines +73 to +84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

putting all these together, you can rewrite the loop more concisely as

for start in range(0, len(tokens), stride):
    chunks.append(tokens[start: start + self._chunk_token_count])

which is equivalent to the original list comprehension


return chunks
Loading