-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathplot_results.py
59 lines (48 loc) · 1.75 KB
/
plot_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import csv
import numpy as np
import matplotlib.pyplot as plt
def read_trajectory(n_states, n_controls, results_path):
# TODO clean this up
reader = csv.reader(open(results_path))
states = []
controls = []
for i, row in enumerate(reader):
if i==0: continue
s = [float(num) for num in row[:n_states]]
states.append(s)
if row[-1] != ' ': # check for terminal state
a = [float(num) for num in row[-n_controls:]]
controls.append(a)
states = np.array(states)
controls = np.array(controls)
return states, controls
if __name__ == "__main__":
# TODO take n_states, n_controls, result_path from argparser
# Currently set up for acrobot
n_states = 4
n_controls = 1
results_path = 'build/ilqr_result.csv'
x, u = read_trajectory(n_states, n_controls, results_path)
plt.plot(x[:, 0], 'g', label='x1')
plt.plot(x[:, 1], 'b', label='x2')
plt.plot(u[:, 0], 'r--', label='u1')
plt.legend()
plt.show()
# results_path = 'results/acrobot_sqr.csv'
# x_sqr, a_sqr = read_trajectory(n_states, n_controls, results_path)
# plt.subplot(121)
# plt.plot(x_sqr[:, 0], 'g', label='x1')
# plt.plot(x_sqr[:, 1], 'b', label='x2')
# plt.plot(a_sqr[:, 0], 'r--', label='u1')
# plt.title('Square cost function')
# plt.legend()
# results_path = 'results/acrobot_sabs.csv'
# x_sabs, a_sabs = read_trajectory(n_states, n_controls, results_path)
# plt.subplot(122)
# plt.plot(x_sabs[:, 0], 'g', label='x1')
# plt.plot(x_sabs[:, 1], 'b', label='x2')
# plt.plot(a_sabs[:, 0], 'r--', label='u1')
# plt.title('Smooth abs cost function')
# plt.legend()
# plt.savefig('acrobot_cost_comparison.png')
# plt.show()