Skip to content

Commit

Permalink
feat: Stubs and RowGroups now inherit from Sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Oct 26, 2023
1 parent 6343803 commit e38a7e6
Showing 1 changed file with 67 additions and 30 deletions.
97 changes: 67 additions & 30 deletions gt/_gt_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
from __future__ import annotations

from typing import overload, TypeVar
from collections import abc
from dataclasses import dataclass


T = TypeVar("T")


class _Sequence(abc.Sequence[T]):
_d: list[T]

def __init__(self, data: Any):
self._d = data

@overload
def __getitem__(self, ii: int) -> T:
...

@overload
def __getitem__(self, ii: slice) -> _Sequence[T]:
...

def __getitem__(self, ii: int | slice) -> T | _Sequence[T]:
if isinstance(ii, slice):
return self.__class__(self._d[ii])

return self._d[ii]

def __len__(self):
return len(self._d)

def __repr__(self):
return f"{type(self).__name__}({self._d.__repr__()})"


# Body ----
__Body = None

Expand Down Expand Up @@ -144,12 +179,13 @@ def _get_effective_number_of_columns(self) -> int:
from ._tbl_data import TblData, n_rows


@dataclass
class RowInfo:
# TODO: Make `rownum_i` readonly
rownum_i: int
group_id: Optional[str]
rowname: Optional[str]
group_label: Optional[str]
group_id: Optional[str] = None
rowname: Optional[str] = None
group_label: Optional[str] = None
built: bool = False

# The components of the stub are:
Expand All @@ -159,30 +195,23 @@ class RowInfo:
# `group_label` = None
# `built` = False

def __init__(
self,
rownum_i: int,
group_id: Optional[str] = None,
rowname: Optional[str] = None,
group_label: Optional[str] = None,
built: bool = False,
):
self.rownum_i = rownum_i
self.group_id = group_id
self.rowname = rowname
self.group_label = group_label
self.built = built

class Stub(_Sequence[RowInfo]):
_d: list[RowInfo]

class Stub:
def __init__(self, data: TblData):
# Obtain a list of row indices from the data and initialize
# the `_stub` from that
row_indices = list(range(n_rows(data)))
def __init__(self, data: TblData | list[RowInfo]):
if isinstance(data, list):
self._d = list(data)

else:
# Obtain a list of row indices from the data and initialize
# the `_stub` from that
row_indices = list(range(n_rows(data)))

# Obtain the column names from the data and initialize the
# `_boxhead` from that
self._d = [RowInfo(col) for col in row_indices]

# Obtain the column names from the data and initialize the
# `_boxhead` from that
self._stub: list[RowInfo] = [RowInfo(col) for col in row_indices]


# Row groups ----
Expand All @@ -191,11 +220,14 @@ def __init__(self, data: TblData):
from typing import Optional


class RowGroups:
row_groups: Optional[str]
class RowGroups(_Sequence[str]):
_d: list[str]

def __init__(self):
pass
def __init__(self, group_ids: Optional[list[str]] = None):
if group_ids is None:
self._d = []
else:
self._d = group_ids


# Spanners ----
Expand Down Expand Up @@ -723,12 +755,17 @@ class GTData:

@classmethod
def from_data(cls, data: TblData, locale: str | None = None):
stub = Stub(data)

group_ids = set(row.group_id for row in stub if row.group_id is not None)
row_groups = list(group_ids)

return cls(
_tbl_data=data,
_body=Body(data, data),
_boxhead=Boxhead(data), # uses get_tbl_data()
_stub=Stub(data), # uses get_tbl_data
_row_groups=RowGroups(),
_stub=stub, # uses get_tbl_data
_row_groups=RowGroups(row_groups),
_spanners=Spanners(),
_heading=Heading(),
_stubhead=Stubhead(),
Expand Down

0 comments on commit e38a7e6

Please sign in to comment.