diff --git a/lib/sycamore/sycamore/data/table.py b/lib/sycamore/sycamore/data/table.py index 2ec195dd9..1600947aa 100644 --- a/lib/sycamore/sycamore/data/table.py +++ b/lib/sycamore/sycamore/data/table.py @@ -184,6 +184,33 @@ def to_dict(self) -> dict[str, Any]: d["num_cols"] = self.num_cols return d + def data_cells(self) -> "Table": + """Returns a table containing only the data cells in this table.""" + header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header))) + i = -1 + for r in header_rows: + i += 1 + if r != i: + break + data_cells = [c for c in self.cells if c.rows[0] > i] + shifted_cells = [ + TableCell( + content=c.content, rows=[r - i - 1 for r in c.rows], cols=c.cols, is_header=c.is_header, bbox=c.bbox + ) + for c in data_cells + ] + return Table(shifted_cells, column_headers=[], caption=self.caption) + + def header_cells(self) -> list[TableCell]: + """Returns the header cells as a list""" + header_rows = sorted(set((row_num for cell in self.cells for row_num in cell.rows if cell.is_header))) + i = -1 + for r in header_rows: + i += 1 + if r != i: + break + return [c for c in self.cells if c.rows[0] <= i] + # TODO: There are likely edge cases where this will break or lose information. Nested or non-contiguous # headers are one likely source of issues. We also don't support missing closing tags (which are allowed in # the spec) because html.parser doesn't handle them. If and when this becomes an issue, we can consider diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py b/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py index 7fec46a72..a1b6e8693 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_split_elements.py @@ -1,8 +1,8 @@ import ray.data -from sycamore.data import Document +from sycamore.data import Document, TableElement, Table from sycamore.transforms.split_elements import SplitElements -from sycamore.functions.tokenizer import HuggingFaceTokenizer +from sycamore.functions.tokenizer import HuggingFaceTokenizer, CharacterTokenizer from sycamore.plan_nodes import Node @@ -32,6 +32,29 @@ class TestSplitElements: } ) + bigtable = """ + + + + + + + +
headerAheaderBheaderC
headerDheaderEheaderF
data1adata2adata3adata4a
data1bdata2bdata3bdata4b
data1cdata2cdata3cdata4c
data1ddata2ddata3ddata4d
+ """ + + tabledoc = Document( + { + "doc_id": "id", + "type": "pdf", + "text_representation": "lkqwrg", + "binary_representation": None, + "parent_id": None, + "properties": {"path": "/filename.yolo", "title": "lkqwrg"}, + "elements": [TableElement(table=Table.from_html(bigtable))], + } + ) + def test_split_elements(self): tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2") doc = SplitElements(None, tokenizer, 15).run(self.doc) @@ -62,3 +85,17 @@ def validateElems(self, elems): assert elems[7].text_representation == "thirtyeight thirtynine forty fortyone fortytwo fortythree " assert elems[8].text_representation == "fortyfour fortyfive fortysix " assert elems[9].text_representation == "fortyseven fortyeight fortynine" + + def test_split_table(self): + tk = CharacterTokenizer() + doc = SplitElements(None, tk, 35).run(self.tabledoc) + answers = { + '
headerA
data1a
data1b
data1c
data1d
', # noqa: E501 + "
headerB
headerD
data2a
data2b
", + "
headerB
headerD
data2c
data2d
", + "
headerB
headerE
data3a
data3b
", + "
headerB
headerE
data3c
data3d
", + "
headerC
headerF
data4a
data4b
", + "
headerC
headerF
data4c
data4d
", + } + assert {e.table.to_html() for e in doc.elements} == answers diff --git a/lib/sycamore/sycamore/transforms/split_elements.py b/lib/sycamore/sycamore/transforms/split_elements.py index 1966ce1ec..bf382bbf1 100644 --- a/lib/sycamore/sycamore/transforms/split_elements.py +++ b/lib/sycamore/sycamore/transforms/split_elements.py @@ -1,6 +1,7 @@ +import math from typing import Optional import logging -from sycamore.data import Document, Element, TableElement +from sycamore.data import Document, Element, TableElement, TableCell, Table, BoundingBox from sycamore.functions.tokenizer import Tokenizer from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser from sycamore.transforms.map import Map @@ -8,6 +9,8 @@ logger = logging.getLogger(__name__) +RECURSIVE_SPLIT_MAX_DEPTH = 20 + class SplitElements(SingleThreadUser, NonGPUUser, Map): """ @@ -46,10 +49,13 @@ def split_doc(parent: Document, tokenizer: Tokenizer, max: int) -> Document: @staticmethod def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) -> list[Element]: - if depth > 20: + if depth > RECURSIVE_SPLIT_MAX_DEPTH: logger.warning("Max split depth exceeded, truncating the splitting") return [elem] + if elem.type == "table" and isinstance(elem, TableElement) and elem.table is not None: + return SplitElements.split_one_table(elem, tokenizer, max, depth) + txt = elem.text_representation if not txt: return [elem] @@ -138,3 +144,160 @@ def split_one(elem: Element, tokenizer: Tokenizer, max: int, depth: int = 0) -> bb = SplitElements.split_one(ment, tokenizer, max, depth + 1) aa.extend(bb) return aa + + @staticmethod + def split_one_table(element: TableElement, tokenizer: Tokenizer, max_tokens: int, depth: int = 0) -> list[Element]: + """ + Special handling for tables: If the column header is too big, no amount of splitting the + rows will save us, as we want to attach the col header to each subtable. In this case, + split the table horizontally. If the column header is small enough, we can guess the number + of rows per subtable and try to break the table into chunks of that size for evenness (still + breaking when we run out of tokens). Special care is taken to adjust the bounding box appropriately. + """ + if depth > RECURSIVE_SPLIT_MAX_DEPTH: + logger.warning("Max split depth exceeded, truncating the splitting") + return [element] + + assert element.table is not None, "Cannot split a table without table structure" + + col_header_len = len(tokenizer.tokenize(", ".join(element.table.column_headers))) + data_table = element.table.data_cells() + data_row_lens = [ + len(tokenizer.tokenize(", ".join([c.content for c in data_table.cells if i in c.rows]))) + for i in range(data_table.num_rows) + ] + if col_header_len > max_tokens - max(data_row_lens): + # If there is a row and column header that cannot combine, + # split table horizontally in half and recurse. Splitting + # is done by column number rather than text. + ncols = element.table.num_cols + if ncols <= 1: + # One-column table that's too big - turn it into text and split in the traditional way. + new_elt = element.copy() + new_elt.data["text_representation"] = new_elt.text_representation + new_elt.type = "Text" + return SplitElements.split_one(new_elt, tokenizer, max_tokens, depth + 1) + # Split the table by splitting the cells into groups. + elem_cells = [ + TableCell(c.content, c.rows, [cl for cl in c.cols if cl < ncols // 2], c.is_header, c.bbox) + for c in element.table.cells + if min(c.cols) < ncols // 2 + ] + ment_cells = [ + TableCell( + c.content, c.rows, [cl - ncols // 2 for cl in c.cols if cl >= ncols // 2], c.is_header, c.bbox + ) + for c in element.table.cells + if max(c.cols) >= ncols // 2 + ] + elem = element.copy() + ment = element.copy() + elem.table = Table(elem_cells, element.table.caption) + ment.table = Table(ment_cells, element.table.caption) + _reset_table_bbox(elem) + _reset_table_bbox(ment) + return SplitElements.split_one_table( + elem, tokenizer, max_tokens, depth + 1 + ) + SplitElements.split_one_table(ment, tokenizer, max_tokens, depth + 1) + # We can attach the column header to every row, so break the rows up + # evenly into groups and form new tables each containing a set of rows. + # Ensure that each resulting table is below the token limit too. + data_max_tokens = max_tokens - col_header_len + header_cells = element.table.header_cells() + if len(header_cells) > 0: + n_header_rows = max(r for c in header_cells for r in c.rows) + 1 + else: + n_header_rows = 0 + # Try to be slightly less greedy by giving each chunk a row limit + expected_chunks = math.ceil(sum(data_row_lens) / data_max_tokens) + expected_rows_per_chunk = math.ceil(len(data_row_lens) / expected_chunks) + curr_len = 0 + curr_rows: list[int] = [] + subtables = [] + for i, drl in enumerate(data_row_lens): + if curr_len + drl < data_max_tokens and len(curr_rows) < expected_rows_per_chunk: + curr_rows.append(i) + curr_len += drl + else: + begin, end = curr_rows[0], curr_rows[-1] + new_table_cells = header_cells + [ + TableCell( + c.content, + [r - begin + n_header_rows for r in c.rows if begin <= r <= end], + c.cols, + c.is_header, + c.bbox, + ) + for c in data_table.cells + if any(begin <= r <= end for r in c.rows) + ] + subtables.append(Table(new_table_cells, element.table.caption)) + curr_len = drl + curr_rows = [i] + + begin, end = curr_rows[0], curr_rows[-1] + new_table_cells = header_cells + [ + TableCell( + c.content, [r - begin + n_header_rows for r in c.rows if begin <= r <= end], c.cols, c.is_header, c.bbox + ) + for c in data_table.cells + if any(begin <= r <= end for r in c.rows) + ] + subtables.append(Table(new_table_cells, element.table.caption)) + elms = [element.copy() for _ in subtables] + first = True + for elm, sbt in zip(elms, subtables): + elm.table = sbt + _reset_table_bbox(elm, ignore_header=not first) + first = False + return elms # type: ignore + + +def _reset_table_bbox(te: TableElement, ignore_header: bool = False): + """ + Set a table element's overall bbox to something that aligns with + its cells. Specifically, take the median left, right, top, and bottom + edges of all cells on the corresponding edge of the table. We want + the median here rather then the extreme because in split tables with + spanning cells the extreme may be the edge of the spanning cell, which + may hang outside of the columns or rows of the table we're working with. + """ + assert te.table is not None + if ignore_header: + dc = te.table.data_cells().cells + else: + dc = te.table.cells + if te.bbox is None or all(c.bbox is None for c in dc): + return + max_row = max(c.rows[-1] for c in dc) + max_col = max(c.cols[-1] for c in dc) + min_row = min(c.rows[0] for c in dc) + min_col = min(c.cols[0] for c in dc) + x1s = [] + x2s = [] + y1s = [] + y2s = [] + for c in dc: + if c.bbox is None: + continue + if c.cols[0] == min_col: + x1s.append(c.bbox.x1) + if c.cols[-1] == max_col: + x2s.append(c.bbox.x2) + if c.rows[0] == min_row: + y1s.append(c.bbox.y1) + if c.rows[-1] == max_row: + y2s.append(c.bbox.y2) + new_bb = BoundingBox( + _median(x1s), + _median(y1s), + _median(x2s), + _median(y2s), + ) + if new_bb is not None: + te.bbox = new_bb + + +def _median(nums: list[float]) -> float: + nums.sort() + return nums[len(nums) // 2]