From 57e9f2a8db43eaa62d8701ef456f8323e9bcb8ff Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 15:27:05 +0800 Subject: [PATCH] Add the "rms-sort" diagnostics (#1851) --- icefall/diagnostics.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 37872f2331..e5eaba619e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -63,12 +63,22 @@ def get_tensor_stats( "rms" -> square before summing, we'll take sqrt later "value" -> just sum x itself "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing + "rms-sort" -> this is a bit different than the others, it's based on computing the + rms over the specified dim and returning percentiles of the result (11 of them). Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element of stats. """ + if stats_type == "rms-sort": + rms = (x**2).mean(dim=dim).sqrt() + rms = rms.flatten() + rms = rms.sort()[0] + rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] + count = 1.0 + return rms, count + count = x.numel() // x.shape[dim] if stats_type == "eigs": @@ -164,7 +174,17 @@ def accumulate(self, x, class_name: Optional[str] = None): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "max", "min", "positive", "value", "rms"] + # rms-sort is different from the others, it's based on summing over just this + # dim, then sorting and returning the percentiles. + stats_types = [ + "abs", + "max", + "min", + "positive", + "value", + "rms", + "rms-sort", + ] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: