From e38a7e628f15d9a7a40d499597bb9111190aa153 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 26 Oct 2023 16:19:57 -0400 Subject: [PATCH] feat: Stubs and RowGroups now inherit from Sequence --- gt/_gt_data.py | 97 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/gt/_gt_data.py b/gt/_gt_data.py index 1f0a47df7..29819df5f 100644 --- a/gt/_gt_data.py +++ b/gt/_gt_data.py @@ -1,5 +1,40 @@ from __future__ import annotations +from typing import overload, TypeVar +from collections import abc +from dataclasses import dataclass + + +T = TypeVar("T") + + +class _Sequence(abc.Sequence[T]): + _d: list[T] + + def __init__(self, data: Any): + self._d = data + + @overload + def __getitem__(self, ii: int) -> T: + ... + + @overload + def __getitem__(self, ii: slice) -> _Sequence[T]: + ... + + def __getitem__(self, ii: int | slice) -> T | _Sequence[T]: + if isinstance(ii, slice): + return self.__class__(self._d[ii]) + + return self._d[ii] + + def __len__(self): + return len(self._d) + + def __repr__(self): + return f"{type(self).__name__}({self._d.__repr__()})" + + # Body ---- __Body = None @@ -144,12 +179,13 @@ def _get_effective_number_of_columns(self) -> int: from ._tbl_data import TblData, n_rows +@dataclass class RowInfo: # TODO: Make `rownum_i` readonly rownum_i: int - group_id: Optional[str] - rowname: Optional[str] - group_label: Optional[str] + group_id: Optional[str] = None + rowname: Optional[str] = None + group_label: Optional[str] = None built: bool = False # The components of the stub are: @@ -159,30 +195,23 @@ class RowInfo: # `group_label` = None # `built` = False - def __init__( - self, - rownum_i: int, - group_id: Optional[str] = None, - rowname: Optional[str] = None, - group_label: Optional[str] = None, - built: bool = False, - ): - self.rownum_i = rownum_i - self.group_id = group_id - self.rowname = rowname - self.group_label = group_label - self.built = built +class Stub(_Sequence[RowInfo]): + _d: list[RowInfo] -class Stub: - def __init__(self, data: TblData): - # Obtain a list of row indices from the data and initialize - # the `_stub` from that - row_indices = list(range(n_rows(data))) + def __init__(self, data: TblData | list[RowInfo]): + if isinstance(data, list): + self._d = list(data) + + else: + # Obtain a list of row indices from the data and initialize + # the `_stub` from that + row_indices = list(range(n_rows(data))) + + # Obtain the column names from the data and initialize the + # `_boxhead` from that + self._d = [RowInfo(col) for col in row_indices] - # Obtain the column names from the data and initialize the - # `_boxhead` from that - self._stub: list[RowInfo] = [RowInfo(col) for col in row_indices] # Row groups ---- @@ -191,11 +220,14 @@ def __init__(self, data: TblData): from typing import Optional -class RowGroups: - row_groups: Optional[str] +class RowGroups(_Sequence[str]): + _d: list[str] - def __init__(self): - pass + def __init__(self, group_ids: Optional[list[str]] = None): + if group_ids is None: + self._d = [] + else: + self._d = group_ids # Spanners ---- @@ -723,12 +755,17 @@ class GTData: @classmethod def from_data(cls, data: TblData, locale: str | None = None): + stub = Stub(data) + + group_ids = set(row.group_id for row in stub if row.group_id is not None) + row_groups = list(group_ids) + return cls( _tbl_data=data, _body=Body(data, data), _boxhead=Boxhead(data), # uses get_tbl_data() - _stub=Stub(data), # uses get_tbl_data - _row_groups=RowGroups(), + _stub=stub, # uses get_tbl_data + _row_groups=RowGroups(row_groups), _spanners=Spanners(), _heading=Heading(), _stubhead=Stubhead(),