-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
executable file
·84 lines (77 loc) · 4.12 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from YCBDataModule import YCBDataModule
from LinemodDataModule import LinemodDataModule
from CustomDataModule import CustomDataModule
from DenseFusionModule import DenseFusionModule
import pytorch_lightning as pl
import argparse
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default = 'custom', help='ycb or linemod')
parser.add_argument('--dataset_root', type=str, default = 'datasets/ycb/YCB_Video_Dataset', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')')
parser.add_argument('--batch_size', type=int, default = 8, help='batch size')
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')
parser.add_argument('--lr', default=0.0001, help='learning rate')
parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')
parser.add_argument('--w', default=0.015, help='learning rate')
parser.add_argument('--w_rate', default=0.3, help='learning rate decay rate')
parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & w')
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement')
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data')
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train')
parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model')
parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model')
parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
opt = parser.parse_args()
# :)
opt.refine_start = False
if __name__ == '__main__':
torch.multiprocessing.freeze_support()
if opt.dataset == 'ycb':
opt.num_objects = 21 #number of object classes in the dataset
opt.num_points = 1000 #number of points on the input pointcloud
opt.outf = 'trained_models/ycb' #folder to save trained models
opt.log_dir = 'experiments/logs/ycb' #folder to save logs
opt.repeat_epoch = 1 #number of repeat times for one epoch training
# init DataModule
dataModule = YCBDataModule(opt)
elif opt.dataset == 'linemod':
opt.num_objects = 13
opt.num_points = 500
opt.outf = 'trained_models/linemod'
opt.log_dir = 'experiments/logs/linemod'
opt.repeat_epoch = 20
opt.nepoch = opt.nepoch*opt.repeat_epoch
# init DataModule
dataModule = LinemodDataModule(opt)
elif opt.dataset == 'custom':
opt.dataset_root = 'datasets/custom/custom_preprocessed'
opt.num_objects = 1
opt.num_points = 500
opt.outf = 'trained_models/custom'
opt.log_dir = 'experiments/logs/custom'
opt.repeat_epoch = 1
# init DataModule
dataModule = CustomDataModule(opt)
else:
print('Unknown dataset')
# init model
densefusion = DenseFusionModule(opt)
checkpoint_callback = ModelCheckpoint(dirpath='ckpt/',
filename='dense-fusion-{epoch:02d}-{val_loss:.2f}',
monitor="loss",
save_last=True,
every_n_train_steps=1000)
logger = TensorBoardLogger("tb_logs", name="dense_fusion")
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer(logger=logger, accumulate_grad_batches=opt.batch_size,
callbacks=[checkpoint_callback],
max_epochs=opt.nepoch,
check_val_every_n_epoch=opt.repeat_epoch,
gpus=1,
resume_from_checkpoint= opt.resume_posenet,
)
trainer.fit(densefusion, datamodule=dataModule)