diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 15fd8ac66a..bae8bd65bc 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -479,9 +479,9 @@ def _save_best_checkpoint(self, runner, metrics) -> None: runner.message_hub.update_info(best_score_key, best_score) if best_ckpt_path and \ - self.file_client.isfile(best_ckpt_path) and \ + self.file_backend.isfile(best_ckpt_path) and \ is_main_process(): - self.file_client.remove(best_ckpt_path) + self.file_backend.remove(best_ckpt_path) runner.logger.info( f'The previous best checkpoint {best_ckpt_path} ' 'is removed') @@ -490,13 +490,13 @@ def _save_best_checkpoint(self, runner, metrics) -> None: # Replace illegal characters for filename with `_` best_ckpt_name = best_ckpt_name.replace('/', '_') if len(self.key_indicators) == 1: - self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 + self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501 self.out_dir, best_ckpt_name) runner.message_hub.update_info(runtime_best_ckpt_key, self.best_ckpt_path) else: self.best_ckpt_path_dict[ - key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501 + key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501 self.out_dir, best_ckpt_name) runner.message_hub.update_info( runtime_best_ckpt_key, diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index e5b2f755d0..60b40a7e07 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -2191,7 +2191,11 @@ def save_checkpoint( checkpoint['param_schedulers'].append(state_dict) self.call_hook('before_save_checkpoint', checkpoint=checkpoint) - save_checkpoint(checkpoint, filepath) + save_checkpoint( + checkpoint, + filepath, + file_client_args=file_client_args, + backend_args=backend_args) @master_only def dump_config(self) -> None: diff --git a/mmengine/testing/runner_test_case.py b/mmengine/testing/runner_test_case.py index e9dc5acbc6..a05c41d3e8 100644 --- a/mmengine/testing/runner_test_case.py +++ b/mmengine/testing/runner_test_case.py @@ -2,6 +2,7 @@ import copy import logging import os +import shutil import tempfile import time from unittest import TestCase @@ -184,3 +185,12 @@ def setup_dist_env(self): os.environ['RANK'] = self.dist_cfg['RANK'] os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] + + def clear_work_dir(self): + logging.shutdown() + for filename in os.listdir(self.temp_dir.name): + filepath = os.path.join(self.temp_dir.name, filename) + if os.path.isfile(filepath): + os.remove(filepath) + else: + shutil.rmtree(filepath) diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d51d13d172..6a80dcdd12 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -3,6 +3,8 @@ import os import os.path as osp import re +import sys +from unittest.mock import MagicMock, patch import torch from parameterized import parameterized @@ -312,6 +314,54 @@ def test_after_val_epoch(self): self.assertFalse( osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))) + # There should only one best checkpoint be reserved + # dist backend + for by_epoch, cfg in [(True, self.epoch_based_cfg), + (False, self.iter_based_cfg)]: + self.clear_work_dir() + cfg = copy.deepcopy(cfg) + runner = self.build_runner(cfg) + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=by_epoch, save_best='acc') + checkpoint_hook.before_train(runner) + checkpoint_hook.after_val_epoch(runner, metrics) + all_files = os.listdir(runner.work_dir) + best_ckpts = [ + file for file in all_files if file.startswith('best') + ] + self.assertTrue(len(best_ckpts) == 1) + + # petrel backend + # TODO use real petrel oss bucket to test + petrel_client = MagicMock() + for by_epoch, cfg in [(True, self.epoch_based_cfg), + (False, self.iter_based_cfg)]: + isfile = MagicMock(return_value=True) + self.clear_work_dir() + with patch.dict(sys.modules, {'petrel_client': petrel_client}), \ + patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \ + patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \ + patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501 + cfg = copy.deepcopy(cfg) + runner = self.build_runner(cfg) + metrics = dict(acc=0.5) + petrel_client.client.Client = MagicMock( + return_value=petrel_client) + checkpoint_hook = CheckpointHook( + interval=2, + by_epoch=by_epoch, + save_best='acc', + backend_args=dict(backend='petrel')) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_val_epoch(runner, metrics) + put_mock.assert_called_once() + metrics['acc'] += 0.1 + runner.train_loop._epoch += 1 + runner.train_loop._iter += 1 + checkpoint_hook.after_val_epoch(runner, metrics) + isfile.assert_called_once() + remove_mock.assert_called_once() + def test_after_train_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index fd7d6d286d..65ebb17b48 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -20,9 +20,6 @@ load_from_local, load_from_pavi, save_checkpoint) -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() - @MODEL_WRAPPERS.register_module() class DDPWrapper: @@ -150,9 +147,8 @@ def test_get_state_dict(): wrapped_model.module.conv.module.bias) +@patch.dict(sys.modules, {'pavi': MagicMock()}) def test_load_pavimodel_dist(): - sys.modules['pavi'] = MagicMock() - sys.modules['pavi.modelcloud'] = MagicMock() pavimodel = Mockpavimodel() import pavi pavi.modelcloud.get = MagicMock(return_value=pavimodel) @@ -296,6 +292,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) +@patch.dict(sys.modules, {'petrel_client': MagicMock()}) def test_checkpoint_loader(): filenames = [ 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',