Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TEDS evaluation code #929

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9a31051
pluggable table evaluation script
HenryL27 Jul 31, 2024
407e317
fix token bboxes and clean out spurious html tags
HenryL27 Jul 31, 2024
450090b
merge
HenryL27 Jul 31, 2024
dfa4239
add local fintabnet s3 scan
HenryL27 Aug 1, 2024
1483167
fix image resizing issue
HenryL27 Aug 1, 2024
f61997b
commit changes
HenryL27 Oct 15, 2024
cfd23cc
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-teds-eval
HenryL27 Oct 15, 2024
c238bb7
relock sycamore poetry.lock
HenryL27 Oct 15, 2024
622425e
relock monorepo poetry.lock
HenryL27 Oct 15, 2024
44ff27a
add distance
HenryL27 Oct 15, 2024
471a483
add paddlev2 extractor
HenryL27 Oct 16, 2024
5c596ff
Merge branch 'hml-teds-eval-2' of ssh://testing:/home/ubuntu/sycamore…
HenryL27 Oct 16, 2024
3f1ff50
Load model only once when the class object is created.
akarshgupta7 Oct 18, 2024
758891d
Added code to use debug and limit args properly.
akarshgupta7 Oct 18, 2024
d251c34
add homemade model to eval script
HenryL27 Nov 19, 2024
0700b11
factor around deformable detr loading/lockfile management for use wit…
HenryL27 Dec 4, 2024
9b5a9e9
remove unused global variable
HenryL27 Dec 4, 2024
ea27bcf
move .to(device) iniside the lock
HenryL27 Dec 4, 2024
0a8a2b1
jitpick
HenryL27 Dec 4, 2024
4f49ede
set deformable table extractor choose_device detr=True
HenryL27 Dec 5, 2024
ee1f80c
misc postprocessing tweaks
HenryL27 Dec 16, 2024
bdfd21b
before merge
HenryL27 Dec 16, 2024
c5671d2
Merge branch 'hml-table-ppfix2' of github.com:aryn-ai/sycamore into h…
HenryL27 Dec 16, 2024
bfb7b3d
save and merge?
HenryL27 Dec 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
683 changes: 345 additions & 338 deletions lib/sycamore/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lib/sycamore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ transformers = { version = "^4.43.1", optional = true }
# Legacy partitioner dependencies
unstructured = { version = "0.10.20", optional = true }
python-pptx = {version = "^0.6.22", optional = true }
distance = "^0.1.3"
nanoid = "^2.0.0"

[tool.poetry.group.test.dependencies]
Expand Down
91 changes: 91 additions & 0 deletions lib/sycamore/sycamore/evaluation/tables/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from argparse import ArgumentParser
from tqdm import tqdm
from ray.data import ActorPoolStrategy
import sycamore
from sycamore.context import ExecMode
from sycamore.evaluation.tables.extractors import ExtractTableFromImage, FlorenceTableStructureExtractor, HomemadeTableTransformerTableStructureExtractor, PaddleTableStructureExtractor, TextractTableStructureExtractor, PaddleV2TableStructureExtractor
from sycamore.evaluation.tables.table_metrics import TEDSMetric, apply_metric
from sycamore.transforms.table_structure.extract import TableTransformerStructureExtractor

from .benchmark_scans import CohereTabNetS3Scan, FinTabNetS3Scan, PubTabNetScan, TableEvalDoc
from time import time

SCANS = {"pubtabnet": PubTabNetScan, "fintabnet": FinTabNetS3Scan, "coheretabnet": CohereTabNetS3Scan}

EXTRACTORS = {
"tabletransformer": (TableTransformerStructureExtractor, ActorPoolStrategy(size=1), {"device": "cuda:0"}),
"paddleocr": (PaddleTableStructureExtractor, None, {}),
"paddlev2": (PaddleV2TableStructureExtractor, None, {}),
"textract": (TextractTableStructureExtractor, None, {}),
"florence": (FlorenceTableStructureExtractor, None, {}),
"homemade": (HomemadeTableTransformerTableStructureExtractor, ActorPoolStrategy(size=1), {"device": "cuda:0"}),
}


def local_aggregate(docs, *agg_fns):
aggcumulations = {af.name: af.init(af.name) for af in agg_fns}
for doc in docs:
for af in agg_fns:
aggcumulations[af.name] = af.accumulate_row(aggcumulations[af.name], doc, in_ray=False)
return {af.name: af.finalize(aggcumulations[af.name]) for af in agg_fns}


parser = ArgumentParser()
parser.add_argument("dataset", choices=list(SCANS.keys()), help="dataset to evaluate")
parser.add_argument("extractor", choices=list(EXTRACTORS.keys()), help="TableStructureExtractor to evaluate")
parser.add_argument("--debug", action="store_true")
parser.add_argument("-l", "--limit", default=-1, type=int, required=False)
parser.add_argument("--noreal", action="store_true")
args = parser.parse_args()
print(args)

metrics = [
TEDSMetric(structure_only=True),
# TEDSMetric(structure_only=False),
]

local_ctx = sycamore.init(exec_mode=ExecMode.LOCAL)
# ray_ctx = sycamore.init()

# sc = SCANS[args.dataset]().to_docset(ray_ctx)
# if args.debug:
# sc = sc.limit(10)

docs = []
docgenerator = iter(SCANS[args.dataset]().local_process(limit=args.limit))
if args.debug:
if args.limit == -1:
args.limit = 10
for _ in tqdm(range(args.limit), desc="Loading documents"):
try:
docs.append(next(docgenerator))
except StopIteration:
break
else:
for doc in tqdm(docgenerator, desc="Loading documents"):
docs.append(doc)
print(f"Loaded {len(docs)} documents")

start = time()
extractor, actorpool, kwargs = EXTRACTORS[args.extractor]
extracted = local_ctx.read.document(docs).map_batch(ExtractTableFromImage(extractor(**kwargs)))
measured = extracted
for m in metrics:
measured = measured.map(apply_metric(m))

if args.debug:
doc = measured.take(2)[1]
ed = TableEvalDoc(doc.data)
del ed["image"]
del ed.properties["tokens"]
print(ed.gt_table.to_html())
print(ed.pred_table.to_html())
print(ed.data)

if args.noreal:
exit()
# aggs = measured.plan.execute().aggregate(*[m.to_aggregate_fn() for m in metrics])
aggs = local_aggregate(measured.take_all(), *[m.to_aggregate_fn(in_ray=False) for m in metrics])
print("=" * 80)
print(aggs)
print(f"Time spent for extraction and metrics calculation: {time() - start} seconds.")
Loading
Loading