Skip to content

Commit

Permalink
[BugFix & Feature]
Browse files Browse the repository at this point in the history
1.Fix set_max_subnet or set_min_subnet not found exception
2 Add the fleibility to define custom subnet kinds
3.Support to specifiy None with ResorceEstimator

Signed-off-by: Ming-Hsuan-Tu <[email protected]>
  • Loading branch information
twmht committed Mar 11, 2023
1 parent 9446b30 commit eb856cb
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 24 deletions.
53 changes: 29 additions & 24 deletions mmrazor/engine/runner/subnet_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only
or not. Defaults to False.
fix_subnet_kind (str): fix subnet kinds when evaluate, this would be
`sample_kinds` if not specified
calibrate_sample_num (int): The number of images to compute the true
average of per-batch mean/variance instead of the running average.
Defaults to 4096.
Expand All @@ -36,7 +36,7 @@ def __init__(
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False,
evaluate_fixed_subnet: bool = False,
fix_subnet_kinds: List[str] = [],
calibrate_sample_num: int = 4096,
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
) -> None:
Expand All @@ -48,9 +48,18 @@ def __init__(
model = self.runner.model

self.model = model
self.evaluate_fixed_subnet = evaluate_fixed_subnet
if fix_subnet_kinds is None and not hasattr(self.model,
'sample_kinds'):
raise ValueError(
'neither fix_subnet_kinds nor self.model.sample_kinds exists')

self.evaluate_kinds = fix_subnet_kinds if len(
fix_subnet_kinds) > 0 else getattr(self.model, 'sample_kinds')

self.calibrate_sample_num = calibrate_sample_num
self.estimator = TASK_UTILS.build(estimator_cfg)
self.estimator = None
if estimator_cfg:
self.estimator = TASK_UTILS.build(estimator_cfg)

def run(self):
"""Launch validation."""
Expand All @@ -59,24 +68,19 @@ def run(self):

all_metrics = dict()

if self.evaluate_fixed_subnet:
for kind in self.evaluate_kinds:
if kind == 'max':
self.model.mutator.set_max_choices()
elif kind == 'min':
self.model.mutator.set_min_choices()
elif 'random' in kind:
self.model.mutator.set_choices(
self.model.mutator.sample_choices())
else:
raise NotImplementedError(f'Unsupported Subnet {kind}')

metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
elif hasattr(self.model, 'sample_kinds'):
for kind in self.model.sample_kinds:
if kind == 'max':
self.model.mutator.set_max_choices()
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'max_subnet'))
elif kind == 'min':
self.model.mutator.set_min_choices()
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, 'min_subnet'))
elif 'random' in kind:
self.model.mutator.set_choices(
self.model.mutator.sample_choices())
metrics = self._evaluate_once()
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))

self.runner.call_hook('after_val_epoch', metrics=all_metrics)
self.runner.call_hook('after_val')
Expand All @@ -90,7 +94,8 @@ def _evaluate_once(self) -> Dict:
self.run_iter(idx, data_batch)

metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
resource_metrics = self.estimator.estimate(self.model)
metrics.update(resource_metrics)
if self.estimator:
resource_metrics = self.estimator.estimate(self.model)
metrics.update(resource_metrics)

return metrics
50 changes: 50 additions & 0 deletions tests/test_engine/test_runner/test_subnet_val_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, call, patch

from mmrazor.engine.runner import SubnetValLoop

class TestSubnetValLoop(TestCase):
def test_subnet_val_loop():
runner = MagicMock()
runner.distributed = False
runner.model = MagicMock()
dataloader = MagicMock()
evaluator = [MagicMock()]
fix_subnet_kinds = ['max', 'min']
loop = SubnetValLoop(
runner,
dataloader,
evaluator,
fix_subnet_kinds=fix_subnet_kinds,
estimator_cfg=None)
runner.train_dataloader = MagicMock()
with patch.object(loop, '_evaluate_once') as evaluate_mock:
evaluate_mock.return_value = dict(acc=10)
all_metrics = dict()
all_metrics['max_subnet.acc'] = 10
all_metrics['min_subnet.acc'] = 10
loop.run()
loop.runner.call_hook.assert_has_calls([
call('before_val'),
call('before_val_epoch'),
call('after_val_epoch', metrics=all_metrics),
call('after_val')
])
evaluate_mock.assert_has_calls([call(), call()])

runner.dataloader = MagicMock()
runner.dataloader.dataset = MagicMock()
loop.dataloader.__iter__.return_value = ['data_batch1']
with patch.object(loop, 'calibrate_bn_statistics') as calibration_bn_mock:
with patch.object(loop, 'run_iter') as run_iter_mock:
eval_result = dict(acc=10)
loop.evaluator.evaluate.return_value = eval_result
result = loop._evaluate_once()
calibration_bn_mock.assert_called_with(runner.train_dataloader,
loop.calibrate_sample_num)
runner.model.eval.assert_called()
run_iter_mock.assert_called_with(0, 'data_batch1')
loop.evaluator.evaluate.assert_called_with(
len(runner.dataloader.dataset))
assert result == eval_result

0 comments on commit eb856cb

Please sign in to comment.