ledmands
commited on
Commit
•
47aa47a
1
Parent(s):
91b9fbd
Removed unecessary comments in plot_improvement.py
Browse files- plot_improvement.py +3 -9
plot_improvement.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
-
import argparse
|
2 |
import numpy as np
|
3 |
import os
|
4 |
from matplotlib import pyplot as plt
|
5 |
|
6 |
def calc_stats(filepath):
|
7 |
-
# load the numpy file
|
8 |
data = np.load(filepath)["results"]
|
9 |
# sort the arrays and delete the first and last elements
|
10 |
data = np.sort(data, axis=1)
|
@@ -19,11 +18,6 @@ def calc_stats(filepath):
|
|
19 |
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
|
20 |
# args = parser.parse_args()
|
21 |
|
22 |
-
# Get the file paths and store in list.
|
23 |
-
# For each file path, I want to calculate the mean reward. This would be the mean reward for the training run over all evaluations.
|
24 |
-
# For each file path, append the mean reward to an averages list
|
25 |
-
# Plot the averages!
|
26 |
-
|
27 |
filepaths = []
|
28 |
for d in os.listdir("agents/"):
|
29 |
if "dqn_v2" in d:
|
@@ -40,8 +34,8 @@ for path in filepaths:
|
|
40 |
runs = []
|
41 |
for i in range(len(filepaths)):
|
42 |
runs.append(i + 1)
|
43 |
-
plt.xlabel("
|
44 |
-
plt.ylabel("
|
45 |
plt.bar(runs, means)
|
46 |
plt.bar(runs, stds)
|
47 |
plt.legend(["Mean evaluation score", "Standard deviation"])
|
|
|
1 |
+
# import argparse
|
2 |
import numpy as np
|
3 |
import os
|
4 |
from matplotlib import pyplot as plt
|
5 |
|
6 |
def calc_stats(filepath):
|
|
|
7 |
data = np.load(filepath)["results"]
|
8 |
# sort the arrays and delete the first and last elements
|
9 |
data = np.sort(data, axis=1)
|
|
|
18 |
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
|
19 |
# args = parser.parse_args()
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
filepaths = []
|
22 |
for d in os.listdir("agents/"):
|
23 |
if "dqn_v2" in d:
|
|
|
34 |
runs = []
|
35 |
for i in range(len(filepaths)):
|
36 |
runs.append(i + 1)
|
37 |
+
plt.xlabel("Training Run")
|
38 |
+
plt.ylabel("Score")
|
39 |
plt.bar(runs, means)
|
40 |
plt.bar(runs, stds)
|
41 |
plt.legend(["Mean evaluation score", "Standard deviation"])
|