diff --git a/gbmi/training_tools/logging.py b/gbmi/training_tools/logging.py index 6c0d217..b0c29cf 100644 --- a/gbmi/training_tools/logging.py +++ b/gbmi/training_tools/logging.py @@ -110,6 +110,7 @@ def plot_tensors( plot_1D_kind: Literal["line", "scatter"] = "line", title="Subplots of Matrices", groups: Optional[Collection[Collection[str]]] = None, + default_heatmap_kwargs: dict[str, Any] = {}, **kwargs, ) -> go.Figure: # Calculate grid size based on the number of matrices @@ -149,7 +150,11 @@ def plot_tensors( elif len(matrix.shape) == 2: # 2D data - heatmap fig.add_trace( - go.Heatmap(z=matrix, name=name, **zmax_zmin_args.get(name, {})), + go.Heatmap( + z=matrix, + name=name, + **zmax_zmin_args.get(name, default_heatmap_kwargs), + ), row=row, col=col, ) @@ -992,6 +997,7 @@ def log_matrices( model: HookedTransformer, *, unsafe: bool = False, + default_heatmap_kwargs: dict[str, Any] = {}, **kwargs, ): matrices = dict(self.matrices_to_log(model, unsafe=unsafe)) @@ -1021,6 +1027,7 @@ def log_matrices( if self.group_colorbars else None ), + default_heatmap_kwargs=default_heatmap_kwargs, ) } if matrices