-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
143 lines (108 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn.functional as F
from transformer_efficient import miniGPT, block_size, device
## Train Params ##
current_file = "texts.txt"
load_path = "./saved_models/multi_batch.pt"
save_path = "./saved_models/multi_batch.pt"
num_heads = 4
num_blocks = 2
batch_size = 64
num_itrs = 10000
learning_rate = 4e-3
##################
def get_data():
with open(current_file, 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
print(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
return data[:n], data[n:]
def get_vocab():
f = open(current_file, 'r', encoding='utf-8')
f.seek(0)
text = f.read()
f.close()
chars = sorted(list(set(text)))
return chars
def get_batch(data, batch_size):
# generate a small batch of data of inputs x and targets y
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
def forward(model, x, vocab_size=None, targets=None):
cross_entropy = torch.nn.CrossEntropyLoss()
logits = model(x)
loss = None
if targets != None:
loss = cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
return logits, loss
def train(load_path=None, save_path=None):
t, v = get_data() # training and val data
vocab = get_vocab()
model = miniGPT(num_heads, num_blocks, len(vocab)).to(device=device) # want num_heads to be a divisor of n_embd
if load_path != None:
model.load_state_dict(torch.load(load_path))
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for i in range(num_itrs):
optimizer.zero_grad()
xb, yb = get_batch(t, batch_size)
logits, loss = forward(model, xb, len(vocab), yb)
if i % 1000 == 0:
print(loss.item())
loss.backward()
optimizer.step()
if save_path != None:
torch.save(model.state_dict(), save_path)
@torch.no_grad()
def eval(model_path, set="val"):
vocab = get_vocab()
model = miniGPT(num_heads, num_blocks, len(vocab))
model.load_state_dict(torch.load(model_path))
model.eval()
t, v = get_data()
# Move model and data to device
model = model.to(device)
if set == "val":
xval, yval = get_batch(v, v.shape[0])
val_logits, val_loss = forward(model, xval, len(vocab), yval)
print("The validation loss is", val_loss.item())
else:
xtr, ytr = get_batch(t, t.shape[0])
tr_logits, tr_loss = forward(model, xtr, len(vocab), ytr)
print("The train loss is", tr_loss.item())
@torch.no_grad()
def infer(model_path, prior=""):
vocab = get_vocab()
model = miniGPT(num_heads, num_blocks, len(vocab)).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
stoi = { ch:i for i,ch in enumerate(vocab) }
context = torch.zeros((1, block_size), dtype=torch.long).to(device) # 1, 8
if prior != "":
for i, c in enumerate(prior):
context[0, i] = stoi[c]
itos = { i:ch for i,ch in enumerate(vocab) }
output = prior
print("Model Loaded.")
for i in range(500):
logits, loss = forward(model, context)
logits = logits.softmax(dim=-1).squeeze() # 1, block_size, vocab_size
# sample from distribution using torch.multinomial
samples = torch.multinomial(logits, 1) # 1, block_size, 1
pred = samples[-1]
output += itos[pred.item()]
context = torch.cat((context[:, 1:], pred.view(1, -1)), dim=1)
print(output)
# train(load_path, save_path)
eval(load_path, "val")
infer(load_path)