diff --git a/gbmi/training_tools/logging.py b/gbmi/training_tools/logging.py index b0c29cf..0b03470 100644 --- a/gbmi/training_tools/logging.py +++ b/gbmi/training_tools/logging.py @@ -52,11 +52,31 @@ def str_mean(s: str) -> str: return f"𝔼({s})" +def calculate_zmax_zmin_args_one(ms: Collection[Tensor], **kwargs) -> dict[str, Any]: + """Computes zmax and zmin by grouping matrices""" + result: dict[str, float] = {} + result = kwargs | { + "zmax": max(m[~m.isnan()].max().item() for m in ms), + "zmin": min(m[~m.isnan()].min().item() for m in ms), + } + if "zmid" in kwargs: + zhalfrange = np.max( + ( + np.abs(result["zmax"] - result["zmid"]), + np.abs(result["zmin"] - result["zmid"]), + ) + ) + result["zmax"] = result["zmid"] + zhalfrange + result["zmin"] = result["zmid"] - zhalfrange + return result + + def calculate_zmax_zmin_args( matrices: Iterable[Tuple[str, Tensor]], groups: Optional[ Union[Collection[Collection[str]], dict[Collection[str], dict[str, Any]]] ] = None, + **kwargs, ) -> dict[str, dict[str, Any]]: """Computes zmax and zmin by grouping matrices""" if groups is None: @@ -67,37 +87,16 @@ def calculate_zmax_zmin_args( for name in group: groups_map[name] = i if isinstance(groups, dict): - groups_extra_args[i] = groups[group] + groups_extra_args[i] = kwargs | groups[group] group_to_matrix_map: dict[Optional[int], list[Tensor]] = defaultdict(list) matrices = list(matrices) for name, matrix in matrices: group_to_matrix_map[groups_map.get(name)].append(matrix) zmax_zmin_args_by_group: dict[Optional[int], dict[str, float]] = {} for i, ms in group_to_matrix_map.items(): - zmax_zmin_args_by_group[i] = { - "zmax": max(m[~m.isnan()].max().item() for m in ms), - "zmin": min(m[~m.isnan()].min().item() for m in ms), - **groups_extra_args.get(i, {}), - } - if "zmid" in zmax_zmin_args_by_group[i]: - zhalfrange = np.max( - ( - np.abs( - zmax_zmin_args_by_group[i]["zmax"] - - zmax_zmin_args_by_group[i]["zmid"] - ), - np.abs( - zmax_zmin_args_by_group[i]["zmin"] - - zmax_zmin_args_by_group[i]["zmid"] - ), - ) - ) - zmax_zmin_args_by_group[i]["zmax"] = ( - zmax_zmin_args_by_group[i]["zmid"] + zhalfrange - ) - zmax_zmin_args_by_group[i]["zmin"] = ( - zmax_zmin_args_by_group[i]["zmid"] - zhalfrange - ) + zmax_zmin_args_by_group[i] = calculate_zmax_zmin_args_one( + ms, **groups_extra_args.get(i, kwargs) + ) zmax_zmin_args = {} for name, _ in matrices: zmax_zmin_args[name] = zmax_zmin_args_by_group[groups_map.get(name)] @@ -115,7 +114,9 @@ def plot_tensors( ) -> go.Figure: # Calculate grid size based on the number of matrices matrices = list(matrices) - zmax_zmin_args = calculate_zmax_zmin_args(matrices, groups=groups) + zmax_zmin_args = calculate_zmax_zmin_args( + matrices, groups=groups, **default_heatmap_kwargs + ) num_matrices = len(matrices) grid_size = int(np.ceil(np.sqrt(num_matrices))) subplot_titles = [name for name, _ in matrices] if len(matrices) else None @@ -153,7 +154,12 @@ def plot_tensors( go.Heatmap( z=matrix, name=name, - **zmax_zmin_args.get(name, default_heatmap_kwargs), + **zmax_zmin_args.get( + name, + calculate_zmax_zmin_args_one( + [matrix], **default_heatmap_kwargs + ), + ), ), row=row, col=col,