Skip to content

Commit

Permalink
Improve infinity-check (#1862)
Browse files Browse the repository at this point in the history
1. Attach the inf-check hooks if the grad scale is getting too small.
2. Add try-catch to avoid OOM in the inf-check hooks.
3. Set warmup_start=0.1 to reduce chances of divergence
  • Loading branch information
zhu-han authored Jan 9, 2025
1 parent 8d60280 commit ab91112
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
25 changes: 18 additions & 7 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,23 +1165,34 @@ def save_bad_model(suffix: str = ""):
rank=rank,
)

if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if params.use_autocast:
cur_grad_scale = scaler._scale.item()

if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
if not params.inf_check:
register_inf_check_hooks(model)
logging.warning(f"Grad scale is small: {cur_grad_scale}")

if cur_grad_scale < 1.0e-05:
save_bad_model()
raise_grad_scale_is_too_small_error(cur_grad_scale)

# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if (
batch_idx % 25 == 0
and cur_grad_scale < 2.0
or batch_idx % 100 == 0
and cur_grad_scale < 8.0
or batch_idx % 400 == 0
and cur_grad_scale < 32.0
):
scaler.update(cur_grad_scale * 2.0)

if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
Expand Down Expand Up @@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
)

scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)

if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
Expand Down
26 changes: 18 additions & 8 deletions icefall/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
# default param _name is a way to capture the current value of the variable "name".
def forward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):
o = o[0]
if not isinstance(o, Tensor):
continue
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
try:
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.output[{i}] is not finite"
)
except RuntimeError: # e.g. CUDA out of memory
pass

# default param _name is a way to capture the current value of the variable "name".
def backward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.grad is not finite" # ": {_output}"
)
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.grad is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass

elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):
Expand Down

0 comments on commit ab91112

Please sign in to comment.