From 7e6b62639ce9b8f63d56cb35a32837d1c97e711e Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Wed, 4 Dec 2024 14:15:38 -0800 Subject: [PATCH] Add deformable table extractor (#1053) * add deformable table extractor Signed-off-by: Henry Lindeman * add docstrings Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- .../transforms/table_structure/extract.py | 56 +++++++++++++++++-- .../table_structure/table_transformers.py | 10 +++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/lib/sycamore/sycamore/transforms/table_structure/extract.py b/lib/sycamore/sycamore/transforms/table_structure/extract.py index 92d569ade..bde614ce7 100644 --- a/lib/sycamore/sycamore/transforms/table_structure/extract.py +++ b/lib/sycamore/sycamore/transforms/table_structure/extract.py @@ -100,9 +100,16 @@ def _prepare_tokens(self, tokens: list[dict[str, Any]], crop_box, width, height) t["block_num"] = 0 return tokens + def _init_structure_model(self): + from transformers import TableTransformerForObjectDetection + + self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device()) + @timetrace("tblExtr") @requires_modules(["torch", "torchvision"], extra="local-inference") - def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=False) -> TableElement: + def extract( + self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=False + ) -> TableElement: """Extracts the table structure from the specified element using a TableTransformer model. Takes a TableElement containing a bounding box, for example from the SycamorePartitioner, @@ -112,6 +119,8 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa element: A TableElement. The bounding box must be non-null. doc_image: A PIL object containing an image of the Document page containing the element. Used for bounding box calculations. + union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table. + apply_thresholds: Apply class thresholds to the objects output by the model. """ # We need a bounding box to be able to do anything. @@ -123,9 +132,7 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa width, height = doc_image.size if self.structure_model is None: - from transformers import TableTransformerForObjectDetection - - self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device()) + self._init_structure_model() assert self.structure_model is not None # For typechecking # Crop the image to encompass just the table + some padding. @@ -161,7 +168,9 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa structure_id2label = self.structure_model.config.id2label structure_id2label[len(structure_id2label)] = "no object" - objects = table_transformers.outputs_to_objects(outputs, cropped_image.size, structure_id2label) + objects = table_transformers.outputs_to_objects( + outputs, cropped_image.size, structure_id2label, apply_thresholds=apply_thresholds + ) # Convert the raw objects to our internal table representation. This involves multiple # phases of postprocessing. @@ -182,6 +191,43 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa return element +class DeformableTableStructureExtractor(TableTransformerStructureExtractor): + """A TableStructureExtractor implementation that uses the Deformable DETR model.""" + + def __init__(self, model: str, device=None): + """ + Creates a TableTransformerStructureExtractor + + Args: + model: The HuggingFace URL or local path for the DeformableDETR model to use. + """ + + super().__init__(model, device) + + def _init_structure_model(self): + from transformers import DeformableDetrForObjectDetection + + self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device()) + + def extract( + self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True + ) -> TableElement: + """Extracts the table structure from the specified element using a DeformableDETR model. + + Takes a TableElement containing a bounding box, for example from the SycamorePartitioner, + and populates the table property with information about the cells. + + Args: + element: A TableElement. The bounding box must be non-null. + doc_image: A PIL object containing an image of the Document page containing the element. + Used for bounding box calculations. + union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table. + apply_thresholds: Apply class thresholds to the objects output by the model. + """ + # Literally just call the super but change the default for apply_thresholds + return super().extract(element, doc_image, union_tokens, apply_thresholds) + + DEFAULT_TABLE_STRUCTURE_EXTRACTOR = TableTransformerStructureExtractor diff --git a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py index 7b25a89a1..71cd8da75 100644 --- a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py +++ b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py @@ -75,13 +75,18 @@ def rescale_bboxes(out_bbox, size): return b -def outputs_to_objects(outputs, img_size, id2label): +def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = False): m = outputs.logits.softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy())[0] pred_scores = list(m.values.detach().cpu().numpy())[0] pred_bboxes = outputs["pred_boxes"].detach().cpu()[0] pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] + if apply_thresholds: + pred_bboxes, pred_scores, pred_labels = apply_class_thresholds( + pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS + ) + objects = [] for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): if float(bbox[0]) > float(bbox[2]) or float(bbox[1]) > float(bbox[3]): @@ -906,9 +911,10 @@ def objects_to_structures(objects, tokens, class_thresholds): if len(tables) == 0: return {} if len(tables) > 1: + tables.sort(key=lambda x: x["score"], reverse=True) import logging - logging.warning("Got multiple tables in document. Using only the first one") + logging.warning("Got multiple tables in document. Using only the highest-scoring one") table = tables[0] structure = {}