Skip to content

Commit

Permalink
Factor min max args
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Dec 20, 2024
1 parent 8d96f31 commit 59af0c6
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions gbmi/training_tools/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,31 @@ def str_mean(s: str) -> str:
return f"𝔼({s})"


def calculate_zmax_zmin_args_one(ms: Collection[Tensor], **kwargs) -> dict[str, Any]:
"""Computes zmax and zmin by grouping matrices"""
result: dict[str, float] = {}
result = kwargs | {
"zmax": max(m[~m.isnan()].max().item() for m in ms),
"zmin": min(m[~m.isnan()].min().item() for m in ms),
}
if "zmid" in kwargs:
zhalfrange = np.max(
(
np.abs(result["zmax"] - result["zmid"]),
np.abs(result["zmin"] - result["zmid"]),
)
)
result["zmax"] = result["zmid"] + zhalfrange
result["zmin"] = result["zmid"] - zhalfrange
return result


def calculate_zmax_zmin_args(
matrices: Iterable[Tuple[str, Tensor]],
groups: Optional[
Union[Collection[Collection[str]], dict[Collection[str], dict[str, Any]]]
] = None,
**kwargs,
) -> dict[str, dict[str, Any]]:
"""Computes zmax and zmin by grouping matrices"""
if groups is None:
Expand All @@ -67,37 +87,16 @@ def calculate_zmax_zmin_args(
for name in group:
groups_map[name] = i
if isinstance(groups, dict):
groups_extra_args[i] = groups[group]
groups_extra_args[i] = kwargs | groups[group]
group_to_matrix_map: dict[Optional[int], list[Tensor]] = defaultdict(list)
matrices = list(matrices)
for name, matrix in matrices:
group_to_matrix_map[groups_map.get(name)].append(matrix)
zmax_zmin_args_by_group: dict[Optional[int], dict[str, float]] = {}
for i, ms in group_to_matrix_map.items():
zmax_zmin_args_by_group[i] = {
"zmax": max(m[~m.isnan()].max().item() for m in ms),
"zmin": min(m[~m.isnan()].min().item() for m in ms),
**groups_extra_args.get(i, {}),
}
if "zmid" in zmax_zmin_args_by_group[i]:
zhalfrange = np.max(
(
np.abs(
zmax_zmin_args_by_group[i]["zmax"]
- zmax_zmin_args_by_group[i]["zmid"]
),
np.abs(
zmax_zmin_args_by_group[i]["zmin"]
- zmax_zmin_args_by_group[i]["zmid"]
),
)
)
zmax_zmin_args_by_group[i]["zmax"] = (
zmax_zmin_args_by_group[i]["zmid"] + zhalfrange
)
zmax_zmin_args_by_group[i]["zmin"] = (
zmax_zmin_args_by_group[i]["zmid"] - zhalfrange
)
zmax_zmin_args_by_group[i] = calculate_zmax_zmin_args_one(
ms, **groups_extra_args.get(i, kwargs)
)
zmax_zmin_args = {}
for name, _ in matrices:
zmax_zmin_args[name] = zmax_zmin_args_by_group[groups_map.get(name)]
Expand All @@ -115,7 +114,9 @@ def plot_tensors(
) -> go.Figure:
# Calculate grid size based on the number of matrices
matrices = list(matrices)
zmax_zmin_args = calculate_zmax_zmin_args(matrices, groups=groups)
zmax_zmin_args = calculate_zmax_zmin_args(
matrices, groups=groups, **default_heatmap_kwargs
)
num_matrices = len(matrices)
grid_size = int(np.ceil(np.sqrt(num_matrices)))
subplot_titles = [name for name, _ in matrices] if len(matrices) else None
Expand Down Expand Up @@ -153,7 +154,12 @@ def plot_tensors(
go.Heatmap(
z=matrix,
name=name,
**zmax_zmin_args.get(name, default_heatmap_kwargs),
**zmax_zmin_args.get(
name,
calculate_zmax_zmin_args_one(
[matrix], **default_heatmap_kwargs
),
),
),
row=row,
col=col,
Expand Down

0 comments on commit 59af0c6

Please sign in to comment.