diff --git a/makemore.py b/makemore.py index 75efc08..f28bfc7 100644 --- a/makemore.py +++ b/makemore.py @@ -28,16 +28,19 @@ from torch.utils.tensorboard import SummaryWriter # ----------------------------------------------------------------------------- -# GPT model definition @dataclass -class GPTConfig: - # size of the model +class ModelConfig: + block_size: int = None # length of the input sequences of integers + vocab_size: int = None # the input integers are in range [0 .. vocab_size -1] + # parameters below control the sizes of each model slightly differently n_layer: int = 4 - n_head: int = 4 n_embd: int = 64 - vocab_size: int = None - block_size: int = None + n_embd2: int = 64 + n_head: int = 4 + +# ----------------------------------------------------------------------------- +# Transformer Language Model (*exactly* as used in GPT-2) class NewGELU(nn.Module): """ @@ -127,6 +130,9 @@ def __init__(self, config): n_params = sum(p.numel() for p in self.transformer.parameters()) print("number of parameters: %.2fM" % (n_params/1e6,)) + def get_block_size(self): + return self.block_size + def forward(self, idx, targets=None): device = idx.device b, t = idx.size() @@ -149,45 +155,124 @@ def forward(self, idx, targets=None): return logits, loss - @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): - """ - Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete - the sequence max_new_tokens times, feeding the predictions back into the model each time. - Most likely you'll want to make sure to be in model.eval() mode of operation for this. - """ - for _ in range(max_new_tokens): - # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] - # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, top_k) - logits[logits < v[:, [-1]]] = -float('Inf') - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # either sample from the distribution or take the most likely element - if do_sample: - idx_next = torch.multinomial(probs, num_samples=1) - else: - _, idx_next = torch.topk(probs, k=1, dim=-1) - # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) - - return idx +# ----------------------------------------------------------------------------- +# MLP language model + +class MLP(nn.Module): + """ + takes the previous block_size tokens, encodes them with a lookup table, + concatenates the vectors and predicts the next token with an MLP. + + Reference: + Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf + """ + + def __init__(self, config): + super().__init__() + self.block_size = config.block_size + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table + # +1 in the line above for a special token that gets inserted if encoding a token + # before the beginning of the input sequence + self.mlp = nn.Sequential( + nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this + nn.Tanh(), + nn.Linear(config.n_embd2, self.vocab_size) + ) + + def get_block_size(self): + return self.block_size + + def forward(self, idx, targets=None): + + # gather the word embeddings of the previous 3 words + embs = [] + for k in range(self.block_size): + tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) + idx = torch.roll(idx, 1, 1) + idx[:, 0] = self.vocab_size # special token + embs.append(tok_emb) + + # concat all of the embeddings together and pass through an MLP + x = torch.cat(embs, -1) # (b, t, n_embd * block_size) + logits = self.mlp(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + + return logits, loss + +# ----------------------------------------------------------------------------- +# Bigram language model + +class Bigram(nn.Module): + """ + Bigram Language Model 'neural net', simply a lookup table of logits for the + next character given a previous character. + """ + + def __init__(self, config): + super().__init__() + n = config.vocab_size + self.logits = nn.Parameter(torch.zeros((n, n))) + + def get_block_size(self): + return 1 # this model only needs one previous character to predict the next + + def forward(self, idx, targets=None): + + # 'forward pass', lol + logits = self.logits[idx] + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + + return logits, loss # ----------------------------------------------------------------------------- # helper functions for evaluating and sampling from the model +@torch.no_grad() +def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + block_size = model.get_block_size() + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] + # forward the model to get the logits for the index in the sequence + logits, _ = model(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # either sample from the distribution or take the most likely element + if do_sample: + idx_next = torch.multinomial(probs, num_samples=1) + else: + _, idx_next = torch.topk(probs, k=1, dim=-1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx + def print_samples(num=10): """ samples from the model and pretty prints the decoded samples """ X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device) top_k = args.top_k if args.top_k != -1 else None steps = train_dataset.get_output_length() - 1 # -1 because we already start with token (index 0) - X_samp = model.generate(X_init, steps, top_k=top_k, do_sample=True).to('cpu') + X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu') train_samples, test_samples, new_samples = [], [], [] for i in range(X_samp.size(0)): # get the i'th row of sampled integers, as python list @@ -333,9 +418,11 @@ def next(self): # sampling parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k") # model - parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer") - parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer") - parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer") + parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer") + parser.add_argument('--n-layer', type=int, default=4, help="number of layers") + parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)") + parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model") + parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model") # optimization parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization") parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate") @@ -356,9 +443,14 @@ def next(self): print(f"dataset determined that: {vocab_size=}, {block_size=}") # init model - config = GPTConfig(vocab_size=vocab_size, block_size=block_size, - n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd) - model = GPT(config) + config = ModelConfig(vocab_size=vocab_size, block_size=block_size, + n_layer=args.n_layer, n_head=args.n_head, + n_embd=args.n_embd, n_embd2=args.n_embd2) + model = { + 'transformer': GPT, + 'bigram': Bigram, + 'mlp': MLP, + }[args.type](config) model.to(args.device) print(f"model #params: {sum(p.numel() for p in model.parameters())}") if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming