Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RecorderHook #1300

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
823d238
Add recorder_hook and use ast to print assign node
Xinyu302 Jul 26, 2023
85cde71
use messagehub to store information
Xinyu302 Jul 26, 2023
074ee1e
use message_hub.update_scalar
Xinyu302 Jul 28, 2023
f4dcc13
use message_hub.update_scalar
Xinyu302 Jul 28, 2023
e8144c9
design class Recorder
Xinyu302 Jul 31, 2023
c60b934
add recover forward logic
Xinyu302 Jul 31, 2023
54412dc
FunctionRecord actually should be AttributeRecorder because we find a…
Xinyu302 Jul 31, 2023
c7df8bb
add FunctionRecorder
Xinyu302 Aug 5, 2023
e4351ba
add update2 messagehub logic
Xinyu302 Aug 5, 2023
9fe0e7d
clean up code
Xinyu302 Aug 5, 2023
fd6b8e4
add comment and registry for AttributeRecorder and FunctionRecorder
Xinyu302 Aug 5, 2023
25b2415
fix commit verify
Xinyu302 Aug 5, 2023
ce0bfbe
do some clean up
Xinyu302 Aug 5, 2023
9ef4e44
add recorder_hook_test.py
Xinyu302 Aug 5, 2023
4b396aa
redesign FunctionRecorder and AttributeRecorder
Xinyu302 Aug 8, 2023
2d48fae
modify recorder_hook_test.py
Xinyu302 Aug 10, 2023
9a6ff6f
modify attribute recorder
Xinyu302 Sep 4, 2023
4c5d27b
store function recorder in a format of assign_name@index
Xinyu302 Sep 4, 2023
2d8b64b
modify function recorder index: start from 0
Xinyu302 Sep 4, 2023
7bdf2c0
use torch.save to dump data; handle when index is int
Xinyu302 Sep 12, 2023
9fa6c94
add default value for FunctionRecorder's index
Xinyu302 Sep 12, 2023
4102fa2
add copy.deepcopy to collect weight in layer
Xinyu302 Sep 12, 2023
f72c7b1
rename var name
Xinyu302 Sep 12, 2023
33dd386
add model select in recorder
Xinyu302 Sep 12, 2023
a995399
refactor: modify AttributeRecorderTransformer; modify _get_model; mod…
Xinyu302 Sep 15, 2023
963b54e
add deepcopy, if var is Tensor, use Tensor.detach().clone()
Xinyu302 Sep 16, 2023
4f434e4
refactor about store var name
Xinyu302 Sep 17, 2023
1f54cfc
delete useless lines
Xinyu302 Sep 17, 2023
581d668
add appoint specify method
Xinyu302 Sep 17, 2023
10e447f
update test script
Xinyu302 Sep 17, 2023
2d5447b
use MessageHub.get_instance
Xinyu302 Sep 19, 2023
e7e439d
add docs
Xinyu302 Sep 20, 2023
b58540b
try to add type hint
Xinyu302 Sep 20, 2023
ea29bfa
add type hint
Xinyu302 Sep 20, 2023
06fabbe
add type ignore
Xinyu302 Sep 27, 2023
d4406d6
add recorder_hook test
Xinyu302 Oct 4, 2023
9f5f35a
modify test_recorder_hook
Xinyu302 Oct 5, 2023
4e81004
delete modification option
Xinyu302 Oct 5, 2023
ec757ba
add save to messagehub
Xinyu302 Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/attribute_toy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch import nn
from torch.utils.data import DataLoader

from mmengine.model import BaseModel
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)

def forward(self, inputs, data_samples, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_sample = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)

if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (data_sample - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs


x = [(torch.ones(2, 2), [torch.ones(2, 1)])]
# train_dataset = [x, x, x]
train_dataset = x * 50
train_dataloader = DataLoader(train_dataset, batch_size=1)

runner = Runner(
model=ToyModel(),
custom_hooks=[
dict(
type='RecorderHook',
recorders=[
dict(type='AttributeRecorder', target='self.linear1.weight')
],
save_dir='./work_dir',
print_modification=True)
],
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=10),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()
54 changes: 54 additions & 0 deletions examples/function_toy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch import nn
from torch.utils.data import DataLoader

from mmengine.model import BaseModel
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)

def forward(self, inputs, data_samples, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_sample = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)

if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (data_sample - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs


x = [(torch.ones(2, 2), [torch.ones(2, 1)])]
# train_dataset = [x, x, x]
train_dataset = x * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)

runner = Runner(
model=ToyModel(),
custom_hooks=[
dict(
type='RecorderHook',
recorders=[
dict(type='FunctionRecorder', target='outputs', index=[1])
],
save_dir='./work_dir',
print_modification=True)
],
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=10),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()
100 changes: 100 additions & 0 deletions examples/recorder_hook_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner


class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):

def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))

val_dataloader = DataLoader(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))

# default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

runner = Runner(
# custom_hooks=[
# dict(
# type='RecorderHook',
# recorders=[dict(type='FunctionRecorder', target='x')],
# save_dir='./work_dir',
# print_modification=True)
# ],
custom_hooks=[
dict(
type='RecorderHook',
recorders=[
dict(
model='resnet',
method='_forward_impl',
type='FunctionRecorder',
target='x',
index=[0, 1, 2])
],
save_dir='./work_dir',
print_modification=True)
],
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=1, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
3 changes: 2 additions & 1 deletion mmengine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .naive_visualization_hook import NaiveVisualizationHook
from .param_scheduler_hook import ParamSchedulerHook
from .profiler_hook import NPUProfilerHook, ProfilerHook
from .recorder_hook import RecorderHook
from .runtime_info_hook import RuntimeInfoHook
from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook
Expand All @@ -18,5 +19,5 @@
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook',
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook'
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook', 'RecorderHook'
]
Loading