forked from paulxiong/SimCLR-4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
62 lines (50 loc) · 2.08 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
from modules import SimCLR, LARS
def load_model(args, loader, reload_model=False):
model = SimCLR(args)
if reload_model:
model_fp = os.path.join(
args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
)
model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
model = model.to(args.device)
scheduler = None
if args.optimizer == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS
elif args.optimizer == "LARS":
# optimized using LARS with linear learning rate scaling
# (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
learning_rate = 0.3 * args.batch_size / 256
optimizer = LARS(
model.parameters(),
lr=learning_rate,
weight_decay=args.weight_decay,
exclude_from_weight_decay=["batch_normalization", "bias"],
)
# "decay the learning rate with the cosine decay schedule without restarts"
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(loader), eta_min=0, last_epoch=-1
)
else:
raise NotImplementedError
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError(
"Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
)
print("### USING FP16 ###")
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level
)
return model, optimizer, scheduler
def save_model(args, model, optimizer):
out = os.path.join(args.out_dir, "checkpoint_{}.tar".format(args.current_epoch))
# To save a DataParallel model generically, save the model.module.state_dict().
# This way, you have the flexibility to load the model any way you want to any device you want.
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), out)
else:
torch.save(model.state_dict(), out)