Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
e17747a
1
Parent(s):
2ad3cc0
Updates
Browse files- MCAgent.py +7 -4
- Shared.py +1 -1
MCAgent.py
CHANGED
@@ -75,6 +75,7 @@ class MCAgent(Shared):
|
|
75 |
log_wandb=False,
|
76 |
save_best=True,
|
77 |
save_best_dir=None,
|
|
|
78 |
**kwargs,
|
79 |
):
|
80 |
print(f"Training agent for {n_train_episodes} episodes...")
|
@@ -140,15 +141,17 @@ class MCAgent(Shared):
|
|
140 |
wandb.log(stats)
|
141 |
|
142 |
if test_running_success_rate > 0.999:
|
143 |
-
print(
|
144 |
-
f"CONVERGED: test success rate running avg reached 100% after {e} episodes."
|
145 |
-
)
|
146 |
if save_best:
|
147 |
if self.run_name is None:
|
148 |
print("WARNING: run_name is None, not saving best policy.")
|
149 |
else:
|
150 |
self.save_policy(self.run_name, save_best_dir)
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def wandb_log_img(self, episode=None):
|
154 |
caption_suffix = "Initial" if episode is None else f"After Episode {episode}"
|
|
|
75 |
log_wandb=False,
|
76 |
save_best=True,
|
77 |
save_best_dir=None,
|
78 |
+
early_stopping=False,
|
79 |
**kwargs,
|
80 |
):
|
81 |
print(f"Training agent for {n_train_episodes} episodes...")
|
|
|
141 |
wandb.log(stats)
|
142 |
|
143 |
if test_running_success_rate > 0.999:
|
|
|
|
|
|
|
144 |
if save_best:
|
145 |
if self.run_name is None:
|
146 |
print("WARNING: run_name is None, not saving best policy.")
|
147 |
else:
|
148 |
self.save_policy(self.run_name, save_best_dir)
|
149 |
+
|
150 |
+
if early_stopping:
|
151 |
+
print(
|
152 |
+
f"CONVERGED: test success rate running avg reached 100% after {e} episodes."
|
153 |
+
)
|
154 |
+
break
|
155 |
|
156 |
def wandb_log_img(self, episode=None):
|
157 |
caption_suffix = "Initial" if episode is None else f"After Episode {episode}"
|
Shared.py
CHANGED
@@ -12,7 +12,7 @@ class Shared:
|
|
12 |
gamma=0.99,
|
13 |
epsilon=0.1,
|
14 |
run_name=None,
|
15 |
-
frozenlake_size=
|
16 |
**kwargs,
|
17 |
):
|
18 |
print("=" * 80)
|
|
|
12 |
gamma=0.99,
|
13 |
epsilon=0.1,
|
14 |
run_name=None,
|
15 |
+
frozenlake_size=8,
|
16 |
**kwargs,
|
17 |
):
|
18 |
print("=" * 80)
|