-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsimple_example_prior_posterior.py
44 lines (33 loc) · 1.16 KB
/
simple_example_prior_posterior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams["text.usetex"] = True
def f_(x, y):
return x**3 + y**2
def manifold_density_reweight_desired_(x, y):
return ((3 * x**2) ** 2 + (y) ** 2) ** (-1 / 2)
x_max = 2
x_min = -2
y_max = np.sqrt(1 - (-2) ** 3)
y_min = -np.sqrt(1 - (-2) ** 3)
general_max = np.maximum(y_max, x_max)
general_min = np.minimum(y_min, x_min)
n = 500
x = np.linspace(general_min, general_max, n)
y = np.linspace(general_min, general_max, n)
X, Y = np.meshgrid(x, y)
Z = manifold_density_reweight_desired_(X, Y)
Z = (Z - Z.min()) / (Z.max() - Z.min())
Z = Z * (2 - Z)
mask = f_(X, Y)
for Zm, name in [(Z, "prior"), (np.ma.masked_where((1.3 < mask) | (mask < 0.7), Z), "posterior_trivial")]:
plt.pcolormesh(X, Y, Zm, shading="gouraud", color="blue", cmap="Blues")
plt.xlim(general_min, general_max)
plt.ylim(general_min, general_max)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel(r"$\theta_1$", loc="center", fontsize=18)
plt.ylabel(r"$\theta_2$", loc="center", fontsize=18)
plt.gca().set_aspect("equal", adjustable="box")
plt.tight_layout()
plt.savefig(f"{name}.png")
plt.show()