-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbayes_cnn.py
executable file
·155 lines (137 loc) · 7.35 KB
/
bayes_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Code for Replica Exchange Stochastic Gradient MCMC on supervised learning
(c) Wei Deng, Liyao Gao
July 1, 2020
You can cite this paper 'Non-convex Learning via Replica Exchange Stochastic Gradient MCMC (ICML 2020)' if you find it useful.
Note that in Bayesian settings, the lr 2e-6 and weight decay 25 are equivalent to lr 0.1 and weight decay 5e-4 in standard setups.
"""
#!/usr/bin/python
import math
import copy
import sys
import os
import timeit
import csv
import argparse
from tqdm import tqdm ## better progressbar
from math import exp
from sys import getsizeof
import numpy as np
import random
import pickle
## import pytorch modules
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.utils.data as data
import torchvision.datasets as datasets
from tools import model_eval, save_or_pretrain, bayes_mv
from tools import loader, process_d0
from models.model_zoo import CNN, CNN1, BayesCNN, BayesCNN1
from trainer import trainer
import models.fashion as fmnist_models
import models.cifar as cifar_models
'''
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
'''
def main():
parser = argparse.ArgumentParser(description='Grid search')
""" dataset """
parser.add_argument('-data', default='cifar100', dest='data', help='MNIST/ Fashion MNIST/ CIFAR10/ CIFAR100')
parser.add_argument('-aug', default=1, type=float, help='Data augmentation or not')
""" ResNet models or WRN models """
parser.add_argument('-model', default='resnet', type=str, help='resnet / preact / WRN')
parser.add_argument('-depth', type=int, default=20, help='Model depth.')
""" number of training epochs """
parser.add_argument('-sn', default=500, type=int, help='Sampling Epochs')
""" learning rate, momentum, weight decay (L2), temperature and annealing factors """
parser.add_argument('-lr', default=2e-6, type=float, help='Sampling learning rate')
parser.add_argument('-momentum', default=0.9, type=float, help='Sampling momentum learning rate')
parser.add_argument('-wdecay', default=25, type=float, help='Samling weight decay')
parser.add_argument('-T', default=0.01, type=float, help='Inverse temperature for high temperature chain')
parser.add_argument('-anneal', default=1.002, type=float, help='temperature annealing factor (default for 500 epochs)')
parser.add_argument('-lr_anneal', default=0.992, type=float, help='lr annealing factor (default for 500 epochs)')
""" high-temperature hyperparameters """
parser.add_argument('-chains', default=2, type=int, help='Total number of chains')
parser.add_argument('-types', default='swap', type=str, help='swap type: greedy (low T copy high T), swap (low high T swap)')
parser.add_argument('-Tgap', default=0.2, type=float, help='Tgap=low-temperature /high-temperature')
parser.add_argument('-LRgap', default=0.66, type=float, help='LRgap=low-temperature lr / high-temperature lr')
""" default settings for the correction factor F """
parser.add_argument('-bias_F', default=3e5, type=float, help='correction factor F')
parser.add_argument('-F_jump', default=1, type=float, help='F jump factor')
parser.add_argument('-cool', default=0, type=int, help='No swaps happen during the cooling time after a swap')
# other settings
parser.add_argument('-ck', default=False, type=bool, help='Check if we need overwriting check')
parser.add_argument('-total', default=50000, type=int, help='Total data points')
parser.add_argument('-train', default=1000, type=int, help='Training batch size')
parser.add_argument('-test', default=500, type=int, help='Testing batch size')
parser.add_argument('-seed', default=random.randint(1, 1e6), type=int, help='Random Seed')
parser.add_argument('-gpu', default=0, type=int, help='Default GPU')
parser.add_argument('-multi', default=0, type=int, help='Multiple GPUs')
parser.add_argument('-var', default=5, type=int, help='estimate variance piecewise, positive value means estimate every [var] epochs')
parser.add_argument('-windows', default=20, type=int, help='Moving average of corrections')
parser.add_argument('-alpha', default=0.3, type=float, help='forgetting rate')
parser.add_argument('-repeats', default=50, type=int, help='number of samples to estimate sample std')
parser.add_argument('-burn', default=200, type=float, help='burn in iterations for sampling when sn = 1000')
parser.add_argument('-ifstop', default=1, type=int, help='stop iteration if acc is too low')
parser.add_argument('-split', default=2, type=int, help='Bayes avg every split epochs')
pars = parser.parse_args()
""" Step 0: Numpy printing setup and set GPU and Seeds """
print(pars)
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
try:
torch.cuda.set_device(pars.gpu)
except: # in case the device has only one GPU
torch.cuda.set_device(0)
torch.manual_seed(pars.seed)
torch.cuda.manual_seed(pars.seed)
np.random.seed(pars.seed)
random.seed(pars.seed)
torch.backends.cudnn.deterministic=True
""" Step 1: Preprocessing """
if pars.ck:
if raw_input('Are you sure to overwrite the pretrained model? [y/n]') not in ['y', 'Y']:
sys.exit('Fail in overwriting')
if not torch.cuda.is_available():
exit("CUDA does not exist!!!")
if pars.model == 'resnet':
if pars.data == 'fmnist':
net = fmnist_models.__dict__['resnet'](num_classes=10, depth=pars.depth).cuda()
elif pars.data == 'cifar10':
net = cifar_models.__dict__['resnet'](num_classes=10, depth=pars.depth).cuda()
elif pars.data == 'cifar100':
net = cifar_models.__dict__['resnet'](num_classes=100, depth=pars.depth).cuda()
elif pars.model == 'wrn':
if pars.data == 'fmnist':
net = fmnist_models.__dict__['wrn'](num_classes=10, depth=16, widen_factor=8, dropRate=0).cuda()
if pars.data == 'cifar10':
net = cifar_models.__dict__['wrn'](num_classes=10, depth=16, widen_factor=8, dropRate=0).cuda()
elif pars.data == 'cifar100':
net = cifar_models.__dict__['wrn'](num_classes=100, depth=16, widen_factor=8, dropRate=0).cuda()
elif pars.model == 'wrn28':
if pars.data == 'fmnist':
net = fmnist_models.__dict__['wrn'](num_classes=10, depth=28, widen_factor=10, dropRate=0).cuda()
if pars.data == 'cifar10':
net = cifar_models.__dict__['wrn'](num_classes=10, depth=28, widen_factor=10, dropRate=0).cuda()
elif pars.data == 'cifar100':
net = cifar_models.__dict__['wrn'](num_classes=100, depth=28, widen_factor=10, dropRate=0).cuda()
# parallelized over multiple GPUs in the batch dimension
if pars.multi:
net = torch.nn.DataParallel(net).cuda()
nets = [net]
for _ in range(1, pars.chains):
nets.append(pickle.loads(pickle.dumps(net)))
""" Step 2: Load Data """
train_loader, test_loader, targetloader = loader(pars.train, pars.test, pars)
""" Step 3: Bayesian Sampling """
trainer(nets, train_loader, test_loader, pars)
if __name__ == "__main__":
main()