Skip to content

Commit

Permalink
Add the "rms-sort" diagnostics (#1851)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu-han authored Dec 30, 2024
1 parent ad966fb commit 57e9f2a
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion icefall/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,22 @@ def get_tensor_stats(
"rms" -> square before summing, we'll take sqrt later
"value" -> just sum x itself
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
"rms-sort" -> this is a bit different than the others, it's based on computing the
rms over the specified dim and returning percentiles of the result (11 of them).
Returns:
stats: a Tensor of shape (x.shape[dim],).
count: an integer saying how many items were counted in each element
of stats.
"""

if stats_type == "rms-sort":
rms = (x**2).mean(dim=dim).sqrt()
rms = rms.flatten()
rms = rms.sort()[0]
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
count = 1.0
return rms, count

count = x.numel() // x.shape[dim]

if stats_type == "eigs":
Expand Down Expand Up @@ -164,7 +174,17 @@ def accumulate(self, x, class_name: Optional[str] = None):
for dim in range(ndim):
this_dim_stats = self.stats[dim]
if ndim > 1:
stats_types = ["abs", "max", "min", "positive", "value", "rms"]
# rms-sort is different from the others, it's based on summing over just this
# dim, then sorting and returning the percentiles.
stats_types = [
"abs",
"max",
"min",
"positive",
"value",
"rms",
"rms-sort",
]
if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else:
Expand Down

0 comments on commit 57e9f2a

Please sign in to comment.