Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: utilize Einops in appropriate areas to enhance readability #42

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 32 additions & 60 deletions src/efficient_kan/kan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
import math

from einops import rearrange, repeat, reduce

class KANLinear(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -87,21 +87,14 @@ def b_splines(self, x: torch.Tensor):
"""
assert x.dim() == 2 and x.size(1) == self.in_features

grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
grid: torch.Tensor = self.grid # (in_features, grid_size + 2 * spline_order + 1)
x = rearrange(x, 'b f -> b f 1')
bases = (x >= grid[:, :-1]) & (x < grid[:, 1:])
bases = bases.to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
left = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])
right = (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])
bases = left * bases[:, :, :-1] + right * bases[:, :, 1:]

assert bases.size() == (
x.size(0),
Expand All @@ -124,16 +117,13 @@ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)

A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
A = self.b_splines(x) # (batch_size, in_features, grid_size + spline_order)
A = rearrange(A, 'b f g -> f b g') # (in_features, batch_size, grid_size + spline_order)
B = rearrange(y, 'b f o -> f b o') # (in_features, batch_size, out_features)

solution = torch.linalg.lstsq(A, B).solution # (in_features, grid_size + spline_order, out_features)

result = rearrange(solution, 'f g o -> o f g') # (out_features, in_features, grid_size + spline_order)

assert result.size() == (
self.out_features,
Expand All @@ -142,6 +132,7 @@ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
)
return result.contiguous()


@property
def scaled_spline_weight(self):
return self.spline_weight * (
Expand Down Expand Up @@ -171,47 +162,28 @@ def update_grid(self, x: torch.Tensor, margin=0.01):
batch = x.size(0)

splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
splines = rearrange(splines, 'b in coeff -> in b coeff')
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
orig_coeff = rearrange(orig_coeff, 'out in coeff -> in coeff out')
unreduced_spline_output = torch.einsum('ibc, ico -> ibo', splines, orig_coeff)

# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
indices = torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)
grid_adaptive = x_sorted[indices]

uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid_uniform = repeat(torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device), 'g -> g 1')
grid_uniform = grid_uniform * uniform_step + x_sorted[0] - margin

grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)

self.grid.copy_(grid.T)
spline_order_range = torch.arange(self.spline_order, 0, -1, device=x.device)
grid = torch.cat([
grid[:1] - uniform_step * repeat(spline_order_range, 'o -> o 1'),
grid,
grid[-1:] + uniform_step * repeat(torch.arange(1, self.spline_order + 1, device=x.device), 'o -> o 1')
], dim=0)

self.grid.copy_(rearrange(grid, 'g 1 -> 1 g'))
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
Expand All @@ -227,7 +199,7 @@ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0)
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
l1_fake = reduce(self.spline_weight, '... -> ...', 'abs').mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
Expand Down