Skip to content

Commit

Permalink
[Fix] Failed to remove the previous best checkpoints (#1086)
Browse files Browse the repository at this point in the history
* [Fix] Only reserve one best checkpoint

* [Fix] Only reserve one best checkpoint

* Fix unit test

* shutdown logging

* clean the save_checkpoint logic
  • Loading branch information
HAOCHENYE authored Apr 20, 2023
1 parent 6ebb6f8 commit f1aca8e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 10 deletions.
8 changes: 4 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,9 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
runner.message_hub.update_info(best_score_key, best_score)

if best_ckpt_path and \
self.file_client.isfile(best_ckpt_path) and \
self.file_backend.isfile(best_ckpt_path) and \
is_main_process():
self.file_client.remove(best_ckpt_path)
self.file_backend.remove(best_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} '
'is removed')
Expand All @@ -490,13 +490,13 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
# Replace illegal characters for filename with `_`
best_ckpt_name = best_ckpt_name.replace('/', '_')
if len(self.key_indicators) == 1:
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(runtime_best_ckpt_key,
self.best_ckpt_path)
else:
self.best_ckpt_path_dict[
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(
runtime_best_ckpt_key,
Expand Down
6 changes: 5 additions & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,7 +2191,11 @@ def save_checkpoint(
checkpoint['param_schedulers'].append(state_dict)

self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath)
save_checkpoint(
checkpoint,
filepath,
file_client_args=file_client_args,
backend_args=backend_args)

@master_only
def dump_config(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions mmengine/testing/runner_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import shutil
import tempfile
import time
from unittest import TestCase
Expand Down Expand Up @@ -184,3 +185,12 @@ def setup_dist_env(self):
os.environ['RANK'] = self.dist_cfg['RANK']
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']

def clear_work_dir(self):
logging.shutdown()
for filename in os.listdir(self.temp_dir.name):
filepath = os.path.join(self.temp_dir.name, filename)
if os.path.isfile(filepath):
os.remove(filepath)
else:
shutil.rmtree(filepath)
50 changes: 50 additions & 0 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import os.path as osp
import re
import sys
from unittest.mock import MagicMock, patch

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -312,6 +314,54 @@ def test_after_val_epoch(self):
self.assertFalse(
osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))

# There should only one best checkpoint be reserved
# dist backend
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
self.clear_work_dir()
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=by_epoch, save_best='acc')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
all_files = os.listdir(runner.work_dir)
best_ckpts = [
file for file in all_files if file.startswith('best')
]
self.assertTrue(len(best_ckpts) == 1)

# petrel backend
# TODO use real petrel oss bucket to test
petrel_client = MagicMock()
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
isfile = MagicMock(return_value=True)
self.clear_work_dir()
with patch.dict(sys.modules, {'petrel_client': petrel_client}), \
patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \
patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \
patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
metrics = dict(acc=0.5)
petrel_client.client.Client = MagicMock(
return_value=petrel_client)
checkpoint_hook = CheckpointHook(
interval=2,
by_epoch=by_epoch,
save_best='acc',
backend_args=dict(backend='petrel'))
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
put_mock.assert_called_once()
metrics['acc'] += 0.1
runner.train_loop._epoch += 1
runner.train_loop._iter += 1
checkpoint_hook.after_val_epoch(runner, metrics)
isfile.assert_called_once()
remove_mock.assert_called_once()

def test_after_train_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_runner/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
load_from_local, load_from_pavi,
save_checkpoint)

sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()


@MODEL_WRAPPERS.register_module()
class DDPWrapper:
Expand Down Expand Up @@ -150,9 +147,8 @@ def test_get_state_dict():
wrapped_model.module.conv.module.bias)


@patch.dict(sys.modules, {'pavi': MagicMock()})
def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel()
import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
Expand Down Expand Up @@ -296,6 +292,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)


@patch.dict(sys.modules, {'petrel_client': MagicMock()})
def test_checkpoint_loader():
filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
Expand Down

0 comments on commit f1aca8e

Please sign in to comment.