-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
64 lines (47 loc) · 1.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torch
import argparse
import random
import os
from config import load_config_file, get_data, get_model_and_optim, get_loss_fn, get_renderer, get_seed
from trainer import Trainer
from mesh import load_mesh
from utils import model_summary
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument('--allow_checkpoint_loading', default=False, action="store_true")
parser.add_argument('--data_parallel', default=False, action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
config = load_config_file(args.config_path, args.allow_checkpoint_loading)
seed = get_seed(config)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
device = "cuda"
torch.backends.cudnn.benchmark = True
else:
device = "cpu"
mesh = load_mesh(config["data"]["mesh_path"])
data = get_data(config, device)
model, optim = get_model_and_optim(config, mesh, device)
# Print model summary
model_summary(model, data)
if args.data_parallel:
device_ids = [int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")]
model = torch.nn.DataParallel(model, device_ids=device_ids)
loss_fn = get_loss_fn(config)
renderer = get_renderer(config, model, mesh, device)
# Seed again because different model architectures change seed. Make train samples consistent.
# https://discuss.pytorch.org/t/shuffle-issue-in-dataloader-how-to-get-the-same-data-shuffle-results-with-fixed-seed-but-different-network/45357/9
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
trainer = Trainer(model, optim, loss_fn, renderer, data, mesh, config, device)
trainer.train()
if __name__ == "__main__":
main()