Andrei Cozma commited on
Commit
080e344
·
1 Parent(s): 73cd2cf
Files changed (1) hide show
  1. MonteCarloAgent.py +16 -5
MonteCarloAgent.py CHANGED
@@ -102,13 +102,24 @@ class MonteCarloAgent:
102
  1 - self.epsilon + self.epsilon / self.n_actions
103
  )
104
 
105
- def train(self, n_train_episodes=2000, test_every=100, log_wandb=False, **kwargs):
 
 
 
 
 
 
 
106
  print(f"Training agent for {n_train_episodes} episodes...")
 
107
  train_running_success_rate, test_success_rate = 0.0, 0.0
108
  stats = {
109
  "train_running_success_rate": train_running_success_rate,
110
  "test_success_rate": test_success_rate,
111
  }
 
 
 
112
  tqrange = tqdm(range(n_train_episodes))
113
  tqrange.set_description("Training")
114
 
@@ -122,7 +133,7 @@ class MonteCarloAgent:
122
  train_running_success_rate = (
123
  0.99 * train_running_success_rate + 0.01 * finished
124
  )
125
- self.update_first_visit(episode_hist)
126
 
127
  stats = {
128
  "train_running_success_rate": train_running_success_rate,
@@ -232,9 +243,9 @@ def main():
232
  parser.add_argument(
233
  "--update_type",
234
  type=str,
235
- choices=["first-visit", "every-visit"],
236
- default="first-visit",
237
- help="The type of update to use. (default: first-visit)",
238
  )
239
 
240
  parser.add_argument(
 
102
  1 - self.epsilon + self.epsilon / self.n_actions
103
  )
104
 
105
+ def train(
106
+ self,
107
+ n_train_episodes=2000,
108
+ test_every=100,
109
+ update_type="first_visit",
110
+ log_wandb=False,
111
+ **kwargs,
112
+ ):
113
  print(f"Training agent for {n_train_episodes} episodes...")
114
+
115
  train_running_success_rate, test_success_rate = 0.0, 0.0
116
  stats = {
117
  "train_running_success_rate": train_running_success_rate,
118
  "test_success_rate": test_success_rate,
119
  }
120
+
121
+ update_func = getattr(self, f"update_{update_type}")
122
+
123
  tqrange = tqdm(range(n_train_episodes))
124
  tqrange.set_description("Training")
125
 
 
133
  train_running_success_rate = (
134
  0.99 * train_running_success_rate + 0.01 * finished
135
  )
136
+ update_func(episode_hist)
137
 
138
  stats = {
139
  "train_running_success_rate": train_running_success_rate,
 
243
  parser.add_argument(
244
  "--update_type",
245
  type=str,
246
+ choices=["first_visit", "every_visit"],
247
+ default="first_visit",
248
+ help="The type of update to use. (default: first_visit)",
249
  )
250
 
251
  parser.add_argument(