From 823d238f01637d7cc9c8dc67ac3bf87011500fa1 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 26 Jul 2023 19:45:51 +0800 Subject: [PATCH 01/39] Add recorder_hook and use ast to print assign node --- mmengine/hooks/__init__.py | 3 +- mmengine/hooks/recorder_hook.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 mmengine/hooks/recorder_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 746be6b02a..9ff1b21324 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -9,6 +9,7 @@ from .naive_visualization_hook import NaiveVisualizationHook from .param_scheduler_hook import ParamSchedulerHook from .profiler_hook import NPUProfilerHook, ProfilerHook +from .recorder_hook import RecorderHook from .runtime_info_hook import RuntimeInfoHook from .sampler_seed_hook import DistSamplerSeedHook from .sync_buffer_hook import SyncBuffersHook @@ -18,5 +19,5 @@ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook', - 'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook' + 'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook', 'RecorderHook' ] diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py new file mode 100644 index 0000000000..8af49f22c0 --- /dev/null +++ b/mmengine/hooks/recorder_hook.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import dis +import inspect +import textwrap +from typing import Any, List, Optional, Tuple, Union + +from mmengine.registry import HOOKS +from . import Hook + + +class FunctionRecorder(): + + def __init__(self): + self._data_buffer: List = list() + self.now_epoch = 0 + + @property + def data_buffer(self) -> List: + """list: data buffer.""" + return self._data_buffer + + def func_after_assign(self): + pass + + def next_epoch(self): + self.now_epoch += 1 + + def get_record_data(self, + record_idx: int = 0, + data_idx: Optional[int] = None) -> Any: + """Get data from ``data_buffer``. + + Args: + record_idx (int): The index of the record saved in + ``data_buffer``. If a source is executed N times during + forward, there will be N records in ``data_buffer``. + data_index (int, optional): The index of target data in + a record. A record may be a tuple or a list, if data_idx is + None, the whole list or tuple is returned. Defaults to None. + + Returns: + Any: The type of the return value is undefined, and different + source data may have different types. + """ + assert record_idx < len(self._data_buffer), \ + 'record_idx is illegal. The length of data_buffer is ' \ + f'{len(self._data_buffer)}, but record_idx is ' \ + f'{record_idx}.' + + record = self._data_buffer[record_idx] + + if data_idx is None: + target_data = record + else: + if isinstance(record, (list, tuple)): + assert data_idx < len(record), \ + 'data_idx is illegal. The length of record is ' \ + f'{len(record)}, but data_idx is {data_idx}.' + target_data = record[data_idx] + else: + raise TypeError('When data_idx is not None, record should be ' + 'a list or tuple instance, but got ' + f'{type(record)}.') + + return target_data + + def reset_data_buffer(self) -> None: + """Clear data in data_buffer.""" + + self._data_buffer = list() + + +# model的 存到 runner的 message_hub +class RecorderAdder(ast.NodeTransformer): + + def visit_Assign(self, node): + # 这将创建一个新的print调用节点 + print_call = ast.Expr( + value=ast.Call( + func=ast.Name(id='print', ctx=ast.Load()), + args=[ + ast.Str(s='Assigning to variable '), + ast.Name(id=node.targets[0].id, ctx=ast.Load()) + ], + keywords=[])) + + # 插入print语句 + return [node, print_call] + + +# class RecorderAdder(ast.NodeTransformer): +# def visit_Assign(self, node): +# # 这将创建一个新的print调用节点 +# print_call = ast.Expr( +# value=ast.Call( +# func=ast.Name(id='RecorderHook', ctx=ast.Load()), +# attr='add2buffer', +# args=[ +# ast.Name(id=node.targets[0].id, ctx=ast.Load()) +# ], +# keywords=[] +# ) +# ) +# +# # 插入print语句 +# return [node, print_call] + + +@HOOKS.register_module() +class RecorderHook(Hook): + priority = 'LOWEST' + + recorder = FunctionRecorder() + + def __init__(self, ): + pass + + def _modify_func(self, func): + # 获取函数的源代码 + source = inspect.getsource(func) + source = textwrap.dedent(source) + + # 解析源代码为AST + tree = ast.parse(source) + + import_from_statement = ast.ImportFrom( + module='mmengine.hooks', + names=[ast.alias(name='RecorderHook', asname=None)], + level=0) + + tree.body[0].body.insert(0, import_from_statement) + + # 修改AST + tree = RecorderAdder().visit(tree) + tree = ast.fix_missing_locations(tree) + + # print(ast.dump(tree, indent=4)) + + # 编译修改后的AST为一个新的函数 + namespace = {} + exec( + compile(tree, filename='', mode='exec'), func.__globals__, + namespace) + return namespace[func.__name__] + + def before_run(self, runner) -> None: + """Check `stop_training` variable in `runner.train_loop`. + + Args: + runner (Runner): The runner of the training process. + """ + import dis + model = runner.model + print('---------------------------') + # breakpoint() + import types + + model.forward = types.MethodType( + self._modify_func(model.forward), model) \ No newline at end of file From 85cde7150599e94271cc28186e07c2ce55384a98 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 26 Jul 2023 20:50:22 +0800 Subject: [PATCH 02/39] use messagehub to store information --- mmengine/hooks/recorder_hook.py | 106 ++++++++------------------------ 1 file changed, 26 insertions(+), 80 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 8af49f22c0..941f7534bf 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -1,92 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. import ast import dis +import types import inspect import textwrap from typing import Any, List, Optional, Tuple, Union from mmengine.registry import HOOKS +from mmengine.logging import MessageHub, HistoryBuffer from . import Hook - -class FunctionRecorder(): - - def __init__(self): - self._data_buffer: List = list() - self.now_epoch = 0 - - @property - def data_buffer(self) -> List: - """list: data buffer.""" - return self._data_buffer - - def func_after_assign(self): - pass - - def next_epoch(self): - self.now_epoch += 1 - - def get_record_data(self, - record_idx: int = 0, - data_idx: Optional[int] = None) -> Any: - """Get data from ``data_buffer``. - - Args: - record_idx (int): The index of the record saved in - ``data_buffer``. If a source is executed N times during - forward, there will be N records in ``data_buffer``. - data_index (int, optional): The index of target data in - a record. A record may be a tuple or a list, if data_idx is - None, the whole list or tuple is returned. Defaults to None. - - Returns: - Any: The type of the return value is undefined, and different - source data may have different types. - """ - assert record_idx < len(self._data_buffer), \ - 'record_idx is illegal. The length of data_buffer is ' \ - f'{len(self._data_buffer)}, but record_idx is ' \ - f'{record_idx}.' - - record = self._data_buffer[record_idx] - - if data_idx is None: - target_data = record - else: - if isinstance(record, (list, tuple)): - assert data_idx < len(record), \ - 'data_idx is illegal. The length of record is ' \ - f'{len(record)}, but data_idx is {data_idx}.' - target_data = record[data_idx] - else: - raise TypeError('When data_idx is not None, record should be ' - 'a list or tuple instance, but got ' - f'{type(record)}.') - - return target_data - - def reset_data_buffer(self) -> None: - """Clear data in data_buffer.""" - - self._data_buffer = list() - - # model的 存到 runner的 message_hub class RecorderAdder(ast.NodeTransformer): def visit_Assign(self, node): - # 这将创建一个新的print调用节点 - print_call = ast.Expr( - value=ast.Call( - func=ast.Name(id='print', ctx=ast.Load()), - args=[ - ast.Str(s='Assigning to variable '), - ast.Name(id=node.targets[0].id, ctx=ast.Load()) - ], - keywords=[])) + add2messagehub = ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='message_hub', ctx=ast.Load()), attr='update_info', ctx=ast.Load()), + args=[ast.Constant(value='task'), ast.Name(id=node.targets[0].id, ctx=ast.Load())], keywords=[])) # 插入print语句 - return [node, print_call] + return [node, add2messagehub] # class RecorderAdder(ast.NodeTransformer): @@ -111,11 +43,14 @@ def visit_Assign(self, node): class RecorderHook(Hook): priority = 'LOWEST' - recorder = FunctionRecorder() + # recorder = FunctionRecorder() def __init__(self, ): pass + def _get_ast(source_code): + return ast.parse(source_code) + def _modify_func(self, func): # 获取函数的源代码 source = inspect.getsource(func) @@ -125,17 +60,21 @@ def _modify_func(self, func): tree = ast.parse(source) import_from_statement = ast.ImportFrom( - module='mmengine.hooks', + module='mmengine.logging.MessageHub', names=[ast.alias(name='RecorderHook', asname=None)], level=0) - tree.body[0].body.insert(0, import_from_statement) + func_body = tree.body[0].body + import_statement = ast.ImportFrom(module='mmengine.logging', names=[ast.alias(name='MessageHub')], level=0) + add_message_hub = ast.Assign(targets=[ast.Name(id='message_hub', ctx=ast.Store())], value=ast.Call(func=ast.Attribute(value=ast.Name(id='MessageHub', ctx=ast.Load()), attr='get_instance', ctx=ast.Load()), args=[ast.Constant(value='mmengine')], keywords=[])) + tree.body[0].body = [import_statement, add_message_hub] + func_body + # tree.body[0].body.insert(0, import_statement) # 修改AST tree = RecorderAdder().visit(tree) tree = ast.fix_missing_locations(tree) - # print(ast.dump(tree, indent=4)) + print(ast.dump(tree, indent=4)) # 编译修改后的AST为一个新的函数 namespace = {} @@ -150,11 +89,18 @@ def before_run(self, runner) -> None: Args: runner (Runner): The runner of the training process. """ - import dis + log_scalars = dict(loss=HistoryBuffer()) + runtime_info = dict(task='task') + resumed_keys = dict(loss=True) + # create `MessageHub` from data. + message_hub2 = MessageHub( + name = 'name', + log_scalars = log_scalars, + runtime_info = runtime_info, + resumed_keys = resumed_keys) model = runner.model print('---------------------------') # breakpoint() - import types model.forward = types.MethodType( - self._modify_func(model.forward), model) \ No newline at end of file + self._modify_func(model.forward), model) From 074ee1ef14a1c3fd5f51e1d572aa55131c7e242e Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 28 Jul 2023 22:27:08 +0800 Subject: [PATCH 03/39] use message_hub.update_scalar --- mmengine/hooks/recorder_hook.py | 54 +++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 941f7534bf..3724a7844f 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -1,21 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. import ast import dis -import types import inspect import textwrap +import types from typing import Any, List, Optional, Tuple, Union +from mmengine.logging import HistoryBuffer, MessageHub from mmengine.registry import HOOKS -from mmengine.logging import MessageHub, HistoryBuffer from . import Hook + # model的 存到 runner的 message_hub class RecorderAdder(ast.NodeTransformer): def visit_Assign(self, node): - add2messagehub = ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='message_hub', ctx=ast.Load()), attr='update_info', ctx=ast.Load()), - args=[ast.Constant(value='task'), ast.Name(id=node.targets[0].id, ctx=ast.Load())], keywords=[])) + add2messagehub = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='message_hub', ctx=ast.Load()), + attr='update_scalar', + ctx=ast.Load()), + args=[ + ast.Constant(value='task'), + ast.Name(id=node.targets[0].id, ctx=ast.Load()) + ], + keywords=[])) # 插入print语句 return [node, add2messagehub] @@ -65,8 +75,19 @@ def _modify_func(self, func): level=0) func_body = tree.body[0].body - import_statement = ast.ImportFrom(module='mmengine.logging', names=[ast.alias(name='MessageHub')], level=0) - add_message_hub = ast.Assign(targets=[ast.Name(id='message_hub', ctx=ast.Store())], value=ast.Call(func=ast.Attribute(value=ast.Name(id='MessageHub', ctx=ast.Load()), attr='get_instance', ctx=ast.Load()), args=[ast.Constant(value='mmengine')], keywords=[])) + import_statement = ast.ImportFrom( + module='mmengine.logging', + names=[ast.alias(name='MessageHub')], + level=0) + add_message_hub = ast.Assign( + targets=[ast.Name(id='message_hub', ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='MessageHub', ctx=ast.Load()), + attr='get_instance', + ctx=ast.Load()), + args=[ast.Constant(value='mmengine')], + keywords=[])) tree.body[0].body = [import_statement, add_message_hub] + func_body # tree.body[0].body.insert(0, import_statement) @@ -90,17 +111,24 @@ def before_run(self, runner) -> None: runner (Runner): The runner of the training process. """ log_scalars = dict(loss=HistoryBuffer()) - runtime_info = dict(task='task') + runtime_info = dict() resumed_keys = dict(loss=True) - # create `MessageHub` from data. - message_hub2 = MessageHub( - name = 'name', - log_scalars = log_scalars, - runtime_info = runtime_info, - resumed_keys = resumed_keys) + # create `MessageHub` from data. + self.message_hub2 = MessageHub( + name='name', + log_scalars=log_scalars, + runtime_info=runtime_info, + resumed_keys=resumed_keys) model = runner.model print('---------------------------') # breakpoint() model.forward = types.MethodType( self._modify_func(model.forward), model) + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch = None, + outputs = None) -> None: + print(self.message_hub2.__dict__) From f4dcc13b19fb4c4f52a5d76a6504e8b9b5f7ee77 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 28 Jul 2023 22:27:50 +0800 Subject: [PATCH 04/39] use message_hub.update_scalar --- mmengine/hooks/recorder_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 3724a7844f..ddbc634337 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -127,8 +127,8 @@ def before_run(self, runner) -> None: self._modify_func(model.forward), model) def after_train_iter(self, - runner, - batch_idx: int, - data_batch = None, - outputs = None) -> None: + runner, + batch_idx: int, + data_batch=None, + outputs=None) -> None: print(self.message_hub2.__dict__) From e8144c99852d6de3c38fc45ca6d7e384135f414b Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 1 Aug 2023 00:18:35 +0800 Subject: [PATCH 05/39] design class Recorder --- mmengine/hooks/recorder_hook.py | 99 ++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 31 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index ddbc634337..4f65935d62 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -5,10 +5,53 @@ import textwrap import types from typing import Any, List, Optional, Tuple, Union +from collections import defaultdict from mmengine.logging import HistoryBuffer, MessageHub from mmengine.registry import HOOKS from . import Hook +from abc import ABCMeta, abstractmethod + + +class Recorder(metaclass=ABCMeta): + def __init__(self, target: str): + self._target = target + + @abstractmethod + def rewrite(self, ast_tree): + pass + +# FunctionRecorder +class FunctionRecorder(Recorder): + def __init__(self, target: str): + super().__init__(target) + self.visit_assign = self._get_adder_class() + # super.__init__() + + def _get_adder_class(self): + outer_class = self + class FunctionRecorderAdder(ast.NodeTransformer): + def visit_Assign(self, node): + if node.targets[0].id != outer_class._target: + return node + add2messagehub = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='message_hub', ctx=ast.Load()), + attr='update_info', + ctx=ast.Load()), + args=[ + ast.Constant(value = node.targets[0].id), + ast.Name(id=node.targets[0].id, ctx=ast.Load()) + ], + keywords=[])) + + # 插入print语句 + return [node, add2messagehub] + return FunctionRecorderAdder() + + def rewrite(self, ast_tree): + return self.visit_assign.visit(ast_tree) # model的 存到 runner的 message_hub @@ -19,7 +62,7 @@ def visit_Assign(self, node): value=ast.Call( func=ast.Attribute( value=ast.Name(id='message_hub', ctx=ast.Load()), - attr='update_scalar', + attr='update_info', ctx=ast.Load()), args=[ ast.Constant(value='task'), @@ -31,31 +74,16 @@ def visit_Assign(self, node): return [node, add2messagehub] -# class RecorderAdder(ast.NodeTransformer): -# def visit_Assign(self, node): -# # 这将创建一个新的print调用节点 -# print_call = ast.Expr( -# value=ast.Call( -# func=ast.Name(id='RecorderHook', ctx=ast.Load()), -# attr='add2buffer', -# args=[ -# ast.Name(id=node.targets[0].id, ctx=ast.Load()) -# ], -# keywords=[] -# ) -# ) -# -# # 插入print语句 -# return [node, print_call] - - @HOOKS.register_module() class RecorderHook(Hook): priority = 'LOWEST' + # RECORDER_MESSAGEHUB_NAME = "_recorder" + # recorder = FunctionRecorder() def __init__(self, ): + self.tensor_dict = defaultdict(list) pass def _get_ast(source_code): @@ -75,7 +103,7 @@ def _modify_func(self, func): level=0) func_body = tree.body[0].body - import_statement = ast.ImportFrom( + import_messagehub_statement = ast.ImportFrom( module='mmengine.logging', names=[ast.alias(name='MessageHub')], level=0) @@ -84,15 +112,17 @@ def _modify_func(self, func): value=ast.Call( func=ast.Attribute( value=ast.Name(id='MessageHub', ctx=ast.Load()), - attr='get_instance', + attr='get_current_instance', ctx=ast.Load()), - args=[ast.Constant(value='mmengine')], + args=[], keywords=[])) - tree.body[0].body = [import_statement, add_message_hub] + func_body + + tree.body[0].body = [import_messagehub_statement, add_message_hub] + func_body # tree.body[0].body.insert(0, import_statement) # 修改AST - tree = RecorderAdder().visit(tree) + # breakpoint() + tree = FunctionRecorder("x").rewrite(tree) tree = ast.fix_missing_locations(tree) print(ast.dump(tree, indent=4)) @@ -113,12 +143,16 @@ def before_run(self, runner) -> None: log_scalars = dict(loss=HistoryBuffer()) runtime_info = dict() resumed_keys = dict(loss=True) - # create `MessageHub` from data. - self.message_hub2 = MessageHub( - name='name', - log_scalars=log_scalars, - runtime_info=runtime_info, - resumed_keys=resumed_keys) + # # create `MessageHub` from data. + # self.message_hub2 = MessageHub( + # name=RecorderHook.RECORDER_MESSAGEHUB_NAME, + # log_scalars=log_scalars, + # runtime_info=runtime_info, + # resumed_keys=resumed_keys) + # self.message_hub2.update_info("task", "1111") + # self.message_hub2.update_info("task", {1231312: "dfasfd"}) + self.message_hub2 = MessageHub.get_current_instance() + model = runner.model print('---------------------------') # breakpoint() @@ -131,4 +165,7 @@ def after_train_iter(self, batch_idx: int, data_batch=None, outputs=None) -> None: - print(self.message_hub2.__dict__) + + # print(self.message_hub2.__dict__) + # print(self.message_hub2.get_info("task")) + self.tensor_dict["task"].append(self.message_hub2.get_info("task")) \ No newline at end of file From c60b934a962aff361daa3137ceadba1b0d8c4627 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 1 Aug 2023 00:24:10 +0800 Subject: [PATCH 06/39] add recover forward logic --- mmengine/hooks/recorder_hook.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 4f65935d62..bdd2a47d1c 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -4,16 +4,17 @@ import inspect import textwrap import types -from typing import Any, List, Optional, Tuple, Union +from abc import ABCMeta, abstractmethod from collections import defaultdict +from typing import Any, List, Optional, Tuple, Union from mmengine.logging import HistoryBuffer, MessageHub from mmengine.registry import HOOKS from . import Hook -from abc import ABCMeta, abstractmethod class Recorder(metaclass=ABCMeta): + def __init__(self, target: str): self._target = target @@ -21,16 +22,21 @@ def __init__(self, target: str): def rewrite(self, ast_tree): pass + # FunctionRecorder class FunctionRecorder(Recorder): + def __init__(self, target: str): super().__init__(target) self.visit_assign = self._get_adder_class() + # super.__init__() def _get_adder_class(self): outer_class = self + class FunctionRecorderAdder(ast.NodeTransformer): + def visit_Assign(self, node): if node.targets[0].id != outer_class._target: return node @@ -41,13 +47,14 @@ def visit_Assign(self, node): attr='update_info', ctx=ast.Load()), args=[ - ast.Constant(value = node.targets[0].id), + ast.Constant(value=node.targets[0].id), ast.Name(id=node.targets[0].id, ctx=ast.Load()) ], keywords=[])) # 插入print语句 return [node, add2messagehub] + return FunctionRecorderAdder() def rewrite(self, ast_tree): @@ -84,6 +91,7 @@ class RecorderHook(Hook): def __init__(self, ): self.tensor_dict = defaultdict(list) + self.origin_forward = None pass def _get_ast(source_code): @@ -117,12 +125,13 @@ def _modify_func(self, func): args=[], keywords=[])) - tree.body[0].body = [import_messagehub_statement, add_message_hub] + func_body + tree.body[0].body = [import_messagehub_statement, add_message_hub + ] + func_body # tree.body[0].body.insert(0, import_statement) # 修改AST # breakpoint() - tree = FunctionRecorder("x").rewrite(tree) + tree = FunctionRecorder('x').rewrite(tree) tree = ast.fix_missing_locations(tree) print(ast.dump(tree, indent=4)) @@ -156,6 +165,7 @@ def before_run(self, runner) -> None: model = runner.model print('---------------------------') # breakpoint() + self.origin_forward = model.forward model.forward = types.MethodType( self._modify_func(model.forward), model) @@ -168,4 +178,13 @@ def after_train_iter(self, # print(self.message_hub2.__dict__) # print(self.message_hub2.get_info("task")) - self.tensor_dict["task"].append(self.message_hub2.get_info("task")) \ No newline at end of file + self.tensor_dict['task'].append(self.message_hub2.get_info('task')) + + def before_train(self, runner) -> None: + model = runner.model + + model.forward = types.MethodType( + self._modify_func(model.forward), model) + + def after_train(self, runner) -> None: + runner.model.forward = self.origin_forward \ No newline at end of file From 54412dc3f250137df8855fa88d5a1b2401623748 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 1 Aug 2023 00:28:35 +0800 Subject: [PATCH 07/39] FunctionRecord actually should be AttributeRecorder because we find attribute by var name --- mmengine/hooks/recorder_hook.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index bdd2a47d1c..8ae830273d 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -23,8 +23,8 @@ def rewrite(self, ast_tree): pass -# FunctionRecorder -class FunctionRecorder(Recorder): +# AttributeRecorder +class AttributeRecorder(Recorder): def __init__(self, target: str): super().__init__(target) @@ -35,7 +35,7 @@ def __init__(self, target: str): def _get_adder_class(self): outer_class = self - class FunctionRecorderAdder(ast.NodeTransformer): + class AttributeRecorderAdder(ast.NodeTransformer): def visit_Assign(self, node): if node.targets[0].id != outer_class._target: @@ -55,7 +55,7 @@ def visit_Assign(self, node): # 插入print语句 return [node, add2messagehub] - return FunctionRecorderAdder() + return AttributeRecorderAdder() def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) @@ -87,7 +87,7 @@ class RecorderHook(Hook): # RECORDER_MESSAGEHUB_NAME = "_recorder" - # recorder = FunctionRecorder() + # recorder = AttributeRecorder() def __init__(self, ): self.tensor_dict = defaultdict(list) @@ -131,7 +131,7 @@ def _modify_func(self, func): # 修改AST # breakpoint() - tree = FunctionRecorder('x').rewrite(tree) + tree = AttributeRecorder('x').rewrite(tree) tree = ast.fix_missing_locations(tree) print(ast.dump(tree, indent=4)) From c7df8bb339416ec11c6fa0438dbdffaf97415f2d Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 5 Aug 2023 11:25:27 +0800 Subject: [PATCH 08/39] add FunctionRecorder --- mmengine/hooks/recorder_hook.py | 151 ++++++++++++++++++++------------ 1 file changed, 95 insertions(+), 56 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 8ae830273d..60a5c1b670 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -12,6 +12,75 @@ from mmengine.registry import HOOKS from . import Hook +class AttributeRecorderAdder(ast.NodeTransformer): + def __init__(self, target): + super().__init__() + self._target = target + + def visit_Assign(self, node): + if node.targets[0].id != self._target: + return node + add2messagehub = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='message_hub', ctx=ast.Load()), + attr='update_info', + ctx=ast.Load()), + args=[ + ast.Constant(value=node.targets[0].id), + ast.Name(id=node.targets[0].id, ctx=ast.Load()) + ], + keywords=[])) + + # 插入print语句 + return [node, add2messagehub] + +def get_node_name(func_name): + return "tmp_func_" + func_name + +class FuncCallVisitor(ast.NodeTransformer): + def __init__(self, func_name): + self.func_name = func_name + self.call_nodes = [] + + def is_target_call(self, call_node): + assert isinstance(call_node, ast.Call) + call_node = call_node.func + call_chain_list = self.func_name.split(".") + if len(call_chain_list) == 1: + return isinstance(call_node.func, ast.Name) and call_node.func.id == call_chain_list[0] + else: + # 倒序遍历call_chain_list + for i in range(len(call_chain_list) - 1, 0, -1): + print(ast.dump(call_node)) + if isinstance(call_node, ast.Attribute) and call_node.attr == call_chain_list[i]: + call_node = call_node.value + else: + return False + return isinstance(call_node, ast.Name) and call_node.id == call_chain_list[0] + + def visit_Call(self, node): + if not self.is_target_call(node): + return node + new_node = ast.Name(id=get_node_name(self.func_name.replace(".", "_")), ctx=ast.Load()) + self.call_nodes.append(node) + return new_node + +class FunctionRecorderAdder(ast.NodeTransformer): + def __init__(self, target): + super().__init__() + self._target = target + self.function_visitor = FuncCallVisitor(target) + + def visit_Assign(self, node): + self.function_visitor.visit(node) + if self.function_visitor.call_nodes: + assign_node = self.function_visitor.call_nodes[0] + # test = assign node + assign = ast.Assign(targets=[ast.Name(id=get_node_name(self._target.replace(".", "_")), ctx=ast.Store())], value=assign_node) + self.function_visitor.call_nodes.clear() + return [assign, node] + return node class Recorder(metaclass=ABCMeta): @@ -33,53 +102,33 @@ def __init__(self, target: str): # super.__init__() def _get_adder_class(self): - outer_class = self - - class AttributeRecorderAdder(ast.NodeTransformer): - - def visit_Assign(self, node): - if node.targets[0].id != outer_class._target: - return node - add2messagehub = ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='message_hub', ctx=ast.Load()), - attr='update_info', - ctx=ast.Load()), - args=[ - ast.Constant(value=node.targets[0].id), - ast.Name(id=node.targets[0].id, ctx=ast.Load()) - ], - keywords=[])) - - # 插入print语句 - return [node, add2messagehub] - - return AttributeRecorderAdder() + return AttributeRecorderAdder(self._target) def rewrite(self, ast_tree): - return self.visit_assign.visit(ast_tree) + new_ast_tree = self.visit_assign.visit(ast_tree) + new_ast_tree = ast.fix_missing_locations(new_ast_tree) + + modified_source_code = ast.unparse(new_ast_tree) + print(modified_source_code) + return new_ast_tree -# model的 存到 runner的 message_hub -class RecorderAdder(ast.NodeTransformer): +class FunctionRecorder(Recorder): + def __init__(self, target: str): + super().__init__(target) + self.visit_assign = self._get_adder_class() - def visit_Assign(self, node): - add2messagehub = ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='message_hub', ctx=ast.Load()), - attr='update_info', - ctx=ast.Load()), - args=[ - ast.Constant(value='task'), - ast.Name(id=node.targets[0].id, ctx=ast.Load()) - ], - keywords=[])) + def _get_adder_class(self): + return FunctionRecorderAdder(self._target) - # 插入print语句 - return [node, add2messagehub] + def rewrite(self, ast_tree): + new_ast_tree = self.visit_assign.visit(ast_tree) + new_ast_tree = ast.fix_missing_locations(new_ast_tree) + + modified_source_code = ast.unparse(new_ast_tree) + print(modified_source_code) + return new_ast_tree @HOOKS.register_module() class RecorderHook(Hook): @@ -130,12 +179,10 @@ def _modify_func(self, func): # tree.body[0].body.insert(0, import_statement) # 修改AST - # breakpoint() - tree = AttributeRecorder('x').rewrite(tree) + # tree = AttributeRecorder('x').rewrite(tree) + tree = FunctionRecorder('self.resnet').rewrite(tree) tree = ast.fix_missing_locations(tree) - print(ast.dump(tree, indent=4)) - # 编译修改后的AST为一个新的函数 namespace = {} exec( @@ -152,14 +199,6 @@ def before_run(self, runner) -> None: log_scalars = dict(loss=HistoryBuffer()) runtime_info = dict() resumed_keys = dict(loss=True) - # # create `MessageHub` from data. - # self.message_hub2 = MessageHub( - # name=RecorderHook.RECORDER_MESSAGEHUB_NAME, - # log_scalars=log_scalars, - # runtime_info=runtime_info, - # resumed_keys=resumed_keys) - # self.message_hub2.update_info("task", "1111") - # self.message_hub2.update_info("task", {1231312: "dfasfd"}) self.message_hub2 = MessageHub.get_current_instance() model = runner.model @@ -180,11 +219,11 @@ def after_train_iter(self, # print(self.message_hub2.get_info("task")) self.tensor_dict['task'].append(self.message_hub2.get_info('task')) - def before_train(self, runner) -> None: - model = runner.model + # def before_train(self, runner) -> None: + # model = runner.model - model.forward = types.MethodType( - self._modify_func(model.forward), model) + # model.forward = types.MethodType( + # self._modify_func(model.forward), model) def after_train(self, runner) -> None: runner.model.forward = self.origin_forward \ No newline at end of file From e4351ba043373f613b1deae3b36678933f0fa9f8 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 5 Aug 2023 11:37:20 +0800 Subject: [PATCH 09/39] add update2 messagehub logic --- mmengine/hooks/recorder_hook.py | 60 ++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 60a5c1b670..460eca0116 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -12,7 +12,9 @@ from mmengine.registry import HOOKS from . import Hook + class AttributeRecorderAdder(ast.NodeTransformer): + def __init__(self, target): super().__init__() self._target = target @@ -35,38 +37,48 @@ def visit_Assign(self, node): # 插入print语句 return [node, add2messagehub] + def get_node_name(func_name): - return "tmp_func_" + func_name + return 'tmp_func_' + func_name + class FuncCallVisitor(ast.NodeTransformer): + def __init__(self, func_name): self.func_name = func_name self.call_nodes = [] - + def is_target_call(self, call_node): assert isinstance(call_node, ast.Call) call_node = call_node.func - call_chain_list = self.func_name.split(".") + call_chain_list = self.func_name.split('.') if len(call_chain_list) == 1: - return isinstance(call_node.func, ast.Name) and call_node.func.id == call_chain_list[0] + return isinstance( + call_node.func, + ast.Name) and call_node.func.id == call_chain_list[0] else: # 倒序遍历call_chain_list for i in range(len(call_chain_list) - 1, 0, -1): print(ast.dump(call_node)) - if isinstance(call_node, ast.Attribute) and call_node.attr == call_chain_list[i]: + if isinstance(call_node, ast.Attribute + ) and call_node.attr == call_chain_list[i]: call_node = call_node.value else: return False - return isinstance(call_node, ast.Name) and call_node.id == call_chain_list[0] + return isinstance(call_node, + ast.Name) and call_node.id == call_chain_list[0] def visit_Call(self, node): if not self.is_target_call(node): return node - new_node = ast.Name(id=get_node_name(self.func_name.replace(".", "_")), ctx=ast.Load()) - self.call_nodes.append(node) + new_node = ast.Name( + id=get_node_name(self.func_name.replace('.', '_')), ctx=ast.Load()) + self.call_nodes.append(node) return new_node + class FunctionRecorderAdder(ast.NodeTransformer): + def __init__(self, target): super().__init__() self._target = target @@ -76,12 +88,31 @@ def visit_Assign(self, node): self.function_visitor.visit(node) if self.function_visitor.call_nodes: assign_node = self.function_visitor.call_nodes[0] + assign_node_name = get_node_name(self._target.replace('.', '_')) # test = assign node - assign = ast.Assign(targets=[ast.Name(id=get_node_name(self._target.replace(".", "_")), ctx=ast.Store())], value=assign_node) + assign = ast.Assign( + targets=[ + ast.Name( + id=assign_node_name, + ctx=ast.Store()) + ], + value=assign_node) + add2messagehub = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='message_hub', ctx=ast.Load()), + attr='update_info', + ctx=ast.Load()), + args=[ + ast.Constant(value=self._target), + ast.Name(id=assign_node_name, ctx=ast.Load()) + ], + keywords=[])) self.function_visitor.call_nodes.clear() - return [assign, node] + return [assign, add2messagehub, node] return node + class Recorder(metaclass=ABCMeta): def __init__(self, target: str): @@ -107,13 +138,15 @@ def _get_adder_class(self): def rewrite(self, ast_tree): new_ast_tree = self.visit_assign.visit(ast_tree) new_ast_tree = ast.fix_missing_locations(new_ast_tree) - + modified_source_code = ast.unparse(new_ast_tree) print(modified_source_code) return new_ast_tree + class FunctionRecorder(Recorder): + def __init__(self, target: str): super().__init__(target) self.visit_assign = self._get_adder_class() @@ -124,12 +157,13 @@ def _get_adder_class(self): def rewrite(self, ast_tree): new_ast_tree = self.visit_assign.visit(ast_tree) new_ast_tree = ast.fix_missing_locations(new_ast_tree) - + modified_source_code = ast.unparse(new_ast_tree) print(modified_source_code) return new_ast_tree + @HOOKS.register_module() class RecorderHook(Hook): priority = 'LOWEST' @@ -226,4 +260,4 @@ def after_train_iter(self, # self._modify_func(model.forward), model) def after_train(self, runner) -> None: - runner.model.forward = self.origin_forward \ No newline at end of file + runner.model.forward = self.origin_forward From 9fe0e7d5c7ac7ac562c8196d27caec689d40b9b7 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 5 Aug 2023 15:00:54 +0800 Subject: [PATCH 10/39] clean up code --- mmengine/hooks/recorder_hook.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 460eca0116..10e3bdf238 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -59,7 +59,6 @@ def is_target_call(self, call_node): else: # 倒序遍历call_chain_list for i in range(len(call_chain_list) - 1, 0, -1): - print(ast.dump(call_node)) if isinstance(call_node, ast.Attribute ) and call_node.attr == call_chain_list[i]: call_node = call_node.value @@ -168,14 +167,11 @@ def rewrite(self, ast_tree): class RecorderHook(Hook): priority = 'LOWEST' - # RECORDER_MESSAGEHUB_NAME = "_recorder" - - # recorder = AttributeRecorder() - - def __init__(self, ): + def __init__(self, + recorders: Optional[List[Dict]] = None): self.tensor_dict = defaultdict(list) self.origin_forward = None - pass + self._recorders = [] def _get_ast(source_code): return ast.parse(source_code) From fd6b8e48892b7f7e666006a9e025ed344b4cc64e Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 6 Aug 2023 01:06:15 +0800 Subject: [PATCH 11/39] add comment and registry for AttributeRecorder and FunctionRecorder --- mmengine/hooks/recorder_hook.py | 89 ++++++++++++--------------------- mmengine/registry/__init__.py | 9 ++-- mmengine/registry/root.py | 3 ++ 3 files changed, 41 insertions(+), 60 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 10e3bdf238..cea364a78c 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -1,20 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import ast -import dis import inspect import textwrap import types from abc import ABCMeta, abstractmethod -from collections import defaultdict -from typing import Any, List, Optional, Tuple, Union +from typing import Dict, List, Optional -from mmengine.logging import HistoryBuffer, MessageHub -from mmengine.registry import HOOKS +from mmengine.logging import MessageHub +from mmengine.registry import HOOKS, RECORDERS from . import Hook class AttributeRecorderAdder(ast.NodeTransformer): - def __init__(self, target): super().__init__() self._target = target @@ -22,7 +19,7 @@ def __init__(self, target): def visit_Assign(self, node): if node.targets[0].id != self._target: return node - add2messagehub = ast.Expr( + update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( value=ast.Name(id='message_hub', ctx=ast.Load()), @@ -34,8 +31,7 @@ def visit_Assign(self, node): ], keywords=[])) - # 插入print语句 - return [node, add2messagehub] + return [node, update_messagehub_node] def get_node_name(func_name): @@ -57,7 +53,7 @@ def is_target_call(self, call_node): call_node.func, ast.Name) and call_node.func.id == call_chain_list[0] else: - # 倒序遍历call_chain_list + # Traversal call_chain_list in reverse order for i in range(len(call_chain_list) - 1, 0, -1): if isinstance(call_node, ast.Attribute ) and call_node.attr == call_chain_list[i]: @@ -90,13 +86,9 @@ def visit_Assign(self, node): assign_node_name = get_node_name(self._target.replace('.', '_')) # test = assign node assign = ast.Assign( - targets=[ - ast.Name( - id=assign_node_name, - ctx=ast.Store()) - ], + targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], value=assign_node) - add2messagehub = ast.Expr( + update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( value=ast.Name(id='message_hub', ctx=ast.Load()), @@ -108,7 +100,7 @@ def visit_Assign(self, node): ], keywords=[])) self.function_visitor.call_nodes.clear() - return [assign, add2messagehub, node] + return [assign, update_messagehub_node, node] return node @@ -122,15 +114,13 @@ def rewrite(self, ast_tree): pass -# AttributeRecorder +@RECORDERS.register_module() class AttributeRecorder(Recorder): def __init__(self, target: str): super().__init__(target) self.visit_assign = self._get_adder_class() - # super.__init__() - def _get_adder_class(self): return AttributeRecorderAdder(self._target) @@ -144,6 +134,7 @@ def rewrite(self, ast_tree): return new_ast_tree +@RECORDERS.register_module() class FunctionRecorder(Recorder): def __init__(self, target: str): @@ -167,34 +158,34 @@ def rewrite(self, ast_tree): class RecorderHook(Hook): priority = 'LOWEST' - def __init__(self, - recorders: Optional[List[Dict]] = None): - self.tensor_dict = defaultdict(list) + def __init__(self, recorders: Optional[List[Dict]] = None): + self.tensor_dict = {} self.origin_forward = None - self._recorders = [] + self._recorders = {} + for recorder in recorders: + assert recorder.get('target') is not None + self.tensor_dict[recorder['target']] = list() + self._recorders[recorder['target']] = RECORDERS.build(recorder) def _get_ast(source_code): return ast.parse(source_code) def _modify_func(self, func): - # 获取函数的源代码 + # Gets the source code for the function source = inspect.getsource(func) source = textwrap.dedent(source) - # 解析源代码为AST + # Parse source code as ast tree = ast.parse(source) - import_from_statement = ast.ImportFrom( - module='mmengine.logging.MessageHub', - names=[ast.alias(name='RecorderHook', asname=None)], - level=0) - func_body = tree.body[0].body - import_messagehub_statement = ast.ImportFrom( + # import mmengine.logging.MessageHub + import_messagehub_node = ast.ImportFrom( module='mmengine.logging', names=[ast.alias(name='MessageHub')], level=0) - add_message_hub = ast.Assign( + # get messagehub instance + get_messagehub_node = ast.Assign( targets=[ast.Name(id='message_hub', ctx=ast.Store())], value=ast.Call( func=ast.Attribute( @@ -204,16 +195,15 @@ def _modify_func(self, func): args=[], keywords=[])) - tree.body[0].body = [import_messagehub_statement, add_message_hub + tree.body[0].body = [import_messagehub_node, get_messagehub_node ] + func_body # tree.body[0].body.insert(0, import_statement) - # 修改AST - # tree = AttributeRecorder('x').rewrite(tree) - tree = FunctionRecorder('self.resnet').rewrite(tree) + for recorder in self._recorders.values(): + tree = recorder.rewrite(tree) tree = ast.fix_missing_locations(tree) - # 编译修改后的AST为一个新的函数 + # Compile the modified ast as a new function namespace = {} exec( compile(tree, filename='', mode='exec'), func.__globals__, @@ -226,16 +216,11 @@ def before_run(self, runner) -> None: Args: runner (Runner): The runner of the training process. """ - log_scalars = dict(loss=HistoryBuffer()) - runtime_info = dict() - resumed_keys = dict(loss=True) - self.message_hub2 = MessageHub.get_current_instance() - + # get messagehub instance and store it. + self.message_hub = MessageHub.get_current_instance() + # get model and modify its forward function model = runner.model - print('---------------------------') - # breakpoint() self.origin_forward = model.forward - model.forward = types.MethodType( self._modify_func(model.forward), model) @@ -244,16 +229,8 @@ def after_train_iter(self, batch_idx: int, data_batch=None, outputs=None) -> None: - - # print(self.message_hub2.__dict__) - # print(self.message_hub2.get_info("task")) - self.tensor_dict['task'].append(self.message_hub2.get_info('task')) - - # def before_train(self, runner) -> None: - # model = runner.model - - # model.forward = types.MethodType( - # self._modify_func(model.forward), model) + for key in self.tensor_dict.keys(): + self.tensor_dict[key].append(self.message_hub.get_info(key)) def after_train(self, runner) -> None: runner.model.forward = self.origin_forward diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index cce2737043..7bbe5a316d 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -6,9 +6,9 @@ from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS, INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, - OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, - STRATEGIES, TASK_UTILS, TRANSFORMS, VISBACKENDS, - VISUALIZERS, WEIGHT_INITIALIZERS) + OPTIMIZERS, PARAM_SCHEDULERS, RECORDERS, + RUNNER_CONSTRUCTORS, RUNNERS, STRATEGIES, TASK_UTILS, + TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) from .utils import (count_registered_modules, init_default_scope, traverse_registry_tree) @@ -20,5 +20,6 @@ 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'INFERENCERS', 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules', 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg', - 'build_scheduler_from_cfg', 'init_default_scope', 'FUNCTIONS', 'STRATEGIES' + 'build_scheduler_from_cfg', 'init_default_scope', 'FUNCTIONS', + 'STRATEGIES', 'RECORDERS' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 2663dffcd9..d75f38af26 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -57,6 +57,9 @@ # manage visualizer backend VISBACKENDS = Registry('vis_backend') +# manage recorders +RECORDERS = Registry('recorder') + # manage logprocessor LOG_PROCESSORS = Registry('log_processor') From 25b2415c5557a27f67d2ae1ef995fe5e170de3b9 Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 6 Aug 2023 01:11:54 +0800 Subject: [PATCH 12/39] fix commit verify --- mmengine/hooks/recorder_hook.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index cea364a78c..47d82da66a 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -4,7 +4,7 @@ import textwrap import types from abc import ABCMeta, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from mmengine.logging import MessageHub from mmengine.registry import HOOKS, RECORDERS @@ -12,6 +12,7 @@ class AttributeRecorderAdder(ast.NodeTransformer): + def __init__(self, target): super().__init__() self._target = target @@ -159,9 +160,11 @@ class RecorderHook(Hook): priority = 'LOWEST' def __init__(self, recorders: Optional[List[Dict]] = None): - self.tensor_dict = {} + self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None - self._recorders = {} + self._recorders: Dict[str, Recorder] = {} + if recorders is None: + raise ValueError('recorders not initialized') for recorder in recorders: assert recorder.get('target') is not None self.tensor_dict[recorder['target']] = list() From ce0bfbea1e6e08337d0673fb9850463a9e63d0c7 Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 6 Aug 2023 01:16:19 +0800 Subject: [PATCH 13/39] do some clean up --- mmengine/hooks/recorder_hook.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 47d82da66a..cf611d183c 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -35,6 +35,10 @@ def visit_Assign(self, node): return [node, update_messagehub_node] +# Take "x = self.conv1(x)" as an example +# genertate "tmp_func_self_conv1 = self.conv1(x)" +# and "x = tmp_func_self_conv1" +# and "message_hub.update_info('conv1', tmp_func_conv1)" def get_node_name(func_name): return 'tmp_func_' + func_name @@ -83,12 +87,11 @@ def __init__(self, target): def visit_Assign(self, node): self.function_visitor.visit(node) if self.function_visitor.call_nodes: - assign_node = self.function_visitor.call_nodes[0] + assign_right_node = self.function_visitor.call_nodes[0] assign_node_name = get_node_name(self._target.replace('.', '_')) - # test = assign node - assign = ast.Assign( + assign_left_node = ast.Assign( targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], - value=assign_node) + value=assign_right_node) update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( @@ -101,7 +104,7 @@ def visit_Assign(self, node): ], keywords=[])) self.function_visitor.call_nodes.clear() - return [assign, update_messagehub_node, node] + return [assign_left_node, update_messagehub_node, node] return node @@ -200,7 +203,6 @@ def _modify_func(self, func): tree.body[0].body = [import_messagehub_node, get_messagehub_node ] + func_body - # tree.body[0].body.insert(0, import_statement) for recorder in self._recorders.values(): tree = recorder.rewrite(tree) From 9ef4e4430f1bf291670731dec4a758e0e9e2da25 Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 6 Aug 2023 01:59:31 +0800 Subject: [PATCH 14/39] add recorder_hook_test.py --- examples/recorder_hook_test.py | 87 +++++++++++++++++++++++++++++++++ mmengine/hooks/recorder_hook.py | 43 ++++++++++------ 2 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 examples/recorder_hook_test.py diff --git a/examples/recorder_hook_test.py b/examples/recorder_hook_test.py new file mode 100644 index 0000000000..512d463a4d --- /dev/null +++ b/examples/recorder_hook_test.py @@ -0,0 +1,87 @@ +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.optim import SGD +from torch.utils.data import DataLoader + +from mmengine.evaluator import BaseMetric +from mmengine.model import BaseModel +from mmengine.runner import Runner + + +class MMResNet50(BaseModel): + + def __init__(self): + super().__init__() + self.resnet = torchvision.models.resnet50() + + def forward(self, imgs, labels, mode): + x = self.resnet(imgs) + if mode == 'loss': + return {'loss': F.cross_entropy(x, labels)} + elif mode == 'predict': + return x, labels + + +class Accuracy(BaseMetric): + + def process(self, data_batch, data_samples): + score, gt = data_samples + self.results.append({ + 'batch_size': len(gt), + 'correct': (score.argmax(dim=1) == gt).sum().cpu(), + }) + + def compute_metrics(self, results): + total_correct = sum(item['correct'] for item in results) + total_size = sum(item['batch_size'] for item in results) + return dict(accuracy=100 * total_correct / total_size) + + +norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) +train_dataloader = DataLoader( + batch_size=32, + shuffle=True, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) + +val_dataloader = DataLoader( + batch_size=32, + shuffle=False, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(**norm_cfg)]))) + +# default_hooks = dict(logger=dict(type='LoggerHook', interval=20)) + +runner = Runner( + # default_hooks=default_hooks, + custom_hooks=[ + dict( + type='RecorderHook', + recorders=[dict(type='FunctionRecorder', target='self.resnet')], + save_dir='./work_dir', + print_modification=True) + ], + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), +) +runner.train() diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index cf611d183c..4e565b5f77 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -1,12 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import ast import inspect +import logging import textwrap import types from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Optional -from mmengine.logging import MessageHub +from mmengine.logging import MessageHub, print_log from mmengine.registry import HOOKS, RECORDERS from . import Hook @@ -129,13 +130,7 @@ def _get_adder_class(self): return AttributeRecorderAdder(self._target) def rewrite(self, ast_tree): - new_ast_tree = self.visit_assign.visit(ast_tree) - new_ast_tree = ast.fix_missing_locations(new_ast_tree) - - modified_source_code = ast.unparse(new_ast_tree) - print(modified_source_code) - - return new_ast_tree + return self.visit_assign.visit(ast_tree) @RECORDERS.register_module() @@ -149,23 +144,29 @@ def _get_adder_class(self): return FunctionRecorderAdder(self._target) def rewrite(self, ast_tree): - new_ast_tree = self.visit_assign.visit(ast_tree) - new_ast_tree = ast.fix_missing_locations(new_ast_tree) - - modified_source_code = ast.unparse(new_ast_tree) - print(modified_source_code) - - return new_ast_tree + return self.visit_assign.visit(ast_tree) @HOOKS.register_module() class RecorderHook(Hook): priority = 'LOWEST' - def __init__(self, recorders: Optional[List[Dict]] = None): + def __init__(self, + recorders: Optional[List[Dict]] = None, + print_modification: bool = True, + save_dir: str = ''): self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None self._recorders: Dict[str, Recorder] = {} + self._print_modification = print_modification + if not save_dir: + print_log( + '`RecorderHook` cannot save the tensor values ' + 'because save_dir is None.', + logger='current', + level=logging.WARNING) + self._save_dir = save_dir + if recorders is None: raise ValueError('recorders not initialized') for recorder in recorders: @@ -206,6 +207,14 @@ def _modify_func(self, func): for recorder in self._recorders.values(): tree = recorder.rewrite(tree) + if self._print_modification: + new_tree = ast.fix_missing_locations(tree) + modified_source_code = ast.unparse(new_tree) + print_log( + f'After modification, the source code is:\n' + f'{modified_source_code}', + logger='current', + level=logging.INFO) tree = ast.fix_missing_locations(tree) # Compile the modified ast as a new function @@ -238,4 +247,6 @@ def after_train_iter(self, self.tensor_dict[key].append(self.message_hub.get_info(key)) def after_train(self, runner) -> None: + import pickle runner.model.forward = self.origin_forward + pickle.dump(self.tensor_dict, open(self._save_dir + 'tensor', 'wb')) From 4b396aadd443f610f597fab48d0e0d8c1fed7f71 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 9 Aug 2023 00:04:36 +0800 Subject: [PATCH 15/39] redesign FunctionRecorder and AttributeRecorder --- mmengine/hooks/recorder_hook.py | 128 ++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 54 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 4e565b5f77..e607d94eea 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -2,6 +2,7 @@ import ast import inspect import logging +import pickle import textwrap import types from abc import ABCMeta, abstractmethod @@ -12,7 +13,7 @@ from . import Hook -class AttributeRecorderAdder(ast.NodeTransformer): +class FunctionRecorderTransformer(ast.NodeTransformer): def __init__(self, target): super().__init__() @@ -40,8 +41,15 @@ def visit_Assign(self, node): # genertate "tmp_func_self_conv1 = self.conv1(x)" # and "x = tmp_func_self_conv1" # and "message_hub.update_info('conv1', tmp_func_conv1)" -def get_node_name(func_name): - return 'tmp_func_' + func_name +def _get_tensor_key(target, attribute=None): + target = target.replace('.', '_') + if attribute: + target = target + '_' + attribute + return target + + +# def get_node_name(func_name): +# return 'tmp_func_' + func_name class FuncCallVisitor(ast.NodeTransformer): @@ -50,10 +58,11 @@ def __init__(self, func_name): self.func_name = func_name self.call_nodes = [] - def is_target_call(self, call_node): + # judge if the ast.Call node is user wanted + def _is_target_call(self, call_node): assert isinstance(call_node, ast.Call) - call_node = call_node.func call_chain_list = self.func_name.split('.') + call_node = call_node.func if len(call_chain_list) == 1: return isinstance( call_node.func, @@ -70,39 +79,43 @@ def is_target_call(self, call_node): ast.Name) and call_node.id == call_chain_list[0] def visit_Call(self, node): - if not self.is_target_call(node): + if not self._is_target_call(node): return node - new_node = ast.Name( - id=get_node_name(self.func_name.replace('.', '_')), ctx=ast.Load()) + new_node = ast.Name(id=_get_tensor_key(self.func_name), ctx=ast.Load()) self.call_nodes.append(node) return new_node -class FunctionRecorderAdder(ast.NodeTransformer): +class AttributeRecorderTransformer(ast.NodeTransformer): - def __init__(self, target): + def __init__(self, target, attribute): super().__init__() self._target = target + self._attribute = attribute self.function_visitor = FuncCallVisitor(target) def visit_Assign(self, node): self.function_visitor.visit(node) if self.function_visitor.call_nodes: assign_right_node = self.function_visitor.call_nodes[0] - assign_node_name = get_node_name(self._target.replace('.', '_')) + assign_node_name = _get_tensor_key(self._target, self._attribute) assign_left_node = ast.Assign( targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], value=assign_right_node) + if self._attribute: + ast_arg2 = ast.Attribute( + value=ast.Name(assign_node_name, ctx=ast.Load()), + attr=self._attribute, + ctx=ast.Load()) + else: + ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load()) update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( value=ast.Name(id='message_hub', ctx=ast.Load()), attr='update_info', ctx=ast.Load()), - args=[ - ast.Constant(value=self._target), - ast.Name(id=assign_node_name, ctx=ast.Load()) - ], + args=[ast.Constant(value=assign_node_name), ast_arg2], keywords=[])) self.function_visitor.call_nodes.clear() return [assign_left_node, update_messagehub_node, node] @@ -111,8 +124,9 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): - def __init__(self, target: str): + def __init__(self, target: str, saved_tensor_key: str): self._target = target + self._saved_tensor_key = saved_tensor_key @abstractmethod def rewrite(self, ast_tree): @@ -120,28 +134,32 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() -class AttributeRecorder(Recorder): +class FunctionRecorder(Recorder): - def __init__(self, target: str): - super().__init__(target) - self.visit_assign = self._get_adder_class() + def __init__(self, target: str, saved_tensor_key: str): + super().__init__(target, saved_tensor_key) + self.visit_assign = self._get_transformer_class() - def _get_adder_class(self): - return AttributeRecorderAdder(self._target) + def _get_transformer_class(self): + return FunctionRecorderTransformer(self._target) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) @RECORDERS.register_module() -class FunctionRecorder(Recorder): +class AttributeRecorder(Recorder): - def __init__(self, target: str): - super().__init__(target) - self.visit_assign = self._get_adder_class() + def __init__(self, + target: str, + saved_tensor_key: str, + attribute: str = None): + super().__init__(target, saved_tensor_key) + self.attribute = attribute + self.visit_assign = self._get_transformer_class() - def _get_adder_class(self): - return FunctionRecorderAdder(self._target) + def _get_transformer_class(self): + return AttributeRecorderTransformer(self._target, self.attribute) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) @@ -151,31 +169,34 @@ def rewrite(self, ast_tree): class RecorderHook(Hook): priority = 'LOWEST' - def __init__(self, - recorders: Optional[List[Dict]] = None, - print_modification: bool = True, - save_dir: str = ''): + def __init__( + self, + recorders: Optional[List[Dict]] = None, + print_modification: bool = True, + save_dir: str = None, + filename_tmpl: Optional[str] = None, + ): self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None self._recorders: Dict[str, Recorder] = {} - self._print_modification = print_modification - if not save_dir: - print_log( - '`RecorderHook` cannot save the tensor values ' - 'because save_dir is None.', - logger='current', - level=logging.WARNING) - self._save_dir = save_dir - - if recorders is None: + self.print_modification = print_modification + self.save_dir = save_dir # type: ignore + + if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') for recorder in recorders: - assert recorder.get('target') is not None - self.tensor_dict[recorder['target']] = list() - self._recorders[recorder['target']] = RECORDERS.build(recorder) + target = recorder.get('target') + attribute = recorder.get('attribute') - def _get_ast(source_code): - return ast.parse(source_code) + if target is None: + print_log( + '`RecorderHook` cannot be initialized ' + 'because recorder has no target', + logger='current', + level=logging.WARNING) + tensor_key = _get_tensor_key(target, attribute) + self.tensor_dict[tensor_key] = list() + self._recorders[tensor_key] = RECORDERS.build(recorder) def _modify_func(self, func): # Gets the source code for the function @@ -207,7 +228,7 @@ def _modify_func(self, func): for recorder in self._recorders.values(): tree = recorder.rewrite(tree) - if self._print_modification: + if self.print_modification: new_tree = ast.fix_missing_locations(tree) modified_source_code = ast.unparse(new_tree) print_log( @@ -225,11 +246,9 @@ def _modify_func(self, func): return namespace[func.__name__] def before_run(self, runner) -> None: - """Check `stop_training` variable in `runner.train_loop`. + if not self.save_dir: + self.save_dir = runner.work_dir - Args: - runner (Runner): The runner of the training process. - """ # get messagehub instance and store it. self.message_hub = MessageHub.get_current_instance() # get model and modify its forward function @@ -247,6 +266,7 @@ def after_train_iter(self, self.tensor_dict[key].append(self.message_hub.get_info(key)) def after_train(self, runner) -> None: - import pickle + data = pickle.dumps(self.tensor_dict) + print(data) + # use self.save_dir to save data runner.model.forward = self.origin_forward - pickle.dump(self.tensor_dict, open(self._save_dir + 'tensor', 'wb')) From 2d48fae97da22d06f03ac01493d5e61e0ea4e462 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 10 Aug 2023 23:17:17 +0800 Subject: [PATCH 16/39] modify recorder_hook_test.py --- examples/recorder_hook_test.py | 15 +++++++++++++-- mmengine/hooks/recorder_hook.py | 14 +++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/recorder_hook_test.py b/examples/recorder_hook_test.py index 512d463a4d..fdf65ce402 100644 --- a/examples/recorder_hook_test.py +++ b/examples/recorder_hook_test.py @@ -67,11 +67,22 @@ def compute_metrics(self, results): # default_hooks = dict(logger=dict(type='LoggerHook', interval=20)) runner = Runner( - # default_hooks=default_hooks, + # custom_hooks=[ + # dict( + # type='RecorderHook', + # recorders=[dict(type='FunctionRecorder', target='x')], + # save_dir='./work_dir', + # print_modification=True) + # ], custom_hooks=[ dict( type='RecorderHook', - recorders=[dict(type='FunctionRecorder', target='self.resnet')], + recorders=[ + dict( + type='AttributeRecorder', + target='self.resnet', + attribute='weight') + ], save_dir='./work_dir', print_modification=True) ], diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index e607d94eea..b09a70a341 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -124,9 +124,8 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): - def __init__(self, target: str, saved_tensor_key: str): + def __init__(self, target: str): self._target = target - self._saved_tensor_key = saved_tensor_key @abstractmethod def rewrite(self, ast_tree): @@ -136,8 +135,8 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class FunctionRecorder(Recorder): - def __init__(self, target: str, saved_tensor_key: str): - super().__init__(target, saved_tensor_key) + def __init__(self, target: str): + super().__init__(target) self.visit_assign = self._get_transformer_class() def _get_transformer_class(self): @@ -150,11 +149,8 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class AttributeRecorder(Recorder): - def __init__(self, - target: str, - saved_tensor_key: str, - attribute: str = None): - super().__init__(target, saved_tensor_key) + def __init__(self, target: str, attribute: str = None): + super().__init__(target) self.attribute = attribute self.visit_assign = self._get_transformer_class() From 9a6ff6ff4930fd44504493f0db1ebdccd550c105 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 4 Sep 2023 13:44:29 +0800 Subject: [PATCH 17/39] modify attribute recorder --- mmengine/hooks/recorder_hook.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index b09a70a341..2799b4448a 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -94,19 +94,26 @@ def __init__(self, target, attribute): self._attribute = attribute self.function_visitor = FuncCallVisitor(target) + def _get_target_attribute(self): + func_chain = self._target.split('.') + func_chain.append(self._attribute) + assert len(func_chain) >= 2 + attr = ast.Attribute(value=ast.Name(id=func_chain[0], ctx=ast.Load()), attr=func_chain[1], ctx=ast.Load()) + for ele in func_chain[2:]: + attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) + return attr + def visit_Assign(self, node): self.function_visitor.visit(node) if self.function_visitor.call_nodes: assign_right_node = self.function_visitor.call_nodes[0] - assign_node_name = _get_tensor_key(self._target, self._attribute) + assign_node_name = _get_tensor_key(self._target, None) assign_left_node = ast.Assign( targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], value=assign_right_node) if self._attribute: - ast_arg2 = ast.Attribute( - value=ast.Name(assign_node_name, ctx=ast.Load()), - attr=self._attribute, - ctx=ast.Load()) + assign_node_name = _get_tensor_key(self._target, self._attribute) + ast_arg2 = self._get_target_attribute() else: ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load()) update_messagehub_node = ast.Expr( From 4c5d27b832ff12c9f82de1464fe1b460ce4c701c Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 4 Sep 2023 22:34:25 +0800 Subject: [PATCH 18/39] store function recorder in a format of assign_name@index --- mmengine/hooks/recorder_hook.py | 70 +++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 2799b4448a..585bc07be3 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -2,6 +2,7 @@ import ast import inspect import logging +import os.path as osp import pickle import textwrap import types @@ -13,15 +14,24 @@ from . import Hook +def function_with_index(function, index): + return function + '@' + str(index) + + class FunctionRecorderTransformer(ast.NodeTransformer): - def __init__(self, target): + def __init__(self, target, target_index): super().__init__() self._target = target + self._target_index = set(target_index) + self.count = 0 def visit_Assign(self, node): if node.targets[0].id != self._target: return node + self.count += 1 + if self.count not in self._target_index: + return node update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( @@ -29,7 +39,9 @@ def visit_Assign(self, node): attr='update_info', ctx=ast.Load()), args=[ - ast.Constant(value=node.targets[0].id), + ast.Constant( + value=function_with_index(node.targets[0].id, + self.count)), ast.Name(id=node.targets[0].id, ctx=ast.Load()) ], keywords=[])) @@ -48,10 +60,6 @@ def _get_tensor_key(target, attribute=None): return target -# def get_node_name(func_name): -# return 'tmp_func_' + func_name - - class FuncCallVisitor(ast.NodeTransformer): def __init__(self, func_name): @@ -98,10 +106,13 @@ def _get_target_attribute(self): func_chain = self._target.split('.') func_chain.append(self._attribute) assert len(func_chain) >= 2 - attr = ast.Attribute(value=ast.Name(id=func_chain[0], ctx=ast.Load()), attr=func_chain[1], ctx=ast.Load()) + attr = ast.Attribute( + value=ast.Name(id=func_chain[0], ctx=ast.Load()), + attr=func_chain[1], + ctx=ast.Load()) for ele in func_chain[2:]: attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) - return attr + return attr def visit_Assign(self, node): self.function_visitor.visit(node) @@ -112,7 +123,8 @@ def visit_Assign(self, node): targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], value=assign_right_node) if self._attribute: - assign_node_name = _get_tensor_key(self._target, self._attribute) + assign_node_name = _get_tensor_key(self._target, + self._attribute) ast_arg2 = self._get_target_attribute() else: ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load()) @@ -142,12 +154,13 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class FunctionRecorder(Recorder): - def __init__(self, target: str): + def __init__(self, target: str, index: list): super().__init__(target) + self.index = index self.visit_assign = self._get_transformer_class() def _get_transformer_class(self): - return FunctionRecorderTransformer(self._target) + return FunctionRecorderTransformer(self._target, self.index) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) @@ -184,12 +197,15 @@ def __init__( self._recorders: Dict[str, Recorder] = {} self.print_modification = print_modification self.save_dir = save_dir # type: ignore + if filename_tmpl is None: + self.filename_tmpl = 'record_epoch_{}.pth' if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') for recorder in recorders: target = recorder.get('target') attribute = recorder.get('attribute') + tensor_key = _get_tensor_key(target, attribute) if target is None: print_log( @@ -197,8 +213,17 @@ def __init__( 'because recorder has no target', logger='current', level=logging.WARNING) - tensor_key = _get_tensor_key(target, attribute) - self.tensor_dict[tensor_key] = list() + if recorder.get('type') == 'FunctionRecorder': + index = recorder.get('index') + if isinstance(index, list): + for i in index: + self.tensor_dict[function_with_index(target, + i)] = list() + elif isinstance(index, int): + self.tensor_dict[function_with_index(target, + index)] = list() + elif recorder.get('type') == 'AttributeRecorder': + self.tensor_dict[tensor_key] = list() self._recorders[tensor_key] = RECORDERS.build(recorder) def _modify_func(self, func): @@ -268,8 +293,21 @@ def after_train_iter(self, for key in self.tensor_dict.keys(): self.tensor_dict[key].append(self.message_hub.get_info(key)) + def _save_record(self, step): + recorder_file_name = self.filename_tmpl.format(step) + path = osp.join(self.save_dir, recorder_file_name) + with open(path, 'wb') as f: + pickle.dump(self.tensor_dict, f) + + def _init_tensor_dict(self): + for k in self.tensor_dict.keys(): + self.tensor_dict[k] = list() + + def after_train_epoch(self, runner) -> None: + step = runner.epoch + 1 + runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') + self._save_record(step) + self._init_tensor_dict() + def after_train(self, runner) -> None: - data = pickle.dumps(self.tensor_dict) - print(data) - # use self.save_dir to save data runner.model.forward = self.origin_forward From 2d8b64b0844555eff9eba9df8052f6b1aff873b3 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 4 Sep 2023 23:14:41 +0800 Subject: [PATCH 19/39] modify function recorder index: start from 0 --- mmengine/hooks/recorder_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 585bc07be3..7cef0ab40b 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -24,7 +24,7 @@ def __init__(self, target, target_index): super().__init__() self._target = target self._target_index = set(target_index) - self.count = 0 + self.count = -1 def visit_Assign(self, node): if node.targets[0].id != self._target: From 7bdf2c097b8f5ff3d2d4c7f37567b54fd6eba6ca Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 12 Sep 2023 15:23:46 +0800 Subject: [PATCH 20/39] use torch.save to dump data; handle when index is int --- mmengine/hooks/recorder_hook.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 7cef0ab40b..49b451b4c4 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -3,12 +3,13 @@ import inspect import logging import os.path as osp -import pickle import textwrap import types from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Optional +import torch + from mmengine.logging import MessageHub, print_log from mmengine.registry import HOOKS, RECORDERS from . import Hook @@ -23,7 +24,10 @@ class FunctionRecorderTransformer(ast.NodeTransformer): def __init__(self, target, target_index): super().__init__() self._target = target - self._target_index = set(target_index) + if isinstance(target_index, list): + self._target_index = set(target_index) + else: + self._target_index = {target_index} self.count = -1 def visit_Assign(self, node): @@ -296,8 +300,7 @@ def after_train_iter(self, def _save_record(self, step): recorder_file_name = self.filename_tmpl.format(step) path = osp.join(self.save_dir, recorder_file_name) - with open(path, 'wb') as f: - pickle.dump(self.tensor_dict, f) + torch.save(self.tensor_dict, path) def _init_tensor_dict(self): for k in self.tensor_dict.keys(): From 9fa6c945b0f6ef71b199583608b473b2f1689021 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 12 Sep 2023 15:29:51 +0800 Subject: [PATCH 21/39] add default value for FunctionRecorder's index --- mmengine/hooks/recorder_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 49b451b4c4..630546da02 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -218,7 +218,7 @@ def __init__( logger='current', level=logging.WARNING) if recorder.get('type') == 'FunctionRecorder': - index = recorder.get('index') + index = recorder.get('index', 0) if isinstance(index, list): for i in index: self.tensor_dict[function_with_index(target, From 4102fa22f98bee557ad07f3d0a05cd949e9de1d5 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 12 Sep 2023 17:47:05 +0800 Subject: [PATCH 22/39] add copy.deepcopy to collect weight in layer --- mmengine/hooks/recorder_hook.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 630546da02..e13ab2ce46 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -132,13 +132,24 @@ def visit_Assign(self, node): ast_arg2 = self._get_target_attribute() else: ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load()) + + deep_copy_ast_arg2 = ast.Call( + func=ast.Attribute( + value=ast.Name(id='copy', ctx=ast.Load()), + attr='deepcopy', + ctx=ast.Load()), + args=[ast_arg2], + keywords=[]) update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( value=ast.Name(id='message_hub', ctx=ast.Load()), attr='update_info', ctx=ast.Load()), - args=[ast.Constant(value=assign_node_name), ast_arg2], + args=[ + ast.Constant( + value=assign_node_name), deep_copy_ast_arg2 + ], keywords=[])) self.function_visitor.call_nodes.clear() return [assign_left_node, update_messagehub_node, node] @@ -244,6 +255,7 @@ def _modify_func(self, func): module='mmengine.logging', names=[ast.alias(name='MessageHub')], level=0) + import_copy_node = ast.Import(names=[ast.alias(name='copy')]) # get messagehub instance get_messagehub_node = ast.Assign( targets=[ast.Name(id='message_hub', ctx=ast.Store())], @@ -255,8 +267,9 @@ def _modify_func(self, func): args=[], keywords=[])) - tree.body[0].body = [import_messagehub_node, get_messagehub_node - ] + func_body + tree.body[0].body = [ + import_messagehub_node, import_copy_node, get_messagehub_node + ] + func_body for recorder in self._recorders.values(): tree = recorder.rewrite(tree) From f72c7b1025647bee3b578772fdb7fb21e12c4f28 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 12 Sep 2023 23:51:29 +0800 Subject: [PATCH 23/39] rename var name --- mmengine/hooks/recorder_hook.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index e13ab2ce46..e77a6886bd 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -54,9 +54,9 @@ def visit_Assign(self, node): # Take "x = self.conv1(x)" as an example -# genertate "tmp_func_self_conv1 = self.conv1(x)" -# and "x = tmp_func_self_conv1" -# and "message_hub.update_info('conv1', tmp_func_conv1)" +# genertate "self_conv1 = self.conv1(x)" +# and "x = self_conv1" +# and "message_hub.update_info('self_conv1', self_conv1)" def _get_tensor_key(target, attribute=None): target = target.replace('.', '_') if attribute: @@ -129,16 +129,16 @@ def visit_Assign(self, node): if self._attribute: assign_node_name = _get_tensor_key(self._target, self._attribute) - ast_arg2 = self._get_target_attribute() + attribute_node = self._get_target_attribute() else: - ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load()) + attribute_node = ast.Name(id=assign_node_name, ctx=ast.Load()) - deep_copy_ast_arg2 = ast.Call( + deep_copy_attribute_node = ast.Call( func=ast.Attribute( value=ast.Name(id='copy', ctx=ast.Load()), attr='deepcopy', ctx=ast.Load()), - args=[ast_arg2], + args=[attribute_node], keywords=[]) update_messagehub_node = ast.Expr( value=ast.Call( @@ -148,7 +148,7 @@ def visit_Assign(self, node): ctx=ast.Load()), args=[ ast.Constant( - value=assign_node_name), deep_copy_ast_arg2 + value=assign_node_name), deep_copy_attribute_node ], keywords=[])) self.function_visitor.call_nodes.clear() From 33dd3864399aaf4260f6694ca1beed80bce811b5 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 13 Sep 2023 01:48:29 +0800 Subject: [PATCH 24/39] add model select in recorder --- mmengine/hooks/recorder_hook.py | 45 ++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index e77a6886bd..5129a522a5 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -6,6 +6,7 @@ import textwrap import types from abc import ABCMeta, abstractmethod +from operator import attrgetter from typing import Any, Dict, List, Optional import torch @@ -158,7 +159,8 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): - def __init__(self, target: str): + def __init__(self, model, target: str): + self._model = model self._target = target @abstractmethod @@ -169,8 +171,8 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class FunctionRecorder(Recorder): - def __init__(self, target: str, index: list): - super().__init__(target) + def __init__(self, model: str, target: str, index: list): + super().__init__(model, target) self.index = index self.visit_assign = self._get_transformer_class() @@ -184,8 +186,8 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class AttributeRecorder(Recorder): - def __init__(self, target: str, attribute: str = None): - super().__init__(target) + def __init__(self, model: str, target: str, attribute: str = None): + super().__init__(model, target) self.attribute = attribute self.visit_assign = self._get_transformer_class() @@ -218,6 +220,9 @@ def __init__( if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') for recorder in recorders: + model = recorder.get('model') + if model is None: + recorder['model'] = '' target = recorder.get('target') attribute = recorder.get('attribute') tensor_key = _get_tensor_key(target, attribute) @@ -241,7 +246,7 @@ def __init__( self.tensor_dict[tensor_key] = list() self._recorders[tensor_key] = RECORDERS.build(recorder) - def _modify_func(self, func): + def _modify_forward_func(self, func, recorders): # Gets the source code for the function source = inspect.getsource(func) source = textwrap.dedent(source) @@ -271,7 +276,7 @@ def _modify_func(self, func): import_messagehub_node, import_copy_node, get_messagehub_node ] + func_body - for recorder in self._recorders.values(): + for recorder in recorders: tree = recorder.rewrite(tree) if self.print_modification: new_tree = ast.fix_missing_locations(tree) @@ -290,6 +295,24 @@ def _modify_func(self, func): namespace) return namespace[func.__name__] + def _get_model(self, model_name): + if not model_name: + return self.base_model + module_list = model_name.split('.') + model = self.base_model + for model_name in module_list: + model = attrgetter(model_name)(model) + return model + + def _group_recorder_by_model(self): + group_dict = {} + for recorder in self._recorders.values(): + key = recorder._model + if key not in group_dict: + group_dict[key] = [] + group_dict[key].append(recorder) + return group_dict + def before_run(self, runner) -> None: if not self.save_dir: self.save_dir = runner.work_dir @@ -298,9 +321,13 @@ def before_run(self, runner) -> None: self.message_hub = MessageHub.get_current_instance() # get model and modify its forward function model = runner.model + self.base_model = model self.origin_forward = model.forward - model.forward = types.MethodType( - self._modify_func(model.forward), model) + self.grouped_recorders = self._group_recorder_by_model() + for model_name, recorders in self.grouped_recorders.items(): + model = self._get_model(model_name) + model.forward = types.MethodType( + self._modify_forward_func(model.forward, recorders), model) def after_train_iter(self, runner, From a995399345b982f997af8d0fdd1ecb5389e7bb9a Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 16 Sep 2023 01:53:20 +0800 Subject: [PATCH 25/39] refactor: modify AttributeRecorderTransformer; modify _get_model; modify restore forward logic --- mmengine/hooks/recorder_hook.py | 156 +++++++++++++------------------- 1 file changed, 62 insertions(+), 94 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 5129a522a5..7b5715e1d9 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -58,58 +58,20 @@ def visit_Assign(self, node): # genertate "self_conv1 = self.conv1(x)" # and "x = self_conv1" # and "message_hub.update_info('self_conv1', self_conv1)" -def _get_tensor_key(target, attribute=None): +def _get_tensor_key(target): target = target.replace('.', '_') - if attribute: - target = target + '_' + attribute return target -class FuncCallVisitor(ast.NodeTransformer): - - def __init__(self, func_name): - self.func_name = func_name - self.call_nodes = [] - - # judge if the ast.Call node is user wanted - def _is_target_call(self, call_node): - assert isinstance(call_node, ast.Call) - call_chain_list = self.func_name.split('.') - call_node = call_node.func - if len(call_chain_list) == 1: - return isinstance( - call_node.func, - ast.Name) and call_node.func.id == call_chain_list[0] - else: - # Traversal call_chain_list in reverse order - for i in range(len(call_chain_list) - 1, 0, -1): - if isinstance(call_node, ast.Attribute - ) and call_node.attr == call_chain_list[i]: - call_node = call_node.value - else: - return False - return isinstance(call_node, - ast.Name) and call_node.id == call_chain_list[0] - - def visit_Call(self, node): - if not self._is_target_call(node): - return node - new_node = ast.Name(id=_get_tensor_key(self.func_name), ctx=ast.Load()) - self.call_nodes.append(node) - return new_node - - class AttributeRecorderTransformer(ast.NodeTransformer): - def __init__(self, target, attribute): + def __init__(self, target): super().__init__() self._target = target - self._attribute = attribute - self.function_visitor = FuncCallVisitor(target) + self._visited = False def _get_target_attribute(self): func_chain = self._target.split('.') - func_chain.append(self._attribute) assert len(func_chain) >= 2 attr = ast.Attribute( value=ast.Name(id=func_chain[0], ctx=ast.Load()), @@ -120,47 +82,38 @@ def _get_target_attribute(self): return attr def visit_Assign(self, node): - self.function_visitor.visit(node) - if self.function_visitor.call_nodes: - assign_right_node = self.function_visitor.call_nodes[0] - assign_node_name = _get_tensor_key(self._target, None) - assign_left_node = ast.Assign( - targets=[ast.Name(id=assign_node_name, ctx=ast.Store())], - value=assign_right_node) - if self._attribute: - assign_node_name = _get_tensor_key(self._target, - self._attribute) - attribute_node = self._get_target_attribute() - else: - attribute_node = ast.Name(id=assign_node_name, ctx=ast.Load()) - - deep_copy_attribute_node = ast.Call( + if self._visited: + return node + if node.targets[0].id == 'message_hub': + self._visited = True + + attribute_node = self._get_target_attribute() + + deep_copy_attribute_node = ast.Call( + func=ast.Attribute( + value=ast.Name(id='copy', ctx=ast.Load()), + attr='deepcopy', + ctx=ast.Load()), + args=[attribute_node], + keywords=[]) + update_messagehub_node = ast.Expr( + value=ast.Call( func=ast.Attribute( - value=ast.Name(id='copy', ctx=ast.Load()), - attr='deepcopy', + value=ast.Name(id='message_hub', ctx=ast.Load()), + attr='update_info', ctx=ast.Load()), - args=[attribute_node], - keywords=[]) - update_messagehub_node = ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='message_hub', ctx=ast.Load()), - attr='update_info', - ctx=ast.Load()), - args=[ - ast.Constant( - value=assign_node_name), deep_copy_attribute_node - ], - keywords=[])) - self.function_visitor.call_nodes.clear() - return [assign_left_node, update_messagehub_node, node] - return node + args=[ + ast.Constant(value=_get_tensor_key(self._target)), + deep_copy_attribute_node + ], + keywords=[])) + return [node, update_messagehub_node] class Recorder(metaclass=ABCMeta): def __init__(self, model, target: str): - self._model = model + self.model = model self._target = target @abstractmethod @@ -174,10 +127,8 @@ class FunctionRecorder(Recorder): def __init__(self, model: str, target: str, index: list): super().__init__(model, target) self.index = index - self.visit_assign = self._get_transformer_class() - - def _get_transformer_class(self): - return FunctionRecorderTransformer(self._target, self.index) + self.visit_assign = FunctionRecorderTransformer( + self._target, self.index) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) @@ -186,17 +137,25 @@ def rewrite(self, ast_tree): @RECORDERS.register_module() class AttributeRecorder(Recorder): - def __init__(self, model: str, target: str, attribute: str = None): + def __init__(self, model: str, target: str): super().__init__(model, target) - self.attribute = attribute - self.visit_assign = self._get_transformer_class() - - def _get_transformer_class(self): - return AttributeRecorderTransformer(self._target, self.attribute) + self.visit_assign = AttributeRecorderTransformer(self._target) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) + def _get_target_attribute(self): + func_chain = self._target.split('.') + func_chain.append(self._attribute) + assert len(func_chain) >= 2 + attr = ast.Attribute( + value=ast.Name(id=func_chain[0], ctx=ast.Load()), + attr=func_chain[1], + ctx=ast.Load()) + for ele in func_chain[2:]: + attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) + return attr + @HOOKS.register_module() class RecorderHook(Hook): @@ -211,6 +170,7 @@ def __init__( ): self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None + self.origin_func: Dict[Any, Any] = {} self._recorders: Dict[str, Recorder] = {} self.print_modification = print_modification self.save_dir = save_dir # type: ignore @@ -224,8 +184,7 @@ def __init__( if model is None: recorder['model'] = '' target = recorder.get('target') - attribute = recorder.get('attribute') - tensor_key = _get_tensor_key(target, attribute) + tensor_key = _get_tensor_key(target) if target is None: print_log( @@ -298,16 +257,14 @@ def _modify_forward_func(self, func, recorders): def _get_model(self, model_name): if not model_name: return self.base_model - module_list = model_name.split('.') model = self.base_model - for model_name in module_list: - model = attrgetter(model_name)(model) + model = attrgetter(model_name)(model) return model def _group_recorder_by_model(self): group_dict = {} for recorder in self._recorders.values(): - key = recorder._model + key = recorder.model if key not in group_dict: group_dict[key] = [] group_dict[key].append(recorder) @@ -322,10 +279,19 @@ def before_run(self, runner) -> None: # get model and modify its forward function model = runner.model self.base_model = model - self.origin_forward = model.forward self.grouped_recorders = self._group_recorder_by_model() for model_name, recorders in self.grouped_recorders.items(): - model = self._get_model(model_name) + try: + model = self._get_model(model_name) + except AttributeError: + print_log( + f'Can not record {model_name} in runner.model' + 'because it doesn\'t exist', + logger='current', + level=logging.WARNING) + continue + self.origin_func[model] = model.forward + print('here') model.forward = types.MethodType( self._modify_forward_func(model.forward, recorders), model) @@ -353,4 +319,6 @@ def after_train_epoch(self, runner) -> None: self._init_tensor_dict() def after_train(self, runner) -> None: - runner.model.forward = self.origin_forward + # restore forward function after train + for m, f in self.origin_func.items(): + m.forward = f From 963b54e1969fe771995fa3bc01560a5bc6b3ce7b Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 16 Sep 2023 17:51:54 +0800 Subject: [PATCH 26/39] add deepcopy, if var is Tensor, use Tensor.detach().clone() --- mmengine/hooks/recorder_hook.py | 74 +++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 7b5715e1d9..3b0eac6cfa 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -81,21 +81,64 @@ def _get_target_attribute(self): attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) return attr + def _deepcopy_varname(self): + return f'_deep_copy_{self._target.replace(".", "_")}' + + def _get_deep_copy_node(self, var_node): + if_node = ast.If( + test=ast.Call( + func=ast.Name(id='isinstance', ctx=ast.Load()), + args=[ + var_node, + ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='Tensor', + ctx=ast.Load()) + ], + keywords=[]), + body=[ + ast.Assign( + targets=[ + ast.Name(id=self._deepcopy_varname(), ctx=ast.Store()) + ], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute( + var_node, attr='detach', ctx=ast.Load()), + args=[], + keywords=[]), + attr='clone', + ctx=ast.Load()), + args=[], + keywords=[])) + ], + orelse=[ + ast.Assign( + targets=[ + ast.Name(id=self._deepcopy_varname(), ctx=ast.Store()) + ], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='copy', ctx=ast.Load()), + attr='deepcopy', + ctx=ast.Load()), + args=[var_node], + keywords=[])) + ]) + return if_node + def visit_Assign(self, node): if self._visited: return node + # insert update attribute node after message_hub assign node if node.targets[0].id == 'message_hub': self._visited = True attribute_node = self._get_target_attribute() - - deep_copy_attribute_node = ast.Call( - func=ast.Attribute( - value=ast.Name(id='copy', ctx=ast.Load()), - attr='deepcopy', - ctx=ast.Load()), - args=[attribute_node], - keywords=[]) + if_node = self._get_deep_copy_node(attribute_node) + deep_copy_attribute_node = ast.Name( + id=self._deepcopy_varname(), ctx=ast.Load()) update_messagehub_node = ast.Expr( value=ast.Call( func=ast.Attribute( @@ -107,7 +150,7 @@ def visit_Assign(self, node): deep_copy_attribute_node ], keywords=[])) - return [node, update_messagehub_node] + return [node, if_node, update_messagehub_node] class Recorder(metaclass=ABCMeta): @@ -144,18 +187,6 @@ def __init__(self, model: str, target: str): def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) - def _get_target_attribute(self): - func_chain = self._target.split('.') - func_chain.append(self._attribute) - assert len(func_chain) >= 2 - attr = ast.Attribute( - value=ast.Name(id=func_chain[0], ctx=ast.Load()), - attr=func_chain[1], - ctx=ast.Load()) - for ele in func_chain[2:]: - attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) - return attr - @HOOKS.register_module() class RecorderHook(Hook): @@ -291,7 +322,6 @@ def before_run(self, runner) -> None: level=logging.WARNING) continue self.origin_func[model] = model.forward - print('here') model.forward = types.MethodType( self._modify_forward_func(model.forward, recorders), model) From 4f434e455fa416be79eb823a929dc604e0e60fab Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 17 Sep 2023 10:11:02 +0800 Subject: [PATCH 27/39] refactor about store var name --- mmengine/hooks/recorder_hook.py | 94 +++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 39 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 3b0eac6cfa..003f9f8678 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -22,8 +22,9 @@ def function_with_index(function, index): class FunctionRecorderTransformer(ast.NodeTransformer): - def __init__(self, target, target_index): + def __init__(self, model, target, target_index): super().__init__() + self._model = model self._target = target if isinstance(target_index, list): self._target_index = set(target_index) @@ -31,6 +32,9 @@ def __init__(self, target, target_index): self._target_index = {target_index} self.count = -1 + def get_store_varname_with_index(self, index): + return f'{self._model}:{self._target}@{index}' + def visit_Assign(self, node): if node.targets[0].id != self._target: return node @@ -45,8 +49,7 @@ def visit_Assign(self, node): ctx=ast.Load()), args=[ ast.Constant( - value=function_with_index(node.targets[0].id, - self.count)), + value=self.get_store_varname_with_index(self.count)), ast.Name(id=node.targets[0].id, ctx=ast.Load()) ], keywords=[])) @@ -54,19 +57,11 @@ def visit_Assign(self, node): return [node, update_messagehub_node] -# Take "x = self.conv1(x)" as an example -# genertate "self_conv1 = self.conv1(x)" -# and "x = self_conv1" -# and "message_hub.update_info('self_conv1', self_conv1)" -def _get_tensor_key(target): - target = target.replace('.', '_') - return target - - class AttributeRecorderTransformer(ast.NodeTransformer): - def __init__(self, target): + def __init__(self, model, target): super().__init__() + self._model = model self._target = target self._visited = False @@ -84,6 +79,9 @@ def _get_target_attribute(self): def _deepcopy_varname(self): return f'_deep_copy_{self._target.replace(".", "_")}' + def _get_tensor_name(self): + return f'{self._model}:{self._target}' + def _get_deep_copy_node(self, var_node): if_node = ast.If( test=ast.Call( @@ -146,7 +144,7 @@ def visit_Assign(self, node): attr='update_info', ctx=ast.Load()), args=[ - ast.Constant(value=_get_tensor_key(self._target)), + ast.Constant(value=self._get_tensor_name()), deep_copy_attribute_node ], keywords=[])) @@ -156,13 +154,17 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): def __init__(self, model, target: str): - self.model = model + self._model = model self._target = target @abstractmethod def rewrite(self, ast_tree): pass + @abstractmethod + def get_store_varname(self): + pass + @RECORDERS.register_module() class FunctionRecorder(Recorder): @@ -171,22 +173,29 @@ def __init__(self, model: str, target: str, index: list): super().__init__(model, target) self.index = index self.visit_assign = FunctionRecorderTransformer( - self._target, self.index) + self._model, self._target, self.index) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) + def get_store_varname(self): + return [f'{self._model}:{self._target}@{i}' for i in self.index] + @RECORDERS.register_module() class AttributeRecorder(Recorder): def __init__(self, model: str, target: str): super().__init__(model, target) - self.visit_assign = AttributeRecorderTransformer(self._target) + self.visit_assign = AttributeRecorderTransformer( + self._model, self._target) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) + def get_store_varname(self): + return f'{self._model}:{self._target}' + @HOOKS.register_module() class RecorderHook(Hook): @@ -202,7 +211,7 @@ def __init__( self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None self.origin_func: Dict[Any, Any] = {} - self._recorders: Dict[str, Recorder] = {} + self._recorders: List[Recorder] = [] self.print_modification = print_modification self.save_dir = save_dir # type: ignore if filename_tmpl is None: @@ -213,9 +222,10 @@ def __init__( for recorder in recorders: model = recorder.get('model') if model is None: - recorder['model'] = '' + recorder['model'] = 'runner_model' target = recorder.get('target') - tensor_key = _get_tensor_key(target) + print(recorder) + # tensor_key = _get_tensor_key(target) if target is None: print_log( @@ -223,18 +233,18 @@ def __init__( 'because recorder has no target', logger='current', level=logging.WARNING) - if recorder.get('type') == 'FunctionRecorder': - index = recorder.get('index', 0) - if isinstance(index, list): - for i in index: - self.tensor_dict[function_with_index(target, - i)] = list() - elif isinstance(index, int): - self.tensor_dict[function_with_index(target, - index)] = list() - elif recorder.get('type') == 'AttributeRecorder': - self.tensor_dict[tensor_key] = list() - self._recorders[tensor_key] = RECORDERS.build(recorder) + # if recorder.get('type') == 'FunctionRecorder': + # index = recorder.get('index', 0) + # if isinstance(index, list): + # for i in index: + # self.tensor_dict[function_with_index(target, + # i)] = list() + # elif isinstance(index, int): + # self.tensor_dict[function_with_index(target, + # index)] = list() + # elif recorder.get('type') == 'AttributeRecorder': + # self.tensor_dict[tensor_key] = list() + self._recorders.append(RECORDERS.build(recorder)) def _modify_forward_func(self, func, recorders): # Gets the source code for the function @@ -286,7 +296,7 @@ def _modify_forward_func(self, func, recorders): return namespace[func.__name__] def _get_model(self, model_name): - if not model_name: + if not model_name or model_name == 'runner_model': return self.base_model model = self.base_model model = attrgetter(model_name)(model) @@ -294,8 +304,8 @@ def _get_model(self, model_name): def _group_recorder_by_model(self): group_dict = {} - for recorder in self._recorders.values(): - key = recorder.model + for recorder in self._recorders: + key = recorder._model if key not in group_dict: group_dict[key] = [] group_dict[key].append(recorder) @@ -307,9 +317,10 @@ def before_run(self, runner) -> None: # get messagehub instance and store it. self.message_hub = MessageHub.get_current_instance() + # init_save_var_dict + self._init_tensor_dict() # get model and modify its forward function - model = runner.model - self.base_model = model + self.base_model = runner.model self.grouped_recorders = self._group_recorder_by_model() for model_name, recorders in self.grouped_recorders.items(): try: @@ -339,8 +350,13 @@ def _save_record(self, step): torch.save(self.tensor_dict, path) def _init_tensor_dict(self): - for k in self.tensor_dict.keys(): - self.tensor_dict[k] = list() + for recorder in self._recorders: + varname = recorder.get_store_varname() + if isinstance(varname, list): + for name in varname: + self.tensor_dict[name] = list() + else: + self.tensor_dict[varname] = list() def after_train_epoch(self, runner) -> None: step = runner.epoch + 1 From 1f54cfce7d8bab74c071738d554daa9aa92e85de Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 17 Sep 2023 10:11:59 +0800 Subject: [PATCH 28/39] delete useless lines --- mmengine/hooks/recorder_hook.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 003f9f8678..82a6713cae 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -16,10 +16,6 @@ from . import Hook -def function_with_index(function, index): - return function + '@' + str(index) - - class FunctionRecorderTransformer(ast.NodeTransformer): def __init__(self, model, target, target_index): @@ -233,17 +229,6 @@ def __init__( 'because recorder has no target', logger='current', level=logging.WARNING) - # if recorder.get('type') == 'FunctionRecorder': - # index = recorder.get('index', 0) - # if isinstance(index, list): - # for i in index: - # self.tensor_dict[function_with_index(target, - # i)] = list() - # elif isinstance(index, int): - # self.tensor_dict[function_with_index(target, - # index)] = list() - # elif recorder.get('type') == 'AttributeRecorder': - # self.tensor_dict[tensor_key] = list() self._recorders.append(RECORDERS.build(recorder)) def _modify_forward_func(self, func, recorders): From 581d668d2c709fe68208dc93f5f5266aa2f763c5 Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 17 Sep 2023 14:52:47 +0800 Subject: [PATCH 29/39] add appoint specify method --- mmengine/hooks/recorder_hook.py | 91 ++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 82a6713cae..3fac855280 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -18,9 +18,10 @@ class FunctionRecorderTransformer(ast.NodeTransformer): - def __init__(self, model, target, target_index): + def __init__(self, model, method, target, target_index): super().__init__() self._model = model + self._method = method self._target = target if isinstance(target_index, list): self._target_index = set(target_index) @@ -29,7 +30,7 @@ def __init__(self, model, target, target_index): self.count = -1 def get_store_varname_with_index(self, index): - return f'{self._model}:{self._target}@{index}' + return f'{self._model}:{self._method}:{self._target}@{index}' def visit_Assign(self, node): if node.targets[0].id != self._target: @@ -55,9 +56,10 @@ def visit_Assign(self, node): class AttributeRecorderTransformer(ast.NodeTransformer): - def __init__(self, model, target): + def __init__(self, model, method, target): super().__init__() self._model = model + self._method = method self._target = target self._visited = False @@ -76,7 +78,7 @@ def _deepcopy_varname(self): return f'_deep_copy_{self._target.replace(".", "_")}' def _get_tensor_name(self): - return f'{self._model}:{self._target}' + return f'{self._model}:{self._method}:{self._target}' def _get_deep_copy_node(self, var_node): if_node = ast.If( @@ -149,8 +151,9 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): - def __init__(self, model, target: str): + def __init__(self, model, method, target: str): self._model = model + self._method = method self._target = target @abstractmethod @@ -165,32 +168,35 @@ def get_store_varname(self): @RECORDERS.register_module() class FunctionRecorder(Recorder): - def __init__(self, model: str, target: str, index: list): - super().__init__(model, target) + def __init__(self, model: str, method: str, target: str, index: list): + super().__init__(model, method, target) self.index = index self.visit_assign = FunctionRecorderTransformer( - self._model, self._target, self.index) + self._model, self._method, self._target, self.index) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) def get_store_varname(self): - return [f'{self._model}:{self._target}@{i}' for i in self.index] + return [ + f'{self._model}:{self._method}:{self._target}@{i}' + for i in self.index + ] @RECORDERS.register_module() class AttributeRecorder(Recorder): - def __init__(self, model: str, target: str): - super().__init__(model, target) + def __init__(self, model: str, method: str, target: str): + super().__init__(model, method, target) self.visit_assign = AttributeRecorderTransformer( - self._model, self._target) + self._model, self._method, self._target) def rewrite(self, ast_tree): return self.visit_assign.visit(ast_tree) def get_store_varname(self): - return f'{self._model}:{self._target}' + return f'{self._model}:{self._method}:{self._target}' @HOOKS.register_module() @@ -206,7 +212,7 @@ def __init__( ): self.tensor_dict: Dict[str, Any] = {} self.origin_forward = None - self.origin_func: Dict[Any, Any] = {} + self.origin_methods: Dict[Any, Any] = {} self._recorders: List[Recorder] = [] self.print_modification = print_modification self.save_dir = save_dir # type: ignore @@ -220,8 +226,9 @@ def __init__( if model is None: recorder['model'] = 'runner_model' target = recorder.get('target') - print(recorder) - # tensor_key = _get_tensor_key(target) + method = recorder.get('method') + if method is None: + recorder['method'] = 'forward' if target is None: print_log( @@ -287,15 +294,31 @@ def _get_model(self, model_name): model = attrgetter(model_name)(model) return model - def _group_recorder_by_model(self): + def _group_recorder_by_model_method(self): group_dict = {} for recorder in self._recorders: key = recorder._model if key not in group_dict: group_dict[key] = [] group_dict[key].append(recorder) + for model_name, recorders in group_dict.items(): + group_dict[model_name] = self._group_recorder_by_method(recorders) return group_dict + def _group_recorder_by_method(self, recorders): + group_dict = {} + for recorder in recorders: + key = recorder._method + if key not in group_dict: + group_dict[key] = [] + group_dict[key].append(recorder) + return group_dict + + def _save_origin_method(self, model, method_name, origin_method): + if model not in self.origin_methods: + self.origin_methods[model] = {} + self.origin_methods[model][method_name] = origin_method + def before_run(self, runner) -> None: if not self.save_dir: self.save_dir = runner.work_dir @@ -306,20 +329,37 @@ def before_run(self, runner) -> None: self._init_tensor_dict() # get model and modify its forward function self.base_model = runner.model - self.grouped_recorders = self._group_recorder_by_model() - for model_name, recorders in self.grouped_recorders.items(): + self.grouped_recorders = self._group_recorder_by_model_method() + for model_name, group_method_recorders in self.grouped_recorders.items( + ): try: model = self._get_model(model_name) except AttributeError: print_log( - f'Can not record {model_name} in runner.model' + f'Can not record {model_name} in runner.model ' 'because it doesn\'t exist', logger='current', level=logging.WARNING) continue - self.origin_func[model] = model.forward - model.forward = types.MethodType( - self._modify_forward_func(model.forward, recorders), model) + for method_name, recorders in group_method_recorders.items(): + try: + method = getattr(model, method_name) + except AttributeError: + print_log( + f'Can not record {method_name} in {model_name}' + 'because it doesn\'t exist', + logger='current', + level=logging.WARNING) + continue + # self.origin_methods[model][method_name] = method + print_log( + f'Modify {method_name} in {model_name}', + logger='current', + level=logging.INFO) + self._save_origin_method(model, method_name, method) + new_method = types.MethodType( + self._modify_forward_func(method, recorders), model) + setattr(model, method_name, new_method) def after_train_iter(self, runner, @@ -351,5 +391,6 @@ def after_train_epoch(self, runner) -> None: def after_train(self, runner) -> None: # restore forward function after train - for m, f in self.origin_func.items(): - m.forward = f + for model, v in self.origin_methods.items(): + for method_name, origin_method in v.items(): + setattr(model, method_name, origin_method) From 10e447f77d8a9bd78e23ea3ab1deb73b53dcf1ca Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 17 Sep 2023 15:27:43 +0800 Subject: [PATCH 30/39] update test script --- examples/attribute_toy_test.py | 54 ++++++++++++++++++++++++++++++++++ examples/function_toy_test.py | 54 ++++++++++++++++++++++++++++++++++ examples/recorder_hook_test.py | 10 ++++--- 3 files changed, 114 insertions(+), 4 deletions(-) create mode 100644 examples/attribute_toy_test.py create mode 100644 examples/function_toy_test.py diff --git a/examples/attribute_toy_test.py b/examples/attribute_toy_test.py new file mode 100644 index 0000000000..a151957c5a --- /dev/null +++ b/examples/attribute_toy_test.py @@ -0,0 +1,54 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader + +from mmengine.model import BaseModel +from mmengine.runner import Runner + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.linear1 = nn.Linear(2, 2) + self.linear2 = nn.Linear(2, 1) + + def forward(self, inputs, data_samples, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_samples, list): + data_sample = torch.stack(data_samples) + outputs = self.linear1(inputs) + outputs = self.linear2(outputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (data_sample - outputs).sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +x = [(torch.ones(2, 2), [torch.ones(2, 1)])] +# train_dataset = [x, x, x] +train_dataset = x * 50 +train_dataloader = DataLoader(train_dataset, batch_size=1) + +runner = Runner( + model=ToyModel(), + custom_hooks=[ + dict( + type='RecorderHook', + recorders=[ + dict(type='AttributeRecorder', target='self.linear1.weight') + ], + save_dir='./work_dir', + print_modification=True) + ], + work_dir='tmp_dir', + train_dataloader=train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=10), + optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01))) +runner.train() diff --git a/examples/function_toy_test.py b/examples/function_toy_test.py new file mode 100644 index 0000000000..7f151ab5f1 --- /dev/null +++ b/examples/function_toy_test.py @@ -0,0 +1,54 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader + +from mmengine.model import BaseModel +from mmengine.runner import Runner + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.linear1 = nn.Linear(2, 2) + self.linear2 = nn.Linear(2, 1) + + def forward(self, inputs, data_samples, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_samples, list): + data_sample = torch.stack(data_samples) + outputs = self.linear1(inputs) + outputs = self.linear2(outputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (data_sample - outputs).sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +x = [(torch.ones(2, 2), [torch.ones(2, 1)])] +# train_dataset = [x, x, x] +train_dataset = x * 50 +train_dataloader = DataLoader(train_dataset, batch_size=2) + +runner = Runner( + model=ToyModel(), + custom_hooks=[ + dict( + type='RecorderHook', + recorders=[ + dict(type='FunctionRecorder', target='outputs', index=[1]) + ], + save_dir='./work_dir', + print_modification=True) + ], + work_dir='tmp_dir', + train_dataloader=train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=10), + optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01))) +runner.train() diff --git a/examples/recorder_hook_test.py b/examples/recorder_hook_test.py index fdf65ce402..fe975e80c9 100644 --- a/examples/recorder_hook_test.py +++ b/examples/recorder_hook_test.py @@ -79,9 +79,11 @@ def compute_metrics(self, results): type='RecorderHook', recorders=[ dict( - type='AttributeRecorder', - target='self.resnet', - attribute='weight') + model='resnet', + method='_forward_impl', + type='FunctionRecorder', + target='x', + index=[0, 1, 2]) ], save_dir='./work_dir', print_modification=True) @@ -90,7 +92,7 @@ def compute_metrics(self, results): work_dir='./work_dir', train_dataloader=train_dataloader, optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), - train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1), + train_cfg=dict(by_epoch=True, max_epochs=1, val_interval=1), val_dataloader=val_dataloader, val_cfg=dict(), val_evaluator=dict(type=Accuracy), From 2d5447b88e5e613fe489e7c57b3f0f0dd0d7dfca Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 19 Sep 2023 21:41:12 +0800 Subject: [PATCH 31/39] use MessageHub.get_instance --- mmengine/hooks/recorder_hook.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 3fac855280..f2dd9ffff2 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -218,6 +218,8 @@ def __init__( self.save_dir = save_dir # type: ignore if filename_tmpl is None: self.filename_tmpl = 'record_epoch_{}.pth' + else: + self.filename_tmpl = filename_tmpl if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') @@ -259,9 +261,9 @@ def _modify_forward_func(self, func, recorders): value=ast.Call( func=ast.Attribute( value=ast.Name(id='MessageHub', ctx=ast.Load()), - attr='get_current_instance', + attr='get_instance', ctx=ast.Load()), - args=[], + args=[ast.Constant(value='recorder_hook')], keywords=[])) tree.body[0].body = [ @@ -324,7 +326,7 @@ def before_run(self, runner) -> None: self.save_dir = runner.work_dir # get messagehub instance and store it. - self.message_hub = MessageHub.get_current_instance() + self.message_hub = MessageHub.get_instance('recorder_hook') # init_save_var_dict self._init_tensor_dict() # get model and modify its forward function From e7e439dce3751bf6b49dbddd714c7e18b087f37e Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 20 Sep 2023 13:26:24 +0800 Subject: [PATCH 32/39] add docs --- mmengine/hooks/recorder_hook.py | 225 +++++++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 1 deletion(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index f2dd9ffff2..397c760267 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -17,6 +17,19 @@ class FunctionRecorderTransformer(ast.NodeTransformer): + """Transformer that modifies the Abstract Syntax Tree (AST) for function- + related record updates. + + The transformer is responsible for updating the AST to add the logic needed + to record tensor data at specific indices when a function is called within + the model's forward pass. + + Args: + model (str): The name or identifier of the model. + method (str): The method in which the transformer operates. + target (str): The target function to be recorded. + target_index (int or list): Index of var to record. + """ def __init__(self, model, method, target, target_index): super().__init__() @@ -30,9 +43,25 @@ def __init__(self, model, method, target, target_index): self.count = -1 def get_store_varname_with_index(self, index): + """Generate and return the variable name with the specified index. + + Args: + index (int): The index for which to generate the variable name. + + Returns: + str: The variable name for the given index. + """ return f'{self._model}:{self._method}:{self._target}@{index}' def visit_Assign(self, node): + """Visit and possibly transform an assignment node in the AST. + + Args: + node: The AST node being visited. + + Returns: + Modified AST node or a list of AST nodes. + """ if node.targets[0].id != self._target: return node self.count += 1 @@ -55,6 +84,17 @@ def visit_Assign(self, node): class AttributeRecorderTransformer(ast.NodeTransformer): + """Transformer that modifies the Abstract Syntax Tree (AST) for attribute- + related record updates. + + The transformer is responsible for updating the AST to add the logic needed + to record tensor data from model attributes during the forward pass. + + Args: + model (str): The name or identifier of the model. + method (str): The method in which the transformer operates. + target (str): The target attribute to be recorded. + """ def __init__(self, model, method, target): super().__init__() @@ -64,6 +104,11 @@ def __init__(self, model, method, target): self._visited = False def _get_target_attribute(self): + """Extract and return the target attribute from the AST as a node. + + Returns: + ast.Attribute: The node representing the target attribute. + """ func_chain = self._target.split('.') assert len(func_chain) >= 2 attr = ast.Attribute( @@ -75,12 +120,33 @@ def _get_target_attribute(self): return attr def _deepcopy_varname(self): + """Generate and return a variable name for the deep copy of the target + attribute. + + Returns: + str: The variable name for the deep copy of the target attribute. + """ return f'_deep_copy_{self._target.replace(".", "_")}' def _get_tensor_name(self): + """Generate and return the tensor name for the target attribute. + + Returns: + str: The tensor name for the target attribute. + """ return f'{self._model}:{self._method}:{self._target}' def _get_deep_copy_node(self, var_node): + """Generate and return the AST node for deep copying the target + attribute. + + Args: + var_node (ast.Name): + The AST node representing the variable to be deep copied. + + Returns: + ast.If: The `if` node for deep copying the target attribute. + """ if_node = ast.If( test=ast.Call( func=ast.Name(id='isinstance', ctx=ast.Load()), @@ -125,6 +191,14 @@ def _get_deep_copy_node(self, var_node): return if_node def visit_Assign(self, node): + """Visit and possibly transform an assignment node in the AST. + + Args: + node: The AST node being visited. + + Returns: + Modified AST node or a list of AST nodes. + """ if self._visited: return node # insert update attribute node after message_hub assign node @@ -150,6 +224,16 @@ def visit_Assign(self, node): class Recorder(metaclass=ABCMeta): + """Abstract base class for implementing tensor data recorders. + + The Recorder is intended to be a blueprint for creating specific recorder + types to capture tensor data during model forward passes. + + Args: + model: The name or identifier of the model. + method: The method on which the Recorder is attached. + target (str): The target layer or tensor to be recorded. + """ def __init__(self, model, method, target: str): self._model = model @@ -158,15 +242,40 @@ def __init__(self, model, method, target: str): @abstractmethod def rewrite(self, ast_tree): + """Rewrite the AST tree to include recording logic. + + Args: + ast_tree: The Abstract Syntax Tree to be rewritten. + + Returns: + Modified AST tree. + """ pass @abstractmethod def get_store_varname(self): + """Get the variable name used for storing recorded data. + + Returns: + Variable name or a list of variable names. + """ pass @RECORDERS.register_module() class FunctionRecorder(Recorder): + """A Recorder implementation to capture output tensor data from function + calls. + + This Recorder hooks into specific function calls within the model's forward + pass and records tensor data at specified indices. + + Args: + model (str): The name or identifier of the model. + method (str): The method on which the Recorder is attached. + target (str): The target function to be recorded. + index (list): List of indices within the function call to record. + """ def __init__(self, model: str, method: str, target: str, index: list): super().__init__(model, method, target) @@ -175,9 +284,16 @@ def __init__(self, model: str, method: str, target: str, index: list): self._model, self._method, self._target, self.index) def rewrite(self, ast_tree): + """Rewrite the AST tree to include recording logic for output of + function calls.""" return self.visit_assign.visit(ast_tree) def get_store_varname(self): + """Generate and return variable names based on output name. + + Outputs with the same name will be distinguished based on the sequence + number. + """ return [ f'{self._model}:{self._method}:{self._target}@{i}' for i in self.index @@ -186,6 +302,16 @@ def get_store_varname(self): @RECORDERS.register_module() class AttributeRecorder(Recorder): + """A Recorder implementation to capture tensor data from model attributes. + + This Recorder hooks into model attributes and records their tensor data + during the forward pass. + + Args: + model (str): The name or identifier of the model. + method (str): The method on which the Recorder is attached. + target (str): The target attribute to be recorded. + """ def __init__(self, model: str, method: str, target: str): super().__init__(model, method, target) @@ -193,14 +319,45 @@ def __init__(self, model: str, method: str, target: str): self._model, self._method, self._target) def rewrite(self, ast_tree): + """Rewrite the AST tree to include recording logic for attributes.""" return self.visit_assign.visit(ast_tree) def get_store_varname(self): + """Generate and return variable name based on model attributes.""" return f'{self._model}:{self._method}:{self._target}' @HOOKS.register_module() class RecorderHook(Hook): + """A hook to record information during model training. + + This hook allows users to modify and record certain model variables + during training iterations and save them for analysis purposes. + It provides the ability to modify any function of a model + using ast module in python. + + Args: + recorders (Optional[List[Dict]]): + Configurations for individual recorders. + Each recorder dict should contain the target model and method. + print_modification (bool): Whether to print the modified source code + after it's been altered by a recorder. Defaults to True. + save_dir (str): The directory where recorded data will be saved. + If not specified, it will use the runner's work directory. + Defaults to None. + filename_tmpl (Optional[str]): The filename template used when saving + recorded data. If not provided, a default template will be used. + Defaults to None. + + Examples: + >>> recorder_hook_cfg = dict( + ... recorders=[{'model': 'runner_model', + ... 'target': 'layer1', 'method': 'forward'}], + ... print_modification=True, + ... save_dir='./records', + ... filename_tmpl='record_epoch_{}.pth' + ... ) + """ priority = 'LOWEST' def __init__( @@ -241,6 +398,15 @@ def __init__( self._recorders.append(RECORDERS.build(recorder)) def _modify_forward_func(self, func, recorders): + """Modify the forward function to incorporate recording behaviors. + + Args: + func (callable): Original forward function to modify. + recorders (List[Recorder]): List of recorder instances. + + Returns: + callable: Modified forward function. + """ # Gets the source code for the function source = inspect.getsource(func) source = textwrap.dedent(source) @@ -290,6 +456,15 @@ def _modify_forward_func(self, func, recorders): return namespace[func.__name__] def _get_model(self, model_name): + """Retrieve a specific model from runner. + If model_name == 'runner_model', return runner.model. + Else, return runner.model.model_name + Args: + model_name (str): Name of the model to retrieve. + + Returns: + Model: Requested model instance. + """ if not model_name or model_name == 'runner_model': return self.base_model model = self.base_model @@ -297,6 +472,11 @@ def _get_model(self, model_name): return model def _group_recorder_by_model_method(self): + """Group recorders by model and method. + + Returns: + dict: Grouped recorders. + """ group_dict = {} for recorder in self._recorders: key = recorder._model @@ -308,6 +488,14 @@ def _group_recorder_by_model_method(self): return group_dict def _group_recorder_by_method(self, recorders): + """Group recorders by method. + + Args: + recorders (List[Recorder]): List of recorder instances. + + Returns: + dict: Grouped recorders. + """ group_dict = {} for recorder in recorders: key = recorder._method @@ -317,11 +505,23 @@ def _group_recorder_by_method(self, recorders): return group_dict def _save_origin_method(self, model, method_name, origin_method): + """Save reference to the original method of a model. + + Args: + model (Model): Model instance. + method_name (str): Name of the method to save. + origin_method (callable): Original method to save. + """ if model not in self.origin_methods: self.origin_methods[model] = {} self.origin_methods[model][method_name] = origin_method def before_run(self, runner) -> None: + """Prepare for training by modifying methods for recording. + + Args: + runner (Runner): The runner of the training process. + """ if not self.save_dir: self.save_dir = runner.work_dir @@ -368,15 +568,29 @@ def after_train_iter(self, batch_idx: int, data_batch=None, outputs=None) -> None: + """Record specific tensors after each training iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): Index of the current batch. + data_batch (Optional): Current data batch. Default is None. + outputs (Optional): Outputs from the current iteration. + """ for key in self.tensor_dict.keys(): self.tensor_dict[key].append(self.message_hub.get_info(key)) def _save_record(self, step): + """Save recorded tensors to disk. + + Args: + step (int): Current training step or epoch. + """ recorder_file_name = self.filename_tmpl.format(step) path = osp.join(self.save_dir, recorder_file_name) torch.save(self.tensor_dict, path) def _init_tensor_dict(self): + """Initialize the tensor dictionary for recording.""" for recorder in self._recorders: varname = recorder.get_store_varname() if isinstance(varname, list): @@ -386,13 +600,22 @@ def _init_tensor_dict(self): self.tensor_dict[varname] = list() def after_train_epoch(self, runner) -> None: + """Save recorded tensors after each training epoch. + + Args: + runner (Runner): The runner of the training process. + """ step = runner.epoch + 1 runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') self._save_record(step) self._init_tensor_dict() def after_train(self, runner) -> None: - # restore forward function after train + """Restore original methods after training. + + Args: + runner (Runner): The runner of the training process. + """ for model, v in self.origin_methods.items(): for method_name, origin_method in v.items(): setattr(model, method_name, origin_method) From b58540b0507dcdd404661dfcf7b3c3ab70678906 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 20 Sep 2023 13:30:54 +0800 Subject: [PATCH 33/39] try to add type hint --- mmengine/hooks/recorder_hook.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 397c760267..1129cc19cd 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -291,7 +291,7 @@ def rewrite(self, ast_tree): def get_store_varname(self): """Generate and return variable names based on output name. - Outputs with the same name will be distinguished based on the sequence + Outputs with the same name will be distinguished based on the index number. """ return [ @@ -368,11 +368,10 @@ def __init__( filename_tmpl: Optional[str] = None, ): self.tensor_dict: Dict[str, Any] = {} - self.origin_forward = None self.origin_methods: Dict[Any, Any] = {} self._recorders: List[Recorder] = [] - self.print_modification = print_modification - self.save_dir = save_dir # type: ignore + self.print_modification: bool = print_modification + self.save_dir: Optional[str] = save_dir # type: ignore if filename_tmpl is None: self.filename_tmpl = 'record_epoch_{}.pth' else: From ea29bfaeb1a65d2942912039002e6db1493e8558 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 20 Sep 2023 16:23:51 +0800 Subject: [PATCH 34/39] add type hint --- mmengine/hooks/recorder_hook.py | 95 ++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 38 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 1129cc19cd..47ff0acb87 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -7,7 +7,7 @@ import types from abc import ABCMeta, abstractmethod from operator import attrgetter -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -31,7 +31,8 @@ class FunctionRecorderTransformer(ast.NodeTransformer): target_index (int or list): Index of var to record. """ - def __init__(self, model, method, target, target_index): + def __init__(self, model: str, method: str, target: str, + target_index: Union[int, List[int]]): super().__init__() self._model = model self._method = method @@ -42,7 +43,7 @@ def __init__(self, model, method, target, target_index): self._target_index = {target_index} self.count = -1 - def get_store_varname_with_index(self, index): + def get_store_varname_with_index(self, index: int) -> str: """Generate and return the variable name with the specified index. Args: @@ -53,7 +54,7 @@ def get_store_varname_with_index(self, index): """ return f'{self._model}:{self._method}:{self._target}@{index}' - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> Union[Any, List[Any]]: """Visit and possibly transform an assignment node in the AST. Args: @@ -62,6 +63,7 @@ def visit_Assign(self, node): Returns: Modified AST node or a list of AST nodes. """ + assert isinstance(node.targets[0], ast.Name) if node.targets[0].id != self._target: return node self.count += 1 @@ -103,7 +105,7 @@ def __init__(self, model, method, target): self._target = target self._visited = False - def _get_target_attribute(self): + def _get_target_attribute(self) -> ast.Attribute: """Extract and return the target attribute from the AST as a node. Returns: @@ -119,7 +121,7 @@ def _get_target_attribute(self): attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load()) return attr - def _deepcopy_varname(self): + def _deepcopy_varname(self) -> str: """Generate and return a variable name for the deep copy of the target attribute. @@ -128,7 +130,7 @@ def _deepcopy_varname(self): """ return f'_deep_copy_{self._target.replace(".", "_")}' - def _get_tensor_name(self): + def _get_tensor_name(self) -> str: """Generate and return the tensor name for the target attribute. Returns: @@ -136,7 +138,7 @@ def _get_tensor_name(self): """ return f'{self._model}:{self._method}:{self._target}' - def _get_deep_copy_node(self, var_node): + def _get_deep_copy_node(self, var_node) -> ast.If: """Generate and return the AST node for deep copying the target attribute. @@ -190,7 +192,7 @@ def _get_deep_copy_node(self, var_node): ]) return if_node - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> Union[Any, List[Any]]: """Visit and possibly transform an assignment node in the AST. Args: @@ -202,7 +204,8 @@ def visit_Assign(self, node): if self._visited: return node # insert update attribute node after message_hub assign node - if node.targets[0].id == 'message_hub': + if isinstance(node.targets[0], + ast.Name) and node.targets[0].id == 'message_hub': self._visited = True attribute_node = self._get_target_attribute() @@ -235,13 +238,13 @@ class Recorder(metaclass=ABCMeta): target (str): The target layer or tensor to be recorded. """ - def __init__(self, model, method, target: str): + def __init__(self, model: str, method: str, target: str): self._model = model self._method = method self._target = target @abstractmethod - def rewrite(self, ast_tree): + def rewrite(self, ast_tree) -> Any: """Rewrite the AST tree to include recording logic. Args: @@ -253,7 +256,7 @@ def rewrite(self, ast_tree): pass @abstractmethod - def get_store_varname(self): + def get_store_varname(self) -> Union[str, List[str]]: """Get the variable name used for storing recorded data. Returns: @@ -283,12 +286,12 @@ def __init__(self, model: str, method: str, target: str, index: list): self.visit_assign = FunctionRecorderTransformer( self._model, self._method, self._target, self.index) - def rewrite(self, ast_tree): + def rewrite(self, ast_tree) -> Any: """Rewrite the AST tree to include recording logic for output of function calls.""" return self.visit_assign.visit(ast_tree) - def get_store_varname(self): + def get_store_varname(self) -> List[str]: """Generate and return variable names based on output name. Outputs with the same name will be distinguished based on the index @@ -318,11 +321,11 @@ def __init__(self, model: str, method: str, target: str): self.visit_assign = AttributeRecorderTransformer( self._model, self._method, self._target) - def rewrite(self, ast_tree): + def rewrite(self, ast_tree) -> Any: """Rewrite the AST tree to include recording logic for attributes.""" return self.visit_assign.visit(ast_tree) - def get_store_varname(self): + def get_store_varname(self) -> str: """Generate and return variable name based on model attributes.""" return f'{self._model}:{self._method}:{self._target}' @@ -396,7 +399,8 @@ def __init__( level=logging.WARNING) self._recorders.append(RECORDERS.build(recorder)) - def _modify_forward_func(self, func, recorders): + def _modify_forward_func(self, func: Callable, + recorders: List[Recorder]) -> Callable: """Modify the forward function to incorporate recording behaviors. Args: @@ -412,8 +416,12 @@ def _modify_forward_func(self, func, recorders): # Parse source code as ast tree = ast.parse(source) - - func_body = tree.body[0].body + breakpoint() + if isinstance(tree.body[0], ast.FunctionDef): + func_body = tree.body[0].body + else: + raise ValueError( + "Unexpected node type that doesn't have a body attribute.") # import mmengine.logging.MessageHub import_messagehub_node = ast.ImportFrom( module='mmengine.logging', @@ -448,13 +456,18 @@ def _modify_forward_func(self, func, recorders): tree = ast.fix_missing_locations(tree) # Compile the modified ast as a new function - namespace = {} + namespace: Dict[str, Any] = {} + if isinstance(func, types.FunctionType): + globals_dict = func.__globals__ + func_name = func.__name__ + else: + raise TypeError('It is not a function') exec( - compile(tree, filename='', mode='exec'), func.__globals__, + compile(tree, filename='', mode='exec'), globals_dict, namespace) - return namespace[func.__name__] + return namespace[func_name] - def _get_model(self, model_name): + def _get_model(self, model_name: str) -> Any: """Retrieve a specific model from runner. If model_name == 'runner_model', return runner.model. Else, return runner.model.model_name @@ -470,23 +483,28 @@ def _get_model(self, model_name): model = attrgetter(model_name)(model) return model - def _group_recorder_by_model_method(self): + def _group_recorder_by_model_method( + self) -> Dict[str, Dict[str, List[Recorder]]]: """Group recorders by model and method. Returns: dict: Grouped recorders. """ - group_dict = {} + group_model_dist = {} + group_model_method_dict: Dict[str, Dict[str, List[Recorder]]] = {} for recorder in self._recorders: key = recorder._model - if key not in group_dict: - group_dict[key] = [] - group_dict[key].append(recorder) - for model_name, recorders in group_dict.items(): - group_dict[model_name] = self._group_recorder_by_method(recorders) - return group_dict - - def _group_recorder_by_method(self, recorders): + if key not in group_model_dist: + group_model_dist[key] = [recorder] + else: + group_model_dist[key].append(recorder) + for model_name, recorders in group_model_dist.items(): + group_model_method_dict[ + model_name] = self._group_recorder_by_method(recorders) + return group_model_method_dict + + def _group_recorder_by_method( + self, recorders: List[Recorder]) -> Dict[str, List[Recorder]]: """Group recorders by method. Args: @@ -495,15 +513,16 @@ def _group_recorder_by_method(self, recorders): Returns: dict: Grouped recorders. """ - group_dict = {} + group_dict: Dict[str, List[Recorder]] = {} for recorder in recorders: key = recorder._method if key not in group_dict: - group_dict[key] = [] + group_dict[key] = [recorder] group_dict[key].append(recorder) return group_dict - def _save_origin_method(self, model, method_name, origin_method): + def _save_origin_method(self, model: Any, method_name: str, + origin_method: Callable) -> None: """Save reference to the original method of a model. Args: @@ -582,7 +601,7 @@ def _save_record(self, step): """Save recorded tensors to disk. Args: - step (int): Current training step or epoch. + step (int): Current training epoch. """ recorder_file_name = self.filename_tmpl.format(step) path = osp.join(self.save_dir, recorder_file_name) From 06fabbe322feaa3585caa2ba98519f1863d3c101 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 27 Sep 2023 17:19:26 +0800 Subject: [PATCH 35/39] add type ignore --- mmengine/hooks/recorder_hook.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 47ff0acb87..f391fd55a9 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -416,7 +416,6 @@ def _modify_forward_func(self, func: Callable, # Parse source code as ast tree = ast.parse(source) - breakpoint() if isinstance(tree.body[0], ast.FunctionDef): func_body = tree.body[0].body else: @@ -457,11 +456,8 @@ def _modify_forward_func(self, func: Callable, # Compile the modified ast as a new function namespace: Dict[str, Any] = {} - if isinstance(func, types.FunctionType): - globals_dict = func.__globals__ - func_name = func.__name__ - else: - raise TypeError('It is not a function') + globals_dict = func.__globals__ # type: ignore + func_name = func.__name__ # type: ignore exec( compile(tree, filename='', mode='exec'), globals_dict, namespace) @@ -518,7 +514,8 @@ def _group_recorder_by_method( key = recorder._method if key not in group_dict: group_dict[key] = [recorder] - group_dict[key].append(recorder) + else: + group_dict[key].append(recorder) return group_dict def _save_origin_method(self, model: Any, method_name: str, From d4406d6d8352d092936263e4d8cddf6a9e61c744 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 4 Oct 2023 17:43:20 +0800 Subject: [PATCH 36/39] add recorder_hook test --- mmengine/hooks/recorder_hook.py | 68 ++++++----- tests/test_hooks/test_recorder_hook.py | 157 +++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 30 deletions(-) create mode 100644 tests/test_hooks/test_recorder_hook.py diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index f391fd55a9..284c4cf8c8 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -34,9 +34,9 @@ class FunctionRecorderTransformer(ast.NodeTransformer): def __init__(self, model: str, method: str, target: str, target_index: Union[int, List[int]]): super().__init__() - self._model = model - self._method = method - self._target = target + self.model = model + self.method = method + self.target = target if isinstance(target_index, list): self._target_index = set(target_index) else: @@ -52,7 +52,7 @@ def get_store_varname_with_index(self, index: int) -> str: Returns: str: The variable name for the given index. """ - return f'{self._model}:{self._method}:{self._target}@{index}' + return f'{self.model}:{self.method}:{self.target}@{index}' def visit_Assign(self, node: ast.Assign) -> Union[Any, List[Any]]: """Visit and possibly transform an assignment node in the AST. @@ -64,7 +64,7 @@ def visit_Assign(self, node: ast.Assign) -> Union[Any, List[Any]]: Modified AST node or a list of AST nodes. """ assert isinstance(node.targets[0], ast.Name) - if node.targets[0].id != self._target: + if node.targets[0].id != self.target: return node self.count += 1 if self.count not in self._target_index: @@ -100,9 +100,9 @@ class AttributeRecorderTransformer(ast.NodeTransformer): def __init__(self, model, method, target): super().__init__() - self._model = model - self._method = method - self._target = target + self.model = model + self.method = method + self.target = target self._visited = False def _get_target_attribute(self) -> ast.Attribute: @@ -111,7 +111,8 @@ def _get_target_attribute(self) -> ast.Attribute: Returns: ast.Attribute: The node representing the target attribute. """ - func_chain = self._target.split('.') + target = 'self.' + self.target + func_chain = target.split('.') assert len(func_chain) >= 2 attr = ast.Attribute( value=ast.Name(id=func_chain[0], ctx=ast.Load()), @@ -128,7 +129,7 @@ def _deepcopy_varname(self) -> str: Returns: str: The variable name for the deep copy of the target attribute. """ - return f'_deep_copy_{self._target.replace(".", "_")}' + return f'_deep_copy_{self.target.replace(".", "_")}' def _get_tensor_name(self) -> str: """Generate and return the tensor name for the target attribute. @@ -136,7 +137,7 @@ def _get_tensor_name(self) -> str: Returns: str: The tensor name for the target attribute. """ - return f'{self._model}:{self._method}:{self._target}' + return f'{self.model}:{self.method}:{self.target}' def _get_deep_copy_node(self, var_node) -> ast.If: """Generate and return the AST node for deep copying the target @@ -239,9 +240,9 @@ class Recorder(metaclass=ABCMeta): """ def __init__(self, model: str, method: str, target: str): - self._model = model - self._method = method - self._target = target + self.model = model + self.method = method + self.target = target @abstractmethod def rewrite(self, ast_tree) -> Any: @@ -284,7 +285,7 @@ def __init__(self, model: str, method: str, target: str, index: list): super().__init__(model, method, target) self.index = index self.visit_assign = FunctionRecorderTransformer( - self._model, self._method, self._target, self.index) + self.model, self.method, self.target, self.index) def rewrite(self, ast_tree) -> Any: """Rewrite the AST tree to include recording logic for output of @@ -298,8 +299,7 @@ def get_store_varname(self) -> List[str]: number. """ return [ - f'{self._model}:{self._method}:{self._target}@{i}' - for i in self.index + f'{self.model}:{self.method}:{self.target}@{i}' for i in self.index ] @@ -317,9 +317,11 @@ class AttributeRecorder(Recorder): """ def __init__(self, model: str, method: str, target: str): + if target.startswith('self.'): + target = target[5:] super().__init__(model, method, target) self.visit_assign = AttributeRecorderTransformer( - self._model, self._method, self._target) + self.model, self.method, self.target) def rewrite(self, ast_tree) -> Any: """Rewrite the AST tree to include recording logic for attributes.""" @@ -327,7 +329,7 @@ def rewrite(self, ast_tree) -> Any: def get_store_varname(self) -> str: """Generate and return variable name based on model attributes.""" - return f'{self._model}:{self._method}:{self._target}' + return f'{self.model}:{self.method}:{self.target}' @HOOKS.register_module() @@ -348,8 +350,9 @@ class RecorderHook(Hook): save_dir (str): The directory where recorded data will be saved. If not specified, it will use the runner's work directory. Defaults to None. - filename_tmpl (Optional[str]): The filename template used when saving - recorded data. If not provided, a default template will be used. + filename_tmpl (str, optional): The filename template used when saving + recorded data. If specified, must contain one and only one "{}", + which will be replaced with ``epoch + 1``. Defaults to None. Examples: @@ -372,7 +375,7 @@ def __init__( ): self.tensor_dict: Dict[str, Any] = {} self.origin_methods: Dict[Any, Any] = {} - self._recorders: List[Recorder] = [] + self.recorders: List[Recorder] = [] self.print_modification: bool = print_modification self.save_dir: Optional[str] = save_dir # type: ignore if filename_tmpl is None: @@ -383,6 +386,12 @@ def __init__( if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') for recorder in recorders: + if recorder.get('type') == 'FunctionRecorder': + if recorder.get('index') is None: + recorder['index'] = 0 + if not isinstance(recorder['index'], list): + recorder['index'] = [recorder['index']] + model = recorder.get('model') if model is None: recorder['model'] = 'runner_model' @@ -397,7 +406,7 @@ def __init__( 'because recorder has no target', logger='current', level=logging.WARNING) - self._recorders.append(RECORDERS.build(recorder)) + self.recorders.append(RECORDERS.build(recorder)) def _modify_forward_func(self, func: Callable, recorders: List[Recorder]) -> Callable: @@ -488,8 +497,8 @@ def _group_recorder_by_model_method( """ group_model_dist = {} group_model_method_dict: Dict[str, Dict[str, List[Recorder]]] = {} - for recorder in self._recorders: - key = recorder._model + for recorder in self.recorders: + key = recorder.model if key not in group_model_dist: group_model_dist[key] = [recorder] else: @@ -511,7 +520,7 @@ def _group_recorder_by_method( """ group_dict: Dict[str, List[Recorder]] = {} for recorder in recorders: - key = recorder._method + key = recorder.method if key not in group_dict: group_dict[key] = [recorder] else: @@ -568,7 +577,6 @@ def before_run(self, runner) -> None: logger='current', level=logging.WARNING) continue - # self.origin_methods[model][method_name] = method print_log( f'Modify {method_name} in {model_name}', logger='current', @@ -594,7 +602,7 @@ def after_train_iter(self, for key in self.tensor_dict.keys(): self.tensor_dict[key].append(self.message_hub.get_info(key)) - def _save_record(self, step): + def _save_record_to_file(self, step): """Save recorded tensors to disk. Args: @@ -606,7 +614,7 @@ def _save_record(self, step): def _init_tensor_dict(self): """Initialize the tensor dictionary for recording.""" - for recorder in self._recorders: + for recorder in self.recorders: varname = recorder.get_store_varname() if isinstance(varname, list): for name in varname: @@ -622,7 +630,7 @@ def after_train_epoch(self, runner) -> None: """ step = runner.epoch + 1 runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') - self._save_record(step) + self._save_record_to_file(step) self._init_tensor_dict() def after_train(self, runner) -> None: diff --git a/tests/test_hooks/test_recorder_hook.py b/tests/test_hooks/test_recorder_hook.py new file mode 100644 index 0000000000..1729957cc9 --- /dev/null +++ b/tests/test_hooks/test_recorder_hook.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import tempfile +from unittest.mock import Mock + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from mmengine.evaluator import BaseMetric +from mmengine.hooks import RecorderHook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.testing import RunnerTestCase + + +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + outputs = self.linear(inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + else: + return outputs + + +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class DummyMetric(BaseMetric): + + default_prefix: str = 'test' + + def __init__(self, length): + super().__init__() + self.length = length + self.best_idx = length // 2 + self.cur_idx = 0 + self.vals = [90, 91, 92, 88, 89, 90] * 2 + + def process(self, *args, **kwargs): + self.results.append(0) + + def compute_metrics(self, *args, **kwargs): + acc = self.vals[self.cur_idx] + self.cur_idx += 1 + return dict(acc=acc) + + +def get_mock_runner(): + runner = Mock() + runner.train_loop = Mock() + runner.train_loop.stop_training = False + return runner + + +class TestRecorderHook(RunnerTestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + # `FileHandler` should be closed in Windows, otherwise we cannot + # delete the temporary directory + logging.shutdown() + MMLogger._instance_dict.clear() + self.temp_dir.cleanup() + + def test_init(self): + # Test recorders + with self.assertRaisesRegex(ValueError, 'recorders not initialized'): + RecorderHook(None, self.temp_dir) + with self.assertRaisesRegex(ValueError, 'recorders not initialized'): + RecorderHook([], self.temp_dir) + + hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + self.assertEqual(len(hook.recorders), 1) + self.assertEqual(hook.recorders[0].target, 'x') + + self.assertEqual(hook.recorders[0].model, 'runner_model') + self.assertEqual(hook.recorders[0].method, 'forward') + + hook = RecorderHook( + [dict(type='AttributeRecorder', target='self.linear1.weight')]) + self.assertEqual(len(hook.recorders), 1) + self.assertEqual(hook.recorders[0].model, 'runner_model') + self.assertEqual(hook.recorders[0].method, 'forward') + self.assertEqual(hook.recorders[0].target, 'linear1.weight') + + hook = RecorderHook([ + dict(type='FunctionRecorder', target='x'), + dict(type='AttributeRecorder', target='self.linear1.weight') + ]) + self.assertEqual(len(hook.recorders), 2) + + hook = RecorderHook([ + dict( + type='AttributeRecorder', + model='resnet', + method='_forward_impl', + target='x') + ]) + self.assertEqual(len(hook.recorders), 1) + self.assertEqual(hook.recorders[0].model, 'resnet') + self.assertEqual(hook.recorders[0].method, '_forward_impl') + self.assertEqual(hook.recorders[0].target, 'x') + + def test_before_run(self): + # test method modification + runner = Mock() + base_model = ToyModel() + origin_forward = base_model.forward + runner.model = base_model + runner.work_dir = self.temp_dir.name + + hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + hook.before_run(runner) + self.assertEqual(hook.save_dir, self.temp_dir.name) + self.assertEqual(hook.base_model, base_model) + self.assertNotEqual(origin_forward, hook.base_model.forward) + + def test_after_train(self): + runner = Mock() + base_model = ToyModel() + origin_forward = base_model.forward + runner.model = base_model + + hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + hook.before_run(runner) + self.assertEqual(hook.base_model, base_model) + self.assertNotEqual(origin_forward, hook.base_model.forward) + + hook.after_train(runner) + self.assertEqual(origin_forward, hook.base_model.forward) From 9f5f35a97635e09d4cb05e05f72ac98f1f12ec57 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 5 Oct 2023 10:55:47 +0800 Subject: [PATCH 37/39] modify test_recorder_hook --- mmengine/hooks/recorder_hook.py | 69 +++++++++++-- tests/test_hooks/test_recorder_hook.py | 129 ++++++++++++++++--------- 2 files changed, 143 insertions(+), 55 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index 284c4cf8c8..c89da9c9a0 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -342,6 +342,13 @@ class RecorderHook(Hook): using ast module in python. Args: + interval (int): The saving period. If ``by_epoch=True``, interval + indicates epochs, otherwise it indicates iterations. + Defaults to -1, which means "never". + by_epoch (bool): Saving checkpoints by epoch or by iteration. + Defaults to True. + save_last (bool): Whether to force the last record to be + saved regardless of interval. Defaults to True. recorders (Optional[List[Dict]]): Configurations for individual recorders. Each recorder dict should contain the target model and method. @@ -368,18 +375,27 @@ class RecorderHook(Hook): def __init__( self, + interval: int = -1, + by_epoch: bool = True, + save_last: bool = True, recorders: Optional[List[Dict]] = None, print_modification: bool = True, save_dir: str = None, filename_tmpl: Optional[str] = None, ): + self.interval = interval + self.by_epoch = by_epoch + self.save_last = save_last self.tensor_dict: Dict[str, Any] = {} self.origin_methods: Dict[Any, Any] = {} self.recorders: List[Recorder] = [] self.print_modification: bool = print_modification self.save_dir: Optional[str] = save_dir # type: ignore if filename_tmpl is None: - self.filename_tmpl = 'record_epoch_{}.pth' + if self.by_epoch: + self.filename_tmpl = 'record_epoch_{}.pth' + else: + self.filename_tmpl = 'record_iter_{}.pth' else: self.filename_tmpl = filename_tmpl @@ -599,8 +615,23 @@ def after_train_iter(self, data_batch (Optional): Current data batch. Default is None. outputs (Optional): Outputs from the current iteration. """ - for key in self.tensor_dict.keys(): - self.tensor_dict[key].append(self.message_hub.get_info(key)) + if self.by_epoch: + for key in self.tensor_dict.keys(): + self.tensor_dict[key].append(self.message_hub.get_info(key)) + else: + for key in self.tensor_dict.keys(): + self.tensor_dict[key] = self.message_hub.get_info(key) + # save record for following cases: + # 1. every ``self.interval`` iterations + # 2. reach the last iteration of training + if (self.every_n_train_iters(runner, self.interval) + or self.save_last and self.is_last_train_iter(runner)): + step = runner.iter + 1 + runner.logger.info( + f'Saving record at {runner.iter + 1} iterations') + self._save_record_to_file(step) + # every iteration will clear the tensor_dict + self._init_tensor_dict() def _save_record_to_file(self, step): """Save recorded tensors to disk. @@ -618,9 +649,23 @@ def _init_tensor_dict(self): varname = recorder.get_store_varname() if isinstance(varname, list): for name in varname: - self.tensor_dict[name] = list() + if self.by_epoch: + self.tensor_dict[name] = list() + else: + self.tensor_dict[name] = None else: + if self.by_epoch: + self.tensor_dict[varname] = list() + else: + self.tensor_dict[varname] = None + + def _clear_tensor_dict(self): + """Clear the tensor dictionary.""" + for varname, record in self.tensor_dict.items(): + if isinstance(record, list): self.tensor_dict[varname] = list() + else: + self.tensor_dict[varname] = None def after_train_epoch(self, runner) -> None: """Save recorded tensors after each training epoch. @@ -628,10 +673,18 @@ def after_train_epoch(self, runner) -> None: Args: runner (Runner): The runner of the training process. """ - step = runner.epoch + 1 - runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') - self._save_record_to_file(step) - self._init_tensor_dict() + if not self.by_epoch: + return + # save record for following cases: + # 1. every ``self.interval`` epochs + # 2. reach the last epoch of training + if self.every_n_epochs(runner, self.interval) or ( + self.save_last and self.is_last_train_epoch(runner)): + step = runner.epoch + 1 + runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') + self._save_record_to_file(step) + # every epoch will clear the tensor_dict + self._clear_tensor_dict() def after_train(self, runner) -> None: """Restore original methods after training. diff --git a/tests/test_hooks/test_recorder_hook.py b/tests/test_hooks/test_recorder_hook.py index 1729957cc9..b80f52c65d 100644 --- a/tests/test_hooks/test_recorder_hook.py +++ b/tests/test_hooks/test_recorder_hook.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import logging +import os.path as osp import tempfile from unittest.mock import Mock import torch import torch.nn as nn -from torch.utils.data import Dataset +from parameterized import parameterized -from mmengine.evaluator import BaseMetric from mmengine.hooks import RecorderHook from mmengine.logging import MMLogger from mmengine.model import BaseModel @@ -34,42 +35,6 @@ def forward(self, inputs, data_sample, mode='tensor'): return outputs -class DummyDataset(Dataset): - METAINFO = dict() # type: ignore - data = torch.randn(12, 2) - label = torch.ones(12) - - @property - def metainfo(self): - return self.METAINFO - - def __len__(self): - return self.data.size(0) - - def __getitem__(self, index): - return dict(inputs=self.data[index], data_sample=self.label[index]) - - -class DummyMetric(BaseMetric): - - default_prefix: str = 'test' - - def __init__(self, length): - super().__init__() - self.length = length - self.best_idx = length // 2 - self.cur_idx = 0 - self.vals = [90, 91, 92, 88, 89, 90] * 2 - - def process(self, *args, **kwargs): - self.results.append(0) - - def compute_metrics(self, *args, **kwargs): - acc = self.vals[self.cur_idx] - self.cur_idx += 1 - return dict(acc=acc) - - def get_mock_runner(): runner = Mock() runner.train_loop = Mock() @@ -80,6 +45,7 @@ def get_mock_runner(): class TestRecorderHook(RunnerTestCase): def setUp(self): + super().setUp() self.temp_dir = tempfile.TemporaryDirectory() def tearDown(self): @@ -92,31 +58,33 @@ def tearDown(self): def test_init(self): # Test recorders with self.assertRaisesRegex(ValueError, 'recorders not initialized'): - RecorderHook(None, self.temp_dir) + RecorderHook(recorders=None, save_dir=self.temp_dir) with self.assertRaisesRegex(ValueError, 'recorders not initialized'): - RecorderHook([], self.temp_dir) + RecorderHook(recorders=[], save_dir=self.temp_dir) - hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + hook = RecorderHook( + recorders=[dict(type='FunctionRecorder', target='x')]) self.assertEqual(len(hook.recorders), 1) self.assertEqual(hook.recorders[0].target, 'x') self.assertEqual(hook.recorders[0].model, 'runner_model') self.assertEqual(hook.recorders[0].method, 'forward') - hook = RecorderHook( - [dict(type='AttributeRecorder', target='self.linear1.weight')]) + hook = RecorderHook(recorders=[ + dict(type='AttributeRecorder', target='self.linear1.weight') + ]) self.assertEqual(len(hook.recorders), 1) self.assertEqual(hook.recorders[0].model, 'runner_model') self.assertEqual(hook.recorders[0].method, 'forward') self.assertEqual(hook.recorders[0].target, 'linear1.weight') - hook = RecorderHook([ + hook = RecorderHook(recorders=[ dict(type='FunctionRecorder', target='x'), dict(type='AttributeRecorder', target='self.linear1.weight') ]) self.assertEqual(len(hook.recorders), 2) - hook = RecorderHook([ + hook = RecorderHook(recorders=[ dict( type='AttributeRecorder', model='resnet', @@ -136,7 +104,8 @@ def test_before_run(self): runner.model = base_model runner.work_dir = self.temp_dir.name - hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + hook = RecorderHook( + recorders=[dict(type='FunctionRecorder', target='x')]) hook.before_run(runner) self.assertEqual(hook.save_dir, self.temp_dir.name) self.assertEqual(hook.base_model, base_model) @@ -148,10 +117,76 @@ def test_after_train(self): origin_forward = base_model.forward runner.model = base_model - hook = RecorderHook([dict(type='FunctionRecorder', target='x')]) + hook = RecorderHook( + recorders=[dict(type='FunctionRecorder', target='x')]) hook.before_run(runner) self.assertEqual(hook.base_model, base_model) self.assertNotEqual(origin_forward, hook.base_model.forward) hook.after_train(runner) self.assertEqual(origin_forward, hook.base_model.forward) + + @parameterized.expand([['iter'], ['epoch']]) + def test_with_runner(self, training_type): + common_cfg = getattr(self, f'{training_type}_based_cfg') + setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) + recorder_cfg = dict( + type='RecorderHook', by_epoch=training_type == 'epoch', interval=1) + common_cfg.default_hooks = dict(recorder=recorder_cfg) + + # Test interval in epoch based training + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.recorder.recorders = [ + dict(type='FunctionRecorder', target='outputs', index=[0, 1]) + ] + cfg.default_hooks.recorder.interval = 2 + runner = self.build_runner(cfg) + runner.train() + + for i in range(1, 11): + self.assertEqual( + osp.isfile( + osp.join(cfg.work_dir, f'record_{training_type}_{i}.pth')), + i % 2 == 0) + + record = torch.load( + osp.join(cfg.work_dir, f'record_{training_type}_10.pth')) + self.assertEqual(len(record), 2) + for varname, var in record.items(): + self.assertTrue(varname.startswith('runner_model:forward:outputs')) + # tensor_list should be a list of tensor + if training_type == 'epoch': + self.assertTrue( + all(isinstance(item, torch.Tensor) for item in var)) + else: + self.assertTrue(isinstance(var, torch.Tensor)) + + self.clear_work_dir() + + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.recorder.recorders = [ + dict(type='AttributeRecorder', target='linear1.weight'), + dict(type='AttributeRecorder', target='linear2.bias') + ] + + runner = self.build_runner(cfg) + runner.train() + + for i in range(1, 11): + self.assertEqual( + osp.isfile( + osp.join(cfg.work_dir, f'record_{training_type}_{i}.pth')), + True) + + record = torch.load( + osp.join(cfg.work_dir, f'record_{training_type}_10.pth')) + self.assertEqual(len(record), 2) + for varname, var in record.items(): + self.assertTrue( + varname.startswith('runner_model:forward:linear1.weight') + or varname.startswith('runner_model:forward:linear2.bias')) + if training_type == 'epoch': + self.assertTrue( + all(isinstance(item, torch.Tensor) for item in var)) + else: + self.assertTrue(isinstance(var, torch.Tensor)) From 4e81004723adebc564e9449dd9899da6ab71354c Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 5 Oct 2023 11:12:24 +0800 Subject: [PATCH 38/39] delete modification option --- mmengine/hooks/recorder_hook.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index c89da9c9a0..a647e3d1c0 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -352,8 +352,6 @@ class RecorderHook(Hook): recorders (Optional[List[Dict]]): Configurations for individual recorders. Each recorder dict should contain the target model and method. - print_modification (bool): Whether to print the modified source code - after it's been altered by a recorder. Defaults to True. save_dir (str): The directory where recorded data will be saved. If not specified, it will use the runner's work directory. Defaults to None. @@ -366,7 +364,6 @@ class RecorderHook(Hook): >>> recorder_hook_cfg = dict( ... recorders=[{'model': 'runner_model', ... 'target': 'layer1', 'method': 'forward'}], - ... print_modification=True, ... save_dir='./records', ... filename_tmpl='record_epoch_{}.pth' ... ) @@ -379,7 +376,6 @@ def __init__( by_epoch: bool = True, save_last: bool = True, recorders: Optional[List[Dict]] = None, - print_modification: bool = True, save_dir: str = None, filename_tmpl: Optional[str] = None, ): @@ -389,7 +385,6 @@ def __init__( self.tensor_dict: Dict[str, Any] = {} self.origin_methods: Dict[Any, Any] = {} self.recorders: List[Recorder] = [] - self.print_modification: bool = print_modification self.save_dir: Optional[str] = save_dir # type: ignore if filename_tmpl is None: if self.by_epoch: @@ -469,14 +464,6 @@ def _modify_forward_func(self, func: Callable, for recorder in recorders: tree = recorder.rewrite(tree) - if self.print_modification: - new_tree = ast.fix_missing_locations(tree) - modified_source_code = ast.unparse(new_tree) - print_log( - f'After modification, the source code is:\n' - f'{modified_source_code}', - logger='current', - level=logging.INFO) tree = ast.fix_missing_locations(tree) # Compile the modified ast as a new function From ec757ba80feaf2cd38d0cfd61e87b335d1f3d0bc Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 6 Oct 2023 22:23:39 +0800 Subject: [PATCH 39/39] add save to messagehub --- mmengine/hooks/recorder_hook.py | 92 ++++++++++++++++++++------ tests/test_hooks/test_recorder_hook.py | 45 ++++++++++++- 2 files changed, 114 insertions(+), 23 deletions(-) diff --git a/mmengine/hooks/recorder_hook.py b/mmengine/hooks/recorder_hook.py index a647e3d1c0..54b1423c4a 100644 --- a/mmengine/hooks/recorder_hook.py +++ b/mmengine/hooks/recorder_hook.py @@ -347,6 +347,8 @@ class RecorderHook(Hook): Defaults to -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Defaults to True. + save_to (str): The place to save the recorded data. It can be either + 'file' or 'messagehub'. Defaults to 'file'. save_last (bool): Whether to force the last record to be saved regardless of interval. Defaults to True. recorders (Optional[List[Dict]]): @@ -356,9 +358,17 @@ class RecorderHook(Hook): If not specified, it will use the runner's work directory. Defaults to None. filename_tmpl (str, optional): The filename template used when saving - recorded data. If specified, must contain one and only one "{}", - which will be replaced with ``epoch + 1``. + recorded data to file. + If specified, must contain one and only one "{}", + which will be replaced with ``epoch + 1`` or ``iter + 1``. Defaults to None. + messagehub_key_tmpl (str, optional): The messagehub_key + template used when saving recorded data to messagehub. + If specified, must contain one and only one "{}", + which will be replaced with ``epoch + 1`` or ``iter + 1``. + Defaults to None. + messagehub_name (str, optional): The name of messagehub instance, + only useful when ``save_to`` is equal to "messagehub". Examples: >>> recorder_hook_cfg = dict( @@ -370,29 +380,46 @@ class RecorderHook(Hook): """ priority = 'LOWEST' - def __init__( - self, - interval: int = -1, - by_epoch: bool = True, - save_last: bool = True, - recorders: Optional[List[Dict]] = None, - save_dir: str = None, - filename_tmpl: Optional[str] = None, - ): + def __init__(self, + interval: int = -1, + by_epoch: bool = True, + save_to: str = 'disk', + save_last: bool = True, + recorders: Optional[List[Dict]] = None, + save_dir: str = None, + filename_tmpl: Optional[str] = None, + messagehub_key_tmpl: Optional[str] = None, + messagehub_name: Optional[str] = None): self.interval = interval self.by_epoch = by_epoch self.save_last = save_last self.tensor_dict: Dict[str, Any] = {} self.origin_methods: Dict[Any, Any] = {} self.recorders: List[Recorder] = [] + self.save_to: Optional[str] = save_to self.save_dir: Optional[str] = save_dir # type: ignore - if filename_tmpl is None: - if self.by_epoch: - self.filename_tmpl = 'record_epoch_{}.pth' + if save_to == 'disk': + if filename_tmpl is None: + if self.by_epoch: + self.save_tmpl = 'record_epoch_{}.pth' + else: + self.save_tmpl = 'record_iter_{}.pth' else: - self.filename_tmpl = 'record_iter_{}.pth' + self.save_tmpl = filename_tmpl + elif save_to == 'messagehub': + if messagehub_name is None: + messagehub_name = 'recorder_hook' + self.save_messagehub = MessageHub.get_instance(messagehub_name) + if messagehub_key_tmpl is None: + if self.by_epoch: + self.save_tmpl = 'record_epoch_{}' + else: + self.save_tmpl = 'record_iter_{}' + else: + self.save_tmpl = messagehub_key_tmpl else: - self.filename_tmpl = filename_tmpl + raise ValueError(f"save_to should be 'file' or 'messagehub', " + f'but got {save_to}') if recorders is None or len(recorders) == 0: raise ValueError('recorders not initialized') @@ -553,7 +580,7 @@ def before_run(self, runner) -> None: self.save_dir = runner.work_dir # get messagehub instance and store it. - self.message_hub = MessageHub.get_instance('recorder_hook') + self.buffer_messagehub = MessageHub.get_instance('recorder_hook') # init_save_var_dict self._init_tensor_dict() # get model and modify its forward function @@ -604,10 +631,11 @@ def after_train_iter(self, """ if self.by_epoch: for key in self.tensor_dict.keys(): - self.tensor_dict[key].append(self.message_hub.get_info(key)) + self.tensor_dict[key].append( + self.buffer_messagehub.get_info(key)) else: for key in self.tensor_dict.keys(): - self.tensor_dict[key] = self.message_hub.get_info(key) + self.tensor_dict[key] = self.buffer_messagehub.get_info(key) # save record for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training @@ -616,20 +644,40 @@ def after_train_iter(self, step = runner.iter + 1 runner.logger.info( f'Saving record at {runner.iter + 1} iterations') - self._save_record_to_file(step) + self._save_record(step) # every iteration will clear the tensor_dict self._init_tensor_dict() + def _save_record(self, step): + """Save recorded tensors to disk or messagehub. + + Args: + step (int): Current training epoch. + """ + if self.save_to == 'disk': + self._save_record_to_file(step) + elif self.save_to == 'messagehub': + self._save_record_to_messagehub(step) + def _save_record_to_file(self, step): """Save recorded tensors to disk. Args: step (int): Current training epoch. """ - recorder_file_name = self.filename_tmpl.format(step) + recorder_file_name = self.save_tmpl.format(step) path = osp.join(self.save_dir, recorder_file_name) torch.save(self.tensor_dict, path) + def _save_record_to_messagehub(self, step): + """Save recorded tensors to messagehub. + + Args: + step (int): Current training epoch. + """ + self.save_messagehub.update_info( + self.save_tmpl.format(step), self.tensor_dict.copy()) + def _init_tensor_dict(self): """Initialize the tensor dictionary for recording.""" for recorder in self.recorders: @@ -669,7 +717,7 @@ def after_train_epoch(self, runner) -> None: self.save_last and self.is_last_train_epoch(runner)): step = runner.epoch + 1 runner.logger.info(f'Saving record at {runner.epoch + 1} epochs') - self._save_record_to_file(step) + self._save_record(step) # every epoch will clear the tensor_dict self._clear_tensor_dict() diff --git a/tests/test_hooks/test_recorder_hook.py b/tests/test_hooks/test_recorder_hook.py index b80f52c65d..77208081cd 100644 --- a/tests/test_hooks/test_recorder_hook.py +++ b/tests/test_hooks/test_recorder_hook.py @@ -10,7 +10,7 @@ from parameterized import parameterized from mmengine.hooks import RecorderHook -from mmengine.logging import MMLogger +from mmengine.logging import MessageHub, MMLogger from mmengine.model import BaseModel from mmengine.testing import RunnerTestCase @@ -96,6 +96,18 @@ def test_init(self): self.assertEqual(hook.recorders[0].method, '_forward_impl') self.assertEqual(hook.recorders[0].target, 'x') + # Test interval, by_epoch, save_dir, save_to + hook = RecorderHook( + recorders=[dict(type='FunctionRecorder', target='x')], + interval=1, + by_epoch=True, + save_dir=self.temp_dir.name, + save_to='messagehub') + self.assertEqual(hook.interval, 1) + self.assertEqual(hook.by_epoch, True) + self.assertEqual(hook.save_dir, self.temp_dir.name) + self.assertEqual(hook.save_to, 'messagehub') + def test_before_run(self): # test method modification runner = Mock() @@ -151,6 +163,7 @@ def test_with_runner(self, training_type): record = torch.load( osp.join(cfg.work_dir, f'record_{training_type}_10.pth')) + self.assertTrue(isinstance(record, dict)) self.assertEqual(len(record), 2) for varname, var in record.items(): self.assertTrue(varname.startswith('runner_model:forward:outputs')) @@ -180,6 +193,7 @@ def test_with_runner(self, training_type): record = torch.load( osp.join(cfg.work_dir, f'record_{training_type}_10.pth')) + self.assertTrue(isinstance(record, dict)) self.assertEqual(len(record), 2) for varname, var in record.items(): self.assertTrue( @@ -190,3 +204,32 @@ def test_with_runner(self, training_type): all(isinstance(item, torch.Tensor) for item in var)) else: self.assertTrue(isinstance(var, torch.Tensor)) + + self.clear_work_dir() + + # Test store to messagehub + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.recorder.recorders = [ + dict(type='FunctionRecorder', target='outputs', index=[0, 1]) + ] + cfg.default_hooks.recorder.save_to = 'messagehub' + cfg.default_hooks.recorder.messagehub_name = 'test_messagehub' + runner = self.build_runner(cfg) + runner.train() + + test_messagehub = MessageHub.get_instance('test_messagehub') + for i in range(1, 11): + key = f'record_{training_type}_{i}' + record = test_messagehub.get_info(key) + self.assertIsNotNone(record) + + record = test_messagehub.get_info(f'record_{training_type}_10') + self.assertTrue(isinstance(record, dict)) + self.assertEqual(len(record), 2) + for varname, var in record.items(): + self.assertTrue(varname.startswith('runner_model:forward:outputs')) + if training_type == 'epoch': + self.assertTrue( + all(isinstance(item, torch.Tensor) for item in var)) + else: + self.assertTrue(isinstance(var, torch.Tensor))