-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsymeig.py
139 lines (114 loc) · 4.5 KB
/
symeig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
code from https://colab.research.google.com/drive/1wyt3357g8J91PQggLh8CLgvm4RiBmjQv
please see https://github.com/pytorch/pytorch/issues/47599
"""
from torch.autograd import Function
import torch
class Symeig(Function):
@staticmethod
def forward(ctx, input):
lambda_, v = torch.symeig(input, eigenvectors=True, upper=True)
ctx.save_for_backward(input, lambda_, v)
return lambda_, v
@staticmethod
def backward(ctx, glambda_, gv):
# unpack and initializaiton
input, lambda_, v = ctx.saved_tensors
grad_input = None
#
vh = v.conj().transpose(-2, -1)
# contribution from the eigenvectors
if gv is not None:
F = lambda_.unsqueeze(-2) - lambda_.unsqueeze(-1)
# F.diagonal(0, -2, -1).fill_(float("Inf"))
# idx = (lambda_ < 0.5).to(dtype=torch.int).sum()
# assert torch.allclose(lambda_[:idx], torch.zeros(idx).type_as(lambda_), atol=1e-6)
# assert torch.allclose(lambda_[idx:], torch.ones(idx).type_as(lambda_), atol=1e-6)
# F[..., :idx, :idx].fill_(float("Inf"))
# F[..., idx:, idx:].fill_(float("Inf"))
min_threshold = 1e-6
idx = torch.abs(F) < min_threshold
F[idx] = float("inf")
idx = torch.abs(F) < min_threshold
Fsign = torch.sign(F[idx])
F[idx] = Fsign * min_threshold
F.pow_(-1)
result = v @ ((F * (vh @ gv)) @ vh)
# gv_term = (v * lambda_.unsqueeze(-2)) @ gv.conj().transpose(-2, -1)
# result = gv_term + gv_term.transpose(-2, -1)
else:
result = torch.zeros_like(input)
# contribution from eigenvalues
if glambda_ is not None:
glambda_ = glambda_.type_as(input)
glambda_term = (v * glambda_.unsqueeze(-2)) @ vh
result = result + glambda_term
grad_input = result.add(result.conj().transpose(-2, -1)).mul_(0.5)
return grad_input
class symeig1_fcn(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
eival, eivec = torch.symeig(A, eigenvectors=True)
ctx.save_for_backward(eival, eivec)
return eival, eivec
@staticmethod
def backward(ctx, grad_eival, grad_eivec):
# parameters to adjust
min_threshold = 1e-6
eival, eivec = ctx.saved_tensors
eivect = eivec.transpose(-2, -1)
# for demo only: only take the contribution from grad_eivec
F = eival.unsqueeze(-2) - eival.unsqueeze(-1)
# modified step: change the difference of degenerate eigenvalues with inf,
# instead of only changing the diagonal
idx = torch.abs(F) < min_threshold
F[idx] = float("inf")
# an additional step: clip the value to have min_threshold so the
# instability isn't severe
idx = torch.abs(F) < min_threshold
Fsign = torch.sign(F[idx])
F[idx] = Fsign * min_threshold
F = F.pow(-1)
F = F * (eivect @ grad_eivec)
res = eivec @ F @ eivect
return (res + res.transpose(-2, -1)) * 0.5, None
class symeig2_fcn(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
eival, eivec = torch.symeig(A, eigenvectors=True)
ctx.save_for_backward(eival, eivec)
return eival, eivec
@staticmethod
def backward(ctx, grad_eival, grad_eivec):
# parameters to adjust
min_threshold = 1e-6
eival, eivec = ctx.saved_tensors
eivect = eivec.transpose(-2, -1)
# for demo only: only take the contribution from grad_eivec
F = eival.unsqueeze(-2) - eival.unsqueeze(-1)
Fdiag = F.diagonal()
Fdiag[:] = float("inf")
# modified step as in symeig1
idx = torch.abs(F) < min_threshold
F[idx] = float("inf")
# additional step: check the condition and return `nan` if not satisfied
degenerate = torch.any(idx)
if degenerate:
# check if the loss function does not depend on which linear combination
# of the degenerate eigenvectors
# (ref: https://arxiv.org/pdf/2011.04366.pdf eq. 2.13)
xtg = eivect @ grad_eivec
diff_xtg = (xtg - xtg.transpose(-2, -1))[idx]
reqsat = torch.allclose(diff_xtg, torch.zeros_like(diff_xtg))
# if the requirement is not satisfied, mathematically the derivative
# should be `nan`.
if not reqsat:
res = torch.zeros_like(eivec) + float("nan")
return res, None
F = F.pow(-1)
F = F * (eivect @ grad_eivec)
res = eivec @ F @ eivect
return (res + res.transpose(-2, -1)) * 0.5, None
symeig1 = symeig1_fcn.apply
symeig2 = symeig2_fcn.apply
symeig_proj = Symeig.apply