-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_market.py
65 lines (44 loc) · 1.75 KB
/
train_market.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import print_function
import argparse
import sys
import time
import torch
import torch.optim as optim
from config import Config_market
from utils import *
from datasets import *
from net.model import *
from engine.market_tainer import *
###############Init Setting#########################################
args = Config_market()
checkpoint_path = args.model_path
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path)
###############Init Setting##########################################
###############Load Data##############################################
market_loader, n_market_clc = make_market_data_loader(args)
###############Load Data##############################################
###############Building Model ##############################################
print('==> Building model..')
Embed_net = Baseline(model_path=args.model_path+'se_resnext50.pth')
Classify_net = C_net(args.low_dim,n_market_clc)
A_net = Attribute_net(dim=args.low_dim, n_att=args.num_att)
trainer = create_trainer(args, Embed_net, Classify_net, A_net, n_market_clc)
# training
best_acc = 0 # best test accuracy
start_epoch = 0
swith_point = 10
print('==> Start Training...')
for epoch in range(start_epoch, 121-start_epoch):
print('==> Preparing Data Loader...')
if epoch == swith_point:
args.num_instance = 4
market_loader, n_market_clc = make_market_data_loader(args)
if epoch < swith_point:
trainer.do_train(epoch, market_loader, 'softmax')
else:
trainer.do_train(epoch, market_loader, 'triplet_softmax')
trainer.adjust_learning_rate(epoch)
# save model every args.save_epoch epochs
if epoch > 0 and epoch%args.save_epoch ==0:
trainer.save_model(epoch)