From acbd837f3ce395653a490c1e7b14f4ce60de5045 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 13 Dec 2024 12:40:47 -0500 Subject: [PATCH] Add memory-efficient stack for consolidate tasks. Using this for MakeCcdVisitTableTask is trickier because we can't ask an ExposureCatalog how many rows it has before loading it. If that's needed, we can extend this code to do it on another branch. --- python/lsst/pipe/tasks/postprocess.py | 98 ++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 3 deletions(-) diff --git a/python/lsst/pipe/tasks/postprocess.py b/python/lsst/pipe/tasks/postprocess.py index 684a1e745..0dd96fc03 100644 --- a/python/lsst/pipe/tasks/postprocess.py +++ b/python/lsst/pipe/tasks/postprocess.py @@ -44,6 +44,7 @@ import numpy as np import pandas as pd import astropy.table +import astropy.utils.metadata import lsst.geom import lsst.pex.config as pexConfig @@ -82,6 +83,95 @@ def flattenFilters(df, noDupCols=["coord_ra", "coord_dec"], camelCase=False, inp return newDf +class TableVStack: + """A helper class for stacking astropy tables without having them all in + memory at once. + + Parameters + ---------- + capacity : `int` + Full size of the final table. + + Notes + ----- + Unlike `astropy.table.vstack`, this class requires all tables to have the + exact same columns (it's slightly more strict than even the + ``join_type="exact"`` argument to `astropy.table.vstack`). + """ + + def __init__(self, capacity): + self.index = 0 + self.capacity = capacity + self.result = None + + @classmethod + def from_handles(cls, handles): + """Construct from an iterable of + `lsst.daf.butler.DeferredDatasetHandle`. + + Parameters + ---------- + handles : `~collections.abc.Iterable` [ \ + `lsst.daf.butler.DeferredDatasetHandle` ] + Iterable of handles. Must have a storage class that supports the + "rowcount" component, which is all that will be fetched. + + Returns + ------- + vstack : `TableVStack` + An instance of this class, initialized with capacity equal to the + sum of the rowcounts of all the given table handles. + """ + capacity = sum(handle.get(component="rowcount") for handle in handles) + return cls(capacity=capacity) + + def extend(self, table): + """Add a single table to the stack. + + Parameters + ---------- + table : `astropy.table.Table` + An astropy table instance. + """ + if self.result is None: + self.result = astropy.table.Table() + for name in table.colnames: + column = table[name] + column_cls = type(column) + self.result[name] = column_cls.info.new_like([column], self.capacity, name=name) + self.index = len(table) + self.result.meta = table.meta.copy() + else: + next_index = self.index + len(table) + for name in table.colnames: + self.result[name][self.index:next_index] = table[name] + self.index = next_index + self.result.meta = astropy.utils.metadata.merge(self.result.meta, table.meta) + + @classmethod + def vstack_handles(cls, handles): + """Vertically stack tables represented by deferred dataset handles. + + Parameters + ---------- + handles : `~collections.abc.Iterable` [ \ + `lsst.daf.butler.DeferredDatasetHandle` ] + Iterable of handles. Must have the "ArrowAstropy" storage class + and identical columns. + + Returns + ------- + table : `astropy.table.Table` + Concatenated table with the same columns as each input table and + the rows of all of them. + """ + handles = tuple(handles) # guard against single-pass iterators + vstack = cls.from_handles(handles) + for handle in handles: + vstack.extend(handle.get()) + return vstack.result + + class WriteObjectTableConnections(pipeBase.PipelineTaskConnections, defaultTemplates={"coaddName": "deep"}, dimensions=("tract", "patch", "skymap")): @@ -932,6 +1022,7 @@ class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections, storageClass="ArrowAstropy", dimensions=("tract", "patch", "skymap"), multiple=True, + deferLoad=True, ) outputCatalog = connectionTypes.Output( doc="Pre-tract horizontal concatenation of the input objectTables", @@ -965,7 +1056,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) self.log.info("Concatenating %s per-patch Object Tables", len(inputs["inputCatalogs"])) - table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact") + table = TableVStack.vstack_handles(inputs["inputCatalogs"]) butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs) @@ -1142,7 +1233,8 @@ class ConsolidateSourceTableConnections(pipeBase.PipelineTaskConnections, name="{catalogType}sourceTable", storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), - multiple=True + multiple=True, + deferLoad=True, ) outputCatalog = connectionTypes.Output( doc="Per-visit concatenation of Source Table", @@ -1175,7 +1267,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) self.log.info("Concatenating %s per-detector Source Tables", len(inputs["inputCatalogs"])) - table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact") + table = TableVStack.vstack_handles(inputs["inputCatalogs"]) butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs)