-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathtest_train.small.gemma.infini.py
150 lines (122 loc) · 4.3 KB
/
test_train.small.gemma.infini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # TODO: set the GPU device
os.environ["WANDB_PROJECT"] = "InfiniTransformer"
# os.environ["WANDB_MODE"] = "offline"
from itertools import chain
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
Trainer,
TrainingArguments,
set_seed,
default_data_collator,
)
from infini_gemma import GemmaForCausalLM, GemmaConfig
set_seed(42)
print("Torch Version:", torch.__version__)
print("CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
device = "cuda:0" # set GPU device using CUDA_VISIBLE_DEVICES
else:
device = "cpu"
if os.path.exists("./models/gemma-2b"):
model = GemmaForCausalLM.from_pretrained(
"./models/gemma-2b", torch_dtype="auto", device_map="auto"
)
config = model.config
print(config)
print(model)
else:
config = GemmaConfig.from_pretrained(
"google/gemma-2b",
attn_implementation="eager",
)
# config.max_position_embeddings = 128
config.use_cache = False
config.segment_size = config.max_position_embeddings
print(config)
# Create the Gemma model with Infini-attention
model = GemmaForCausalLM(config)
# model = model.from_pretrained("google/gemma-2b")
pretrained_model = GemmaForCausalLM.from_pretrained(
"google/gemma-2b", torch_dtype="auto"
)
# Step 4: Transfer weights
# Note: This is a simplified example; you need to ensure that each parameter's dimensions match.
for param in model.named_parameters():
name = param[0]
if name in pretrained_model.state_dict():
# Check if dimensions match, and only then assign the weights
if param[1].size() == pretrained_model.state_dict()[name].size():
param[1].data = pretrained_model.state_dict()[name].data.clone()
else:
print(f"Skipping {name} due to size mismatch.")
print(model)
# model = model.to(torch.bfloat16)
model = model.to(device)
# wiki = load_dataset("wikipedia", "20220301.en", split="train[:20000]")
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
def tokenize_function(examples):
return tokenizer(examples["text"])
try:
column_names = list(wiki["train"].features)
except KeyError:
column_names = list(wiki.features)
tokenized_datasets = wiki.map(
tokenize_function, remove_columns=column_names, batched=True
)
block_size = config.segment_size * 4 # will be 32768
print("block_size:", block_size)
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
)
print(lm_datasets)
# print(lm_datasets["train"]["input_ids"][0])
training_args = TrainingArguments(
output_dir="./models/gemma-2b-wikitext",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=1, # to test batch dim
save_total_limit=1,
report_to="wandb", # "none" if you don't want to report to wandb
run_name="gemma-2b-wikitext",
optim="adafactor",
learning_rate=1e-4,
bf16=True,
logging_first_step=True,
logging_steps=1,
save_strategy="epoch",
# warmup_ratio=0.1,
max_grad_norm=1.0,
gradient_checkpointing=True, # Reduce vram 69G -> 43G
)
try:
train_dataset = lm_datasets["train"]
except KeyError:
train_dataset = lm_datasets
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
# eval_dataset=lm_datasets["validation"],
data_collator=default_data_collator,
)
trainer.train()