Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jan 30, 2023
1 parent 39e6295 commit 062ffe7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mmrazor.registry import TASK_UTILS
from mmrazor.utils import get_placeholder
from ...algorithms.base import BaseAlgorithm
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
DefaultMMDemoInput, DefaultMMDetDemoInput,
DefaultMMPoseDemoInput, DefaultMMRotateDemoInput,
Expand Down Expand Up @@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope):

def defaul_demo_inputs(model, input_shape, training=False, scope=None):
"""Get demo input according to a model and scope."""
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)
if isinstance(model, BaseAlgorithm):
return defaul_demo_inputs(model.architecture, input_shape, training,
scope)
else:
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)


@TASK_UTILS.register_module()
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/models/task_modules/demo_inputs/demo_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _get_data(self, model, input_shape=None, training=None):
return data

def _get_mm_data(self, model, input_shape, training=False):
return {'inputs': torch.rand(input_shape), 'data_samples': None}
data = {'inputs': torch.rand(input_shape), 'data_samples': None}
data = model.data_preprocessor(data, training)
return data


@TASK_UTILS.register_module()
Expand Down Expand Up @@ -132,7 +134,7 @@ def _get_mm_data(self, model, input_shape, training=False):
from mmpose.models import TopdownPoseEstimator

from .mmpose_demo_input import demo_mmpose_inputs
assert isinstance(model, TopdownPoseEstimator)
assert isinstance(model, TopdownPoseEstimator), f'{type(model)}'

data = demo_mmpose_inputs(model, input_shape)
return data
6 changes: 4 additions & 2 deletions projects/cores/hooks/prune_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def __init__(self,

def before_run(self, runner) -> None:
model = get_model_from_runner(runner)
self.origin_delta = self._evaluate(model)[self.delta_type]
print_log(f'get original {self.delta_type}: {self.origin_delta}')
original_resource = self._evaluate(model)
print_log(f'get original resource: {original_resource}')

self.origin_delta = original_resource[self.delta_type]

# save checkpoint

Expand Down

0 comments on commit 062ffe7

Please sign in to comment.