Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split tables horizontally in split_elements if the column headers are big enough to make splitting hard #1104

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/sycamore/sycamore/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,33 @@ def to_dict(self) -> dict[str, Any]:
d["num_cols"] = self.num_cols
return d

def data_cells(self) -> "Table":
"""Returns a table containing only the data cells in this table."""
header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header)))
i = -1
for r in header_rows:
i += 1
if r != i:
break
data_cells = [c for c in self.cells if c.rows[0] > i]
shifted_cells = [
TableCell(
content=c.content, rows=[r - i - 1 for r in c.rows], cols=c.cols, is_header=c.is_header, bbox=c.bbox
)
for c in data_cells
]
return Table(shifted_cells, column_headers=[], caption=self.caption)

def header_cells(self) -> list[TableCell]:
"""Returns the header cells as a list"""
header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header)))
i = -1
for r in header_rows:
i += 1
if r != i:
break
return [c for c in self.cells if c.rows[0] <= i]

# TODO: There are likely edge cases where this will break or lose information. Nested or non-contiguous
# headers are one likely source of issues. We also don't support missing closing tags (which are allowed in
# the spec) because html.parser doesn't handle them. If and when this becomes an issue, we can consider
Expand Down
41 changes: 39 additions & 2 deletions lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray.data

from sycamore.data import Document
from sycamore.data import Document, TableElement, Table
from sycamore.transforms.split_elements import SplitElements
from sycamore.functions.tokenizer import HuggingFaceTokenizer
from sycamore.functions.tokenizer import HuggingFaceTokenizer, CharacterTokenizer
from sycamore.plan_nodes import Node


Expand Down Expand Up @@ -32,6 +32,29 @@ class TestSplitElements:
}
)

bigtable = """
<table>
<tr><th rowspan="2">headerA</th><th colspan="2">headerB</th><th>headerC</th></tr>
<tr><th>headerD</th><th>headerE</th><th>headerF</th></tr>
<tr><td>data1a</td><td>data2a</td><td>data3a</td><td>data4a</td></tr>
<tr><td>data1b</td><td>data2b</td><td>data3b</td><td>data4b</td></tr>
<tr><td>data1c</td><td>data2c</td><td>data3c</td><td>data4c</td></tr>
<tr><td>data1d</td><td>data2d</td><td>data3d</td><td>data4d</td></tr>
</table>
"""

tabledoc = Document(
{
"doc_id": "id",
"type": "pdf",
"text_representation": "lkqwrg",
"binary_representation": None,
"parent_id": None,
"properties": {"path": "/filename.yolo", "title": "lkqwrg"},
"elements": [TableElement(table=Table.from_html(bigtable))],
}
)

def test_split_elements(self):
tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2")
doc = SplitElements(None, tokenizer, 15).run(self.doc)
Expand Down Expand Up @@ -62,3 +85,17 @@ def validateElems(self, elems):
assert elems[7].text_representation == "thirtyeight thirtynine forty fortyone fortytwo fortythree "
assert elems[8].text_representation == "fortyfour fortyfive fortysix "
assert elems[9].text_representation == "fortyseven fortyeight fortynine"

def test_split_table(self):
tk = CharacterTokenizer()
doc = SplitElements(None, tk, 35).run(self.tabledoc)
answers = {
'<table><tr><th rowspan="2">headerA</th></tr><tr><td>data1a</td></tr><tr><td>data1b</td></tr><tr><td>data1c</td></tr><tr><td>data1d</td></tr></table>', # noqa: E501
"<table><tr><th>headerB</th></tr><tr><th>headerD</th></tr><tr><td>data2a</td></tr><tr><td>data2b</td></tr></table>",
"<table><tr><th>headerB</th></tr><tr><th>headerD</th></tr><tr><td>data2c</td></tr><tr><td>data2d</td></tr></table>",
"<table><tr><th>headerB</th></tr><tr><th>headerE</th></tr><tr><td>data3a</td></tr><tr><td>data3b</td></tr></table>",
"<table><tr><th>headerB</th></tr><tr><th>headerE</th></tr><tr><td>data3c</td></tr><tr><td>data3d</td></tr></table>",
"<table><tr><th>headerC</th></tr><tr><th>headerF</th></tr><tr><td>data4a</td></tr><tr><td>data4b</td></tr></table>",
"<table><tr><th>headerC</th></tr><tr><th>headerF</th></tr><tr><td>data4c</td></tr><tr><td>data4d</td></tr></table>",
}
assert {e.table.to_html() for e in doc.elements} == answers
167 changes: 165 additions & 2 deletions lib/sycamore/sycamore/transforms/split_elements.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import math
from typing import Optional
import logging
from sycamore.data import Document, Element, TableElement
from sycamore.data import Document, Element, TableElement, TableCell, Table, BoundingBox
from sycamore.functions.tokenizer import Tokenizer
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace

logger = logging.getLogger(__name__)

RECURSIVE_SPLIT_MAX_DEPTH = 20


class SplitElements(SingleThreadUser, NonGPUUser, Map):
"""
Expand Down Expand Up @@ -46,10 +49,13 @@ def split_doc(parent: Document, tokenizer: Tokenizer, max: int) -> Document:

@staticmethod
def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) -> list[Element]:
if depth > 20:
if depth > RECURSIVE_SPLIT_MAX_DEPTH:
logger.warning("Max split depth exceeded, truncating the splitting")
return [elem]

if elem.type == "table" and isinstance(elem, TableElement) and elem.table is not None:
return SplitElements.split_one_table(elem, tokenizer, max, depth)

txt = elem.text_representation
if not txt:
return [elem]
Expand Down Expand Up @@ -138,3 +144,160 @@ def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) ->
bb = SplitElements.split_one(ment, tokenizer, max, depth + 1)
aa.extend(bb)
return aa

@staticmethod
def split_one_table(element: TableElement, tokenizer: Tokenizer, max_tokens: int, depth: int = 0) -> list[Element]:
"""
Special handling for tables: If the column header is too big, no amount of splitting the
rows will save us, as we want to attach the col header to each subtable. In this case,
split the table horizontally. If the column header is small enough, we can guess the number
of rows per subtable and try to break the table into chunks of that size for evenness (still
breaking when we run out of tokens). Special care is taken to adjust the bounding box appropriately.
"""
if depth > RECURSIVE_SPLIT_MAX_DEPTH:
logger.warning("Max split depth exceeded, truncating the splitting")
return [element]

assert element.table is not None, "Cannot split a table without table structure"

col_header_len = len(tokenizer.tokenize(", ".join(element.table.column_headers)))
data_table = element.table.data_cells()
data_row_lens = [
len(tokenizer.tokenize(", ".join([c.content for c in data_table.cells if i in c.rows])))
for i in range(data_table.num_rows)
]
if col_header_len > max_tokens - max(data_row_lens):
# If there is a row and column header that cannot combine,
# split table horizontally in half and recurse. Splitting
# is done by column number rather than text.
ncols = element.table.num_cols
if ncols <= 1:
# One-column table that's too big - turn it into text and split in the traditional way.
new_elt = element.copy()
new_elt.data["text_representation"] = new_elt.text_representation
new_elt.type = "Text"
return SplitElements.split_one(new_elt, tokenizer, max_tokens, depth + 1)
# Split the table by splitting the cells into groups.
elem_cells = [
TableCell(c.content, c.rows, [cl for cl in c.cols if cl < ncols // 2], c.is_header, c.bbox)
for c in element.table.cells
if min(c.cols) < ncols // 2
]
ment_cells = [
TableCell(
c.content, c.rows, [cl - ncols // 2 for cl in c.cols if cl >= ncols // 2], c.is_header, c.bbox
)
for c in element.table.cells
if max(c.cols) >= ncols // 2
]
elem = element.copy()
ment = element.copy()
elem.table = Table(elem_cells, element.table.caption)
ment.table = Table(ment_cells, element.table.caption)
_reset_table_bbox(elem)
_reset_table_bbox(ment)
return SplitElements.split_one_table(
elem, tokenizer, max_tokens, depth + 1
) + SplitElements.split_one_table(ment, tokenizer, max_tokens, depth + 1)
# We can attach the column header to every row, so break the rows up
# evenly into groups and form new tables each containing a set of rows.
# Ensure that each resulting table is below the token limit too.
data_max_tokens = max_tokens - col_header_len
header_cells = element.table.header_cells()
if len(header_cells) > 0:
n_header_rows = max(r for c in header_cells for r in c.rows) + 1
else:
n_header_rows = 0
# Try to be slightly less greedy by giving each chunk a row limit
expected_chunks = math.ceil(sum(data_row_lens) / data_max_tokens)
expected_rows_per_chunk = math.ceil(len(data_row_lens) / expected_chunks)
curr_len = 0
curr_rows: list[int] = []
subtables = []
for i, drl in enumerate(data_row_lens):
if curr_len + drl < data_max_tokens and len(curr_rows) < expected_rows_per_chunk:
curr_rows.append(i)
curr_len += drl
else:
begin, end = curr_rows[0], curr_rows[-1]
new_table_cells = header_cells + [
TableCell(
c.content,
[r - begin + n_header_rows for r in c.rows if begin <= r <= end],
c.cols,
c.is_header,
c.bbox,
)
for c in data_table.cells
if any(begin <= r <= end for r in c.rows)
]
subtables.append(Table(new_table_cells, element.table.caption))
curr_len = drl
curr_rows = [i]

begin, end = curr_rows[0], curr_rows[-1]
new_table_cells = header_cells + [
TableCell(
c.content, [r - begin + n_header_rows for r in c.rows if begin <= r <= end], c.cols, c.is_header, c.bbox
)
for c in data_table.cells
if any(begin <= r <= end for r in c.rows)
]
subtables.append(Table(new_table_cells, element.table.caption))
elms = [element.copy() for _ in subtables]
first = True
for elm, sbt in zip(elms, subtables):
elm.table = sbt
_reset_table_bbox(elm, ignore_header=not first)
first = False
return elms # type: ignore


def _reset_table_bbox(te: TableElement, ignore_header: bool = False):
"""
Set a table element's overall bbox to something that aligns with
its cells. Specifically, take the median left, right, top, and bottom
edges of all cells on the corresponding edge of the table. We want
the median here rather then the extreme because in split tables with
spanning cells the extreme may be the edge of the spanning cell, which
may hang outside of the columns or rows of the table we're working with.
"""
assert te.table is not None
if ignore_header:
dc = te.table.data_cells().cells
else:
dc = te.table.cells
if te.bbox is None or all(c.bbox is None for c in dc):
return
max_row = max(c.rows[-1] for c in dc)
max_col = max(c.cols[-1] for c in dc)
min_row = min(c.rows[0] for c in dc)
min_col = min(c.cols[0] for c in dc)
x1s = []
x2s = []
y1s = []
y2s = []
for c in dc:
if c.bbox is None:
continue
if c.cols[0] == min_col:
x1s.append(c.bbox.x1)
if c.cols[-1] == max_col:
x2s.append(c.bbox.x2)
if c.rows[0] == min_row:
y1s.append(c.bbox.y1)
if c.rows[-1] == max_row:
y2s.append(c.bbox.y2)
new_bb = BoundingBox(
_median(x1s),
_median(y1s),
_median(x2s),
_median(y2s),
)
if new_bb is not None:
te.bbox = new_bb


def _median(nums: list[float]) -> float:
nums.sort()
return nums[len(nums) // 2]
Loading