Skip to content

Commit

Permalink
finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
SoufianeNoubir committed Dec 30, 2024
1 parent edbda5a commit 5b6a68e
Show file tree
Hide file tree
Showing 2 changed files with 276 additions and 44 deletions.
145 changes: 105 additions & 40 deletions gbmi/exp_indhead/finetunebound.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


# %%
def loss_bound(model, s, w):
def loss_bound(model, s):

W_pos = model.W_pos
W_E = model.W_E
Expand All @@ -53,7 +53,7 @@ def loss_bound(model, s, w):

e_p = W_E.unsqueeze(dim=0) + W_pos.unsqueeze(dim=1)

everything = (
term_0 = (
einops.einsum(
e_p,
W_Q_0,
Expand All @@ -70,16 +70,16 @@ def loss_bound(model, s, w):
for p in range(2, n_ctx): #
tmp = torch.zeros((p, d_voc))
for t_q in range(d_voc):
tmp[-1, :] = everything[p - 1, t_q, p - 1, t_q]
tmp[-1, :] = term_0[p - 1, t_q, p - 1, t_q]

for t_k in range(d_voc):
tmp[-2, :] = everything[p - 1, t_q, p - 2, t_k]
tmp[:-2, :] = everything[p - 1, t_q, : p - 2, :]
tmp[-2, :] = term_0[p - 1, t_q, p - 2, t_k]
tmp[:-2, :] = term_0[p - 1, t_q, : p - 2, :]
tmp_sm = tmp.softmax(dim=0)
table[t_q, t_k, p - 2, :] = tmp_sm[-2, :]
# Table represents post softmax attention paid to t_k, if the final entry is spammed everywhere, and t_q is used as the first entry, at pth poisition

# everything looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
# term_0 looks like EQKE, table looks like you're indexing by query, key, position (of key?), and other token in the sequence.
# They you're computing softmax of d_voc - 2 copies of the other token, one copy of t_k in p-2, and the query in p-1.
# Then you store the post-softmax attention paid to t_k.
#
Expand Down Expand Up @@ -177,6 +177,9 @@ def loss_bound(model, s, w):
"q_pos q_val k, k l, l m, m n, n p, p q -> q_pos q_val q",
)

if s == -1:
return (term_0, term_1, term_2, term_3, term_4, term_5, term_6, term_7, term_8)

if s == 0:
reduced_3 = einops.einsum(
term_3, "q_pos q_val k_pos k_val -> q_pos q_val k_pos"
Expand Down Expand Up @@ -421,17 +424,27 @@ def least_attention(a, i_1, i_2, j, dic):
if s == 2:
return (attn_1, bound, bound_2)

def loss_diff_1(b, i_1, i_2, dic):
def loss_diff_1(b, i_1, i_2, dic, n=None):

if n == b:
return 0

n = torch.arange(d_voc)[torch.arange(d_voc) != b]
if n is None:

n = torch.arange(d_voc)[torch.arange(d_voc) != b]

return (
term_5[i_2, dic[i_2]][..., n] - term_5[i_2, :, b].unsqueeze(dim=-1)
term_5[i_2, dic[i_2]][..., n] - term_5[i_2, dic[i_2], b].unsqueeze(dim=-1)
).max()

def loss_diff_2(b, i_1, i_2, dic):
def loss_diff_2(b, i_1, i_2, dic, n=None):

if n == b:
return 0

if n is None:

n = torch.arange(d_voc)[torch.arange(d_voc) != b]
n = torch.arange(d_voc)[torch.arange(d_voc) != b]

c = (term_6[0, dic[0]][..., n] - term_6[0, dic[0], b].unsqueeze(dim=-1)).max()

Expand Down Expand Up @@ -460,8 +473,12 @@ def loss_diff_2(b, i_1, i_2, dic):
)
return ld_2

def loss_diff_3(b, i_1, i_2, dic):
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
def loss_diff_3(b, i_1, i_2, dic, n=None):
if n == b:
return 0

if n is None:
n = torch.arange(d_voc)[torch.arange(d_voc) != b]
c = (term_7[0, dic[0]][..., n] - term_7[0, dic[0], b].unsqueeze(dim=-1)).max()
for i in range(i_1):
c = torch.max(
Expand All @@ -488,9 +505,14 @@ def loss_diff_3(b, i_1, i_2, dic):
)
return ld_3

def loss_diff_4(b, i_1, i_2, dic):
def loss_diff_4(b, i_1, i_2, dic, n=None):

n = torch.arange(d_voc)[torch.arange(d_voc) != b]
if n == b:
return 0

if n is None:

n = torch.arange(d_voc)[torch.arange(d_voc) != b]

for k in range(i_2 + 1):
if k != 0 and k != 1:
Expand Down Expand Up @@ -546,32 +568,57 @@ def loss_diff_4(b, i_1, i_2, dic):
)
return ld_4

def total_bound(b, i_1, i_2, dic):
def total_bound(b, i_1, i_2, dic, n=None):
return (
loss_diff_1(b, i_1, i_2, dic)
+ loss_diff_2(b, i_1, i_2, dic)
+ loss_diff_3(b, i_1, i_2, dic)
+ loss_diff_4(b, i_1, i_2, dic)
loss_diff_1(b, i_1, i_2, dic, n)
+ loss_diff_2(b, i_1, i_2, dic, n)
+ loss_diff_3(b, i_1, i_2, dic, n)
+ loss_diff_4(b, i_1, i_2, dic, n)
)

out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
if s == 3:

out = torch.zeros((d_voc, n_ctx, n_ctx)) + torch.inf
# b i_2 i_1

for b in range(e_p.shape[1]):

for i_2 in range(e_p.shape[0] - 1):
for i_1 in range(1, i_2):

if (i_1 < i_2) & (i_1 > 0):
dic = {i_1: b}
for i in range(8):
dic.setdefault(i, torch.arange(26))

out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)

out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))

return (attn_1, bound, bound_2, out, out_2)

out = torch.zeros((d_voc, n_ctx, n_ctx, d_voc)) + torch.inf
# b i_2 i_1

for b in range(e_p.shape[1]):
for n in range(e_p.shape[1]):
for i_2 in range(e_p.shape[0] - 1):
for i_1 in range(1, i_2):

for i_2 in range(e_p.shape[0] - 1):
for i_1 in range(1, i_2):
if (i_1 < i_2) & (i_1 > 0):
dic = {i_1: b}
for i in range(8):
dic.setdefault(i, torch.arange(26))

if (i_1 < i_2) & (i_1 > 0):
dic = {i_1: b}
for i in range(8):
dic.setdefault(i, torch.arange(26))
out[b, i_2, i_1, n] = total_bound(b, i_1, i_2, dic, n)

out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)
out_2 = einops.einsum(out.softmax(dim=-1), "b i_2 i_1 b -> b i_2 i_1")

out_2 = 1 / (1 + ((d_voc - 1) * torch.exp(out)))
out_3 = einops.einsum(
out - out.max(dim=-1).values.unsqueeze(dim=-1), "b i_2 i_1 b -> b i_2 i_1"
)

return (attn_1, bound, bound_2, out, out_2)
return (attn_1, bound, bound_2, out, out_2, out_3)


# %%
Expand Down Expand Up @@ -647,22 +694,40 @@ def total_bound(b, i_1, i_2, dic):
counter += 1
print(counter)


# %%
valid = (
ein.array(
lambda i, j, k: where(k > 0, where(j > k, where(j < 7, 1, 0), 0), 0),
sizes=[d_voc, n_ctx, n_ctx],
)
.bool()
.to(device)
)
optimiser = torch.optim.AdamW(
model_1.parameters(), lr=5e-3, betas=(0.9, 0.999), weight_decay=1.0
model_1.parameters(), lr=0.5, betas=(0.9, 0.999), weight_decay=0
)
# %%
a = loss_bound(model_1, 3)[4]
loss = 1 - a[valid].min()
print(a[valid].min())
print(a[valid].mean())
print(a[valid].max())
for i in range(1):
print(i + 1)

a = loss_bound(model_1, 3, 8)[4]
loss = 1 - a[a != 0].mean()
for i in range(30):
print(a[a != 0].mean())
loss.backward()
optimiser.step()
optimiser.zero_grad()
a = loss_bound(model_1, 3, 8)[4][5]
loss = 1 - a[a != 0].mean()
counter += 1
print(counter)
a = loss_bound(model_1, 3)[4]
loss = 1 - a[valid].min()
print(a[valid].min())
print(a[valid].mean())
print(a[valid].max())
if i % 10 == 1:
r = loss_bound(model_1, 4)[5]
print(r[valid].min())
print(r[valid].mean())
print(r[valid].max())

# %%
'''
Expand Down
Loading

0 comments on commit 5b6a68e

Please sign in to comment.