diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..074cd1cdf7 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -91,16 +91,21 @@ def __init__(self, else: scaler_type = GradScaler + enable_loss_scaler = dtype != torch.bfloat16 + if loss_scale == 'dynamic': # If loss_scale is a string, it must be 'dynamic', then dynamic # loss scaling will be used. - self.loss_scaler = scaler_type() + self.loss_scaler = scaler_type(enabled=enable_loss_scaler) elif isinstance(loss_scale, float): # Static loss scaling self._scale_update_param = loss_scale - self.loss_scaler = scaler_type(init_scale=loss_scale) + self.loss_scaler = scaler_type( + init_scale=loss_scale, enabled=enable_loss_scaler) elif isinstance(loss_scale, dict): # More specific configuration. + loss_scale[ + 'enabled'] = loss_scale['enabled'] and enable_loss_scaler self.loss_scaler = scaler_type(**loss_scale) else: raise TypeError('loss_scale must be of type float, dict, or '