-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathTestMatrixCompletion.py
52 lines (43 loc) · 1.72 KB
/
TestMatrixCompletion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import Rank1MatrixCompletion as mc
import numpy as np
import matplotlib.pyplot as plt
# Random Sampling
# Make an NxN rank-1 matrix
N = 100
sampleRate = 0.1
A = np.dot(np.random.randn(N,1),np.random.randn(1,N))
# Make mask with random sampling
mask = np.less(np.random.rand(N,N),sampleRate)
# Zero out matrix where mask is 0
A[np.logical_not(mask)] = 0
f, (ax1, ax2) = plt.subplots(2, sharex=True, sharey=True)
ax1.imshow(A, interpolation="nearest")
ax2.imshow(mask, interpolation="nearest")
plt.show()
[AOut,maskOut] = mc.completeRank1Matrix(np.copy(A),np.copy(mask),True)
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2, sharex=True, sharey=True)
ax1.imshow(A, interpolation="nearest", vmin=-3, vmax=3)
ax3.imshow(mask, interpolation="nearest", vmin=-3, vmax=3)
ax2.imshow(AOut, interpolation="nearest", vmin=-3, vmax=3)
ax4.imshow(maskOut, interpolation="nearest", vmin=-3, vmax=3)
plt.show()
# Banded Diagonal Sampling
# Make an NxN rank-1 matrix
N = 50
sampleRate = 0.1
A = np.dot(np.random.randn(N,1),np.random.randn(1,N))
# Make mask with random sampling
mask = np.logical_or(np.diag(np.ones(N,dtype=bool)),np.diag(np.ones(N-1,dtype=bool),k=1))
# Zero out matrix where mask is 0
A[np.logical_not(mask)] = 0
f, (ax1, ax2) = plt.subplots(2, sharex=True, sharey=True)
ax1.imshow(A, interpolation="nearest")
ax2.imshow(mask, interpolation="nearest")
plt.show()
[AOut,maskOut] = mc.completeRank1Matrix(np.copy(A),np.copy(mask),True)
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2, sharex=True, sharey=True)
ax1.imshow(A, interpolation="nearest", vmin=-3, vmax=3)
ax3.imshow(mask, interpolation="nearest", vmin=-3, vmax=3)
ax2.imshow(AOut, interpolation="nearest", vmin=-3, vmax=3)
ax4.imshow(maskOut, interpolation="nearest", vmin=-3, vmax=3)
plt.show()