Skip to content

Commit

Permalink
disable grad scaler when dtype is bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jun 7, 2024
1 parent 85c83ba commit 6d4b725
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,21 @@ def __init__(self,
else:
scaler_type = GradScaler

enable_loss_scaler = dtype != torch.bfloat16

if loss_scale == 'dynamic':
# If loss_scale is a string, it must be 'dynamic', then dynamic
# loss scaling will be used.
self.loss_scaler = scaler_type()
self.loss_scaler = scaler_type(enabled=enable_loss_scaler)
elif isinstance(loss_scale, float):
# Static loss scaling
self._scale_update_param = loss_scale
self.loss_scaler = scaler_type(init_scale=loss_scale)
self.loss_scaler = scaler_type(
init_scale=loss_scale, enabled=enable_loss_scaler)
elif isinstance(loss_scale, dict):
# More specific configuration.
loss_scale[
'enabled'] = loss_scale['enabled'] and enable_loss_scaler
self.loss_scaler = scaler_type(**loss_scale)
else:
raise TypeError('loss_scale must be of type float, dict, or '
Expand Down

0 comments on commit 6d4b725

Please sign in to comment.