diff --git a/lib/sycamore/sycamore/data/table.py b/lib/sycamore/sycamore/data/table.py
index 2ec195dd9..1600947aa 100644
--- a/lib/sycamore/sycamore/data/table.py
+++ b/lib/sycamore/sycamore/data/table.py
@@ -184,6 +184,33 @@ def to_dict(self) -> dict[str, Any]:
d["num_cols"] = self.num_cols
return d
+ def data_cells(self) -> "Table":
+ """Returns a table containing only the data cells in this table."""
+ header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header)))
+ i = -1
+ for r in header_rows:
+ i += 1
+ if r != i:
+ break
+ data_cells = [c for c in self.cells if c.rows[0] > i]
+ shifted_cells = [
+ TableCell(
+ content=c.content, rows=[r - i - 1 for r in c.rows], cols=c.cols, is_header=c.is_header, bbox=c.bbox
+ )
+ for c in data_cells
+ ]
+ return Table(shifted_cells, column_headers=[], caption=self.caption)
+
+ def header_cells(self) -> list[TableCell]:
+ """Returns the header cells as a list"""
+ header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header)))
+ i = -1
+ for r in header_rows:
+ i += 1
+ if r != i:
+ break
+ return [c for c in self.cells if c.rows[0] <= i]
+
# TODO: There are likely edge cases where this will break or lose information. Nested or non-contiguous
# headers are one likely source of issues. We also don't support missing closing tags (which are allowed in
# the spec) because html.parser doesn't handle them. If and when this becomes an issue, we can consider
diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py b/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py
index 7fec46a72..a1b6e8693 100644
--- a/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py
+++ b/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py
@@ -1,8 +1,8 @@
import ray.data
-from sycamore.data import Document
+from sycamore.data import Document, TableElement, Table
from sycamore.transforms.split_elements import SplitElements
-from sycamore.functions.tokenizer import HuggingFaceTokenizer
+from sycamore.functions.tokenizer import HuggingFaceTokenizer, CharacterTokenizer
from sycamore.plan_nodes import Node
@@ -32,6 +32,29 @@ class TestSplitElements:
}
)
+ bigtable = """
+
+ headerA | headerB | headerC |
+ headerD | headerE | headerF |
+ data1a | data2a | data3a | data4a |
+ data1b | data2b | data3b | data4b |
+ data1c | data2c | data3c | data4c |
+ data1d | data2d | data3d | data4d |
+
+ """
+
+ tabledoc = Document(
+ {
+ "doc_id": "id",
+ "type": "pdf",
+ "text_representation": "lkqwrg",
+ "binary_representation": None,
+ "parent_id": None,
+ "properties": {"path": "/filename.yolo", "title": "lkqwrg"},
+ "elements": [TableElement(table=Table.from_html(bigtable))],
+ }
+ )
+
def test_split_elements(self):
tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2")
doc = SplitElements(None, tokenizer, 15).run(self.doc)
@@ -62,3 +85,17 @@ def validateElems(self, elems):
assert elems[7].text_representation == "thirtyeight thirtynine forty fortyone fortytwo fortythree "
assert elems[8].text_representation == "fortyfour fortyfive fortysix "
assert elems[9].text_representation == "fortyseven fortyeight fortynine"
+
+ def test_split_table(self):
+ tk = CharacterTokenizer()
+ doc = SplitElements(None, tk, 35).run(self.tabledoc)
+ answers = {
+ 'headerA |
---|
data1a |
data1b |
data1c |
data1d |
', # noqa: E501
+ "headerB |
---|
headerD |
---|
data2a |
data2b |
",
+ "headerB |
---|
headerD |
---|
data2c |
data2d |
",
+ "headerB |
---|
headerE |
---|
data3a |
data3b |
",
+ "headerB |
---|
headerE |
---|
data3c |
data3d |
",
+ "headerC |
---|
headerF |
---|
data4a |
data4b |
",
+ "headerC |
---|
headerF |
---|
data4c |
data4d |
",
+ }
+ assert {e.table.to_html() for e in doc.elements} == answers
diff --git a/lib/sycamore/sycamore/transforms/split_elements.py b/lib/sycamore/sycamore/transforms/split_elements.py
index 1966ce1ec..bf382bbf1 100644
--- a/lib/sycamore/sycamore/transforms/split_elements.py
+++ b/lib/sycamore/sycamore/transforms/split_elements.py
@@ -1,6 +1,7 @@
+import math
from typing import Optional
import logging
-from sycamore.data import Document, Element, TableElement
+from sycamore.data import Document, Element, TableElement, TableCell, Table, BoundingBox
from sycamore.functions.tokenizer import Tokenizer
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
@@ -8,6 +9,8 @@
logger = logging.getLogger(__name__)
+RECURSIVE_SPLIT_MAX_DEPTH = 20
+
class SplitElements(SingleThreadUser, NonGPUUser, Map):
"""
@@ -46,10 +49,13 @@ def split_doc(parent: Document, tokenizer: Tokenizer, max: int) -> Document:
@staticmethod
def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) -> list[Element]:
- if depth > 20:
+ if depth > RECURSIVE_SPLIT_MAX_DEPTH:
logger.warning("Max split depth exceeded, truncating the splitting")
return [elem]
+ if elem.type == "table" and isinstance(elem, TableElement) and elem.table is not None:
+ return SplitElements.split_one_table(elem, tokenizer, max, depth)
+
txt = elem.text_representation
if not txt:
return [elem]
@@ -138,3 +144,160 @@ def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) ->
bb = SplitElements.split_one(ment, tokenizer, max, depth + 1)
aa.extend(bb)
return aa
+
+ @staticmethod
+ def split_one_table(element: TableElement, tokenizer: Tokenizer, max_tokens: int, depth: int = 0) -> list[Element]:
+ """
+ Special handling for tables: If the column header is too big, no amount of splitting the
+ rows will save us, as we want to attach the col header to each subtable. In this case,
+ split the table horizontally. If the column header is small enough, we can guess the number
+ of rows per subtable and try to break the table into chunks of that size for evenness (still
+ breaking when we run out of tokens). Special care is taken to adjust the bounding box appropriately.
+ """
+ if depth > RECURSIVE_SPLIT_MAX_DEPTH:
+ logger.warning("Max split depth exceeded, truncating the splitting")
+ return [element]
+
+ assert element.table is not None, "Cannot split a table without table structure"
+
+ col_header_len = len(tokenizer.tokenize(", ".join(element.table.column_headers)))
+ data_table = element.table.data_cells()
+ data_row_lens = [
+ len(tokenizer.tokenize(", ".join([c.content for c in data_table.cells if i in c.rows])))
+ for i in range(data_table.num_rows)
+ ]
+ if col_header_len > max_tokens - max(data_row_lens):
+ # If there is a row and column header that cannot combine,
+ # split table horizontally in half and recurse. Splitting
+ # is done by column number rather than text.
+ ncols = element.table.num_cols
+ if ncols <= 1:
+ # One-column table that's too big - turn it into text and split in the traditional way.
+ new_elt = element.copy()
+ new_elt.data["text_representation"] = new_elt.text_representation
+ new_elt.type = "Text"
+ return SplitElements.split_one(new_elt, tokenizer, max_tokens, depth + 1)
+ # Split the table by splitting the cells into groups.
+ elem_cells = [
+ TableCell(c.content, c.rows, [cl for cl in c.cols if cl < ncols // 2], c.is_header, c.bbox)
+ for c in element.table.cells
+ if min(c.cols) < ncols // 2
+ ]
+ ment_cells = [
+ TableCell(
+ c.content, c.rows, [cl - ncols // 2 for cl in c.cols if cl >= ncols // 2], c.is_header, c.bbox
+ )
+ for c in element.table.cells
+ if max(c.cols) >= ncols // 2
+ ]
+ elem = element.copy()
+ ment = element.copy()
+ elem.table = Table(elem_cells, element.table.caption)
+ ment.table = Table(ment_cells, element.table.caption)
+ _reset_table_bbox(elem)
+ _reset_table_bbox(ment)
+ return SplitElements.split_one_table(
+ elem, tokenizer, max_tokens, depth + 1
+ ) + SplitElements.split_one_table(ment, tokenizer, max_tokens, depth + 1)
+ # We can attach the column header to every row, so break the rows up
+ # evenly into groups and form new tables each containing a set of rows.
+ # Ensure that each resulting table is below the token limit too.
+ data_max_tokens = max_tokens - col_header_len
+ header_cells = element.table.header_cells()
+ if len(header_cells) > 0:
+ n_header_rows = max(r for c in header_cells for r in c.rows) + 1
+ else:
+ n_header_rows = 0
+ # Try to be slightly less greedy by giving each chunk a row limit
+ expected_chunks = math.ceil(sum(data_row_lens) / data_max_tokens)
+ expected_rows_per_chunk = math.ceil(len(data_row_lens) / expected_chunks)
+ curr_len = 0
+ curr_rows: list[int] = []
+ subtables = []
+ for i, drl in enumerate(data_row_lens):
+ if curr_len + drl < data_max_tokens and len(curr_rows) < expected_rows_per_chunk:
+ curr_rows.append(i)
+ curr_len += drl
+ else:
+ begin, end = curr_rows[0], curr_rows[-1]
+ new_table_cells = header_cells + [
+ TableCell(
+ c.content,
+ [r - begin + n_header_rows for r in c.rows if begin <= r <= end],
+ c.cols,
+ c.is_header,
+ c.bbox,
+ )
+ for c in data_table.cells
+ if any(begin <= r <= end for r in c.rows)
+ ]
+ subtables.append(Table(new_table_cells, element.table.caption))
+ curr_len = drl
+ curr_rows = [i]
+
+ begin, end = curr_rows[0], curr_rows[-1]
+ new_table_cells = header_cells + [
+ TableCell(
+ c.content, [r - begin + n_header_rows for r in c.rows if begin <= r <= end], c.cols, c.is_header, c.bbox
+ )
+ for c in data_table.cells
+ if any(begin <= r <= end for r in c.rows)
+ ]
+ subtables.append(Table(new_table_cells, element.table.caption))
+ elms = [element.copy() for _ in subtables]
+ first = True
+ for elm, sbt in zip(elms, subtables):
+ elm.table = sbt
+ _reset_table_bbox(elm, ignore_header=not first)
+ first = False
+ return elms # type: ignore
+
+
+def _reset_table_bbox(te: TableElement, ignore_header: bool = False):
+ """
+ Set a table element's overall bbox to something that aligns with
+ its cells. Specifically, take the median left, right, top, and bottom
+ edges of all cells on the corresponding edge of the table. We want
+ the median here rather then the extreme because in split tables with
+ spanning cells the extreme may be the edge of the spanning cell, which
+ may hang outside of the columns or rows of the table we're working with.
+ """
+ assert te.table is not None
+ if ignore_header:
+ dc = te.table.data_cells().cells
+ else:
+ dc = te.table.cells
+ if te.bbox is None or all(c.bbox is None for c in dc):
+ return
+ max_row = max(c.rows[-1] for c in dc)
+ max_col = max(c.cols[-1] for c in dc)
+ min_row = min(c.rows[0] for c in dc)
+ min_col = min(c.cols[0] for c in dc)
+ x1s = []
+ x2s = []
+ y1s = []
+ y2s = []
+ for c in dc:
+ if c.bbox is None:
+ continue
+ if c.cols[0] == min_col:
+ x1s.append(c.bbox.x1)
+ if c.cols[-1] == max_col:
+ x2s.append(c.bbox.x2)
+ if c.rows[0] == min_row:
+ y1s.append(c.bbox.y1)
+ if c.rows[-1] == max_row:
+ y2s.append(c.bbox.y2)
+ new_bb = BoundingBox(
+ _median(x1s),
+ _median(y1s),
+ _median(x2s),
+ _median(y2s),
+ )
+ if new_bb is not None:
+ te.bbox = new_bb
+
+
+def _median(nums: list[float]) -> float:
+ nums.sort()
+ return nums[len(nums) // 2]