diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 1a88d96096..7cc75757fe 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -13,6 +13,24 @@ class SyncBuffersHook(Hook): def __init__(self) -> None: self.distributed = is_distributed() + # A flag to mark whether synchronization has been done in + # after_train_epoch + self.called_in_train = False + + def before_val_epoch(self, runner) -> None: + """All-reduce model buffers before each validation epoch. + + Synchronize the buffers before each validation if they have not been + synchronized at the end of the previous training epoch. This method + will be called when using IterBasedTrainLoop. + + Args: + runner (Runner): The runner of the training process. + """ + if self.distributed: + if not self.called_in_train: + all_reduce_params(runner.model.buffers(), op='mean') + self.called_in_train = False def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. @@ -22,3 +40,4 @@ def after_train_epoch(self, runner) -> None: """ if self.distributed: all_reduce_params(runner.model.buffers(), op='mean') + self.called_in_train = True