Skip to content

Commit

Permalink
Allow passing a list of functions to GT.pipes()
Browse files Browse the repository at this point in the history
  • Loading branch information
jrycw committed May 25, 2024
1 parent 1623719 commit f428d25
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
24 changes: 21 additions & 3 deletions great_tables/_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def tbl_style(gtbl: GT, columns: list[str], colors: list[str]) -> GT:
return func(self, *args, **kwargs)


def pipes(self: "GT", *funcs: Callable["GT", "GT"]) -> "GT":
def pipes(self: "GT", *funcs: Callable["GT", "GT"] | list[Callable["GT", "GT"]]) -> "GT":
"""
Provide a structured way to chain functions for a GT object.
Expand All @@ -105,7 +105,8 @@ def pipes(self: "GT", *funcs: Callable["GT", "GT"]) -> "GT":
Parameters
----------
*funcs
Multiple functions, each receiving a GT object and returning a GT object.
Multiple functions or a list of functions, each receiving a GT object and returning a GT
object.
Returns
-------
Expand Down Expand Up @@ -173,11 +174,28 @@ def tbl_style(gtbl: GT, column: str, color: str) -> GT:
towny_mini[["name", "land_area_km2", "density_2021"]],
rowname_col="name",
).pipes(
*[partial(tbl_style, column=column, color=color) for column, color in zip(columns, colors)]
*[partial(tbl_style, column=column, color=color)
for column, color in zip(columns, colors)]
)
)
```
Alternatively, you can collect all the functions in a list like this:
```{python}
(
GT(
towny_mini[["name", "land_area_km2", "density_2021"]],
rowname_col="name",
).pipes(
[partial(tbl_style, column=column, color=color)
for column, color in zip(columns, colors)]
)
)
```
"""
if isinstance(funcs[0], list) and len(funcs) == 1:
funcs = funcs[0]
for func in funcs:
self = pipe(self, func)
return self
6 changes: 5 additions & 1 deletion tests/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ def tbl_style(gtbl: GT, column: str, color: str) -> GT:
*[partial(tbl_style, column=column, color=color) for column, color in zip(columns, colors)]
)

assert gt1._styles == gt2._styles
gt3 = GT(df).pipes(
[partial(tbl_style, column=column, color=color) for column, color in zip(columns, colors)]
)

assert gt1._styles == gt2._styles == gt3._styles

0 comments on commit f428d25

Please sign in to comment.