Skip to content

Commit

Permalink
Add deformable table extractor (#1053)
Browse files Browse the repository at this point in the history
* add deformable table extractor

Signed-off-by: Henry Lindeman <[email protected]>

* add docstrings

Signed-off-by: Henry Lindeman <[email protected]>

---------

Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 authored Dec 4, 2024
1 parent 9220d7a commit 7e6b626
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
56 changes: 51 additions & 5 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,16 @@ def _prepare_tokens(self, tokens: list[dict[str, Any]], crop_box, width, height)
t["block_num"] = 0
return tokens

def _init_structure_model(self):
from transformers import TableTransformerForObjectDetection

self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device())

@timetrace("tblExtr")
@requires_modules(["torch", "torchvision"], extra="local-inference")
def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=False) -> TableElement:
def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=False
) -> TableElement:
"""Extracts the table structure from the specified element using a TableTransformer model.
Takes a TableElement containing a bounding box, for example from the SycamorePartitioner,
Expand All @@ -112,6 +119,8 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
element: A TableElement. The bounding box must be non-null.
doc_image: A PIL object containing an image of the Document page containing the element.
Used for bounding box calculations.
union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table.
apply_thresholds: Apply class thresholds to the objects output by the model.
"""

# We need a bounding box to be able to do anything.
Expand All @@ -123,9 +132,7 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
width, height = doc_image.size

if self.structure_model is None:
from transformers import TableTransformerForObjectDetection

self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device())
self._init_structure_model()
assert self.structure_model is not None # For typechecking

# Crop the image to encompass just the table + some padding.
Expand Down Expand Up @@ -161,7 +168,9 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
structure_id2label = self.structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"

objects = table_transformers.outputs_to_objects(outputs, cropped_image.size, structure_id2label)
objects = table_transformers.outputs_to_objects(
outputs, cropped_image.size, structure_id2label, apply_thresholds=apply_thresholds
)

# Convert the raw objects to our internal table representation. This involves multiple
# phases of postprocessing.
Expand All @@ -182,6 +191,43 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
return element


class DeformableTableStructureExtractor(TableTransformerStructureExtractor):
"""A TableStructureExtractor implementation that uses the Deformable DETR model."""

def __init__(self, model: str, device=None):
"""
Creates a TableTransformerStructureExtractor
Args:
model: The HuggingFace URL or local path for the DeformableDETR model to use.
"""

super().__init__(model, device)

def _init_structure_model(self):
from transformers import DeformableDetrForObjectDetection

self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device())

def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True
) -> TableElement:
"""Extracts the table structure from the specified element using a DeformableDETR model.
Takes a TableElement containing a bounding box, for example from the SycamorePartitioner,
and populates the table property with information about the cells.
Args:
element: A TableElement. The bounding box must be non-null.
doc_image: A PIL object containing an image of the Document page containing the element.
Used for bounding box calculations.
union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table.
apply_thresholds: Apply class thresholds to the objects output by the model.
"""
# Literally just call the super but change the default for apply_thresholds
return super().extract(element, doc_image, union_tokens, apply_thresholds)


DEFAULT_TABLE_STRUCTURE_EXTRACTOR = TableTransformerStructureExtractor


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,18 @@ def rescale_bboxes(out_bbox, size):
return b


def outputs_to_objects(outputs, img_size, id2label):
def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = False):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

if apply_thresholds:
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds(
pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS
)

objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
if float(bbox[0]) > float(bbox[2]) or float(bbox[1]) > float(bbox[3]):
Expand Down Expand Up @@ -906,9 +911,10 @@ def objects_to_structures(objects, tokens, class_thresholds):
if len(tables) == 0:
return {}
if len(tables) > 1:
tables.sort(key=lambda x: x["score"], reverse=True)
import logging

logging.warning("Got multiple tables in document. Using only the first one")
logging.warning("Got multiple tables in document. Using only the highest-scoring one")

table = tables[0]
structure = {}
Expand Down

0 comments on commit 7e6b626

Please sign in to comment.