Skip to content

Commit

Permalink
[Fix] Call SyncBufferHook before validation in IterBasedTrainLoop (#982)
Browse files Browse the repository at this point in the history
* [Fix] Call SyncBufferHook before validation in IterBasedTrainLoop

* Add before_val_epoch in SyncBuffersHook

* Fix white space format

* Add comments for SyncBuffersHook

* Add comments for SyncBuffersHook

Co-authored-by: Zaida Zhou <[email protected]>

* Add comments for SyncBuffersHook

Co-authored-by: Zaida Zhou <[email protected]>

* Fix white space format

* Add before_test_epoch

* Remove before_test_epoch

---------

Co-authored-by: Zaida Zhou <[email protected]>
  • Loading branch information
Luo-Yihang and zhouzaida authored Apr 20, 2023
1 parent 0e5f9da commit 6ebb6f8
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions mmengine/hooks/sync_buffer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ class SyncBuffersHook(Hook):

def __init__(self) -> None:
self.distributed = is_distributed()
# A flag to mark whether synchronization has been done in
# after_train_epoch
self.called_in_train = False

def before_val_epoch(self, runner) -> None:
"""All-reduce model buffers before each validation epoch.
Synchronize the buffers before each validation if they have not been
synchronized at the end of the previous training epoch. This method
will be called when using IterBasedTrainLoop.
Args:
runner (Runner): The runner of the training process.
"""
if self.distributed:
if not self.called_in_train:
all_reduce_params(runner.model.buffers(), op='mean')
self.called_in_train = False

def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch.
Expand All @@ -22,3 +40,4 @@ def after_train_epoch(self, runner) -> None:
"""
if self.distributed:
all_reduce_params(runner.model.buffers(), op='mean')
self.called_in_train = True

0 comments on commit 6ebb6f8

Please sign in to comment.