Andrei Cozma commited on
Commit
668f525
·
1 Parent(s): 49c1e2a
Files changed (2) hide show
  1. Shared.py +15 -12
  2. demo.py +28 -6
Shared.py CHANGED
@@ -2,10 +2,11 @@ import os
2
  import numpy as np
3
  import gymnasium as gym
4
 
 
5
  class Shared:
6
-
7
  def __init__(
8
- self,/,
 
9
  env="CliffWalking-v0",
10
  gamma=0.99,
11
  epsilon=0.1,
@@ -20,8 +21,9 @@ class Shared:
20
  self.run_name = run_name
21
  self.env_name = env
22
  self.epsilon, self.gamma = epsilon, gamma
 
23
 
24
- self.env_kwargs = {k:v for k,v in kwargs.items() if k in ['render_mode']}
25
  if self.env_name == "FrozenLake-v1":
26
  # Can use defaults by defining map_name (4x4 or 8x8) or custom map by defining desc
27
  # self.env_kwargs["map_name"] = "8x8"
@@ -46,23 +48,24 @@ class Shared:
46
  print(f"- n_states: {self.n_states}")
47
  print(f"- n_actions: {self.n_actions}")
48
 
49
- def choose_action(self, state, epsilon_override=None, greedy=False, **kwargs):
50
  # Sample an action from the policy.
51
  # The epsilon_override argument allows forcing the use of a new epsilon value than the one previously used during training.
52
  # The ability to override was mostly added for testing purposes and for the demo.
53
  greedy_action = np.argmax(self.Pi[state])
54
 
55
- if greedy or epsilon_override == 0:
56
  return greedy_action
57
 
58
- if epsilon_override is None:
59
  return np.random.choice(self.n_actions, p=self.Pi[state])
60
 
 
61
  return np.random.choice(
62
  [greedy_action, np.random.randint(self.n_actions)],
63
- p=[1 - epsilon_override, epsilon_override],
64
  )
65
-
66
  def generate_episode(self, max_steps=500, render=False, **kwargs):
67
  state, _ = self.env.reset()
68
  episode_hist, solved, rgb_array = (
@@ -118,9 +121,9 @@ class Shared:
118
 
119
  def run_episode(self, max_steps=500, render=False, **kwargs):
120
  # Run the generator until the end
121
- episode_hist, solved, rgb_array = list(self.generate_episode(
122
- max_steps, render, **kwargs
123
- ))[-1]
124
  return episode_hist, solved, rgb_array
125
 
126
  def test(self, n_test_episodes=100, verbose=True, greedy=True, **kwargs):
@@ -143,7 +146,7 @@ class Shared:
143
  f"Agent reached the goal in {num_successes}/{n_test_episodes} episodes ({success_rate * 100:.2f}%)"
144
  )
145
  return success_rate
146
-
147
  def save_policy(self, fname="policy.npy", save_dir=None):
148
  if save_dir is not None:
149
  os.makedirs(save_dir, exist_ok=True)
 
2
  import numpy as np
3
  import gymnasium as gym
4
 
5
+
6
  class Shared:
 
7
  def __init__(
8
+ self,
9
+ /,
10
  env="CliffWalking-v0",
11
  gamma=0.99,
12
  epsilon=0.1,
 
21
  self.run_name = run_name
22
  self.env_name = env
23
  self.epsilon, self.gamma = epsilon, gamma
24
+ self.epsilon_override = None
25
 
26
+ self.env_kwargs = {k: v for k, v in kwargs.items() if k in ["render_mode"]}
27
  if self.env_name == "FrozenLake-v1":
28
  # Can use defaults by defining map_name (4x4 or 8x8) or custom map by defining desc
29
  # self.env_kwargs["map_name"] = "8x8"
 
48
  print(f"- n_states: {self.n_states}")
49
  print(f"- n_actions: {self.n_actions}")
50
 
51
+ def choose_action(self, state, greedy=False, **kwargs):
52
  # Sample an action from the policy.
53
  # The epsilon_override argument allows forcing the use of a new epsilon value than the one previously used during training.
54
  # The ability to override was mostly added for testing purposes and for the demo.
55
  greedy_action = np.argmax(self.Pi[state])
56
 
57
+ if greedy or self.epsilon_override == 0.0:
58
  return greedy_action
59
 
60
+ if self.epsilon_override is None:
61
  return np.random.choice(self.n_actions, p=self.Pi[state])
62
 
63
+ print("epsilon_override", self.epsilon_override)
64
  return np.random.choice(
65
  [greedy_action, np.random.randint(self.n_actions)],
66
+ p=[1.0 - self.epsilon_override, self.epsilon_override],
67
  )
68
+
69
  def generate_episode(self, max_steps=500, render=False, **kwargs):
70
  state, _ = self.env.reset()
71
  episode_hist, solved, rgb_array = (
 
121
 
122
  def run_episode(self, max_steps=500, render=False, **kwargs):
123
  # Run the generator until the end
124
+ episode_hist, solved, rgb_array = list(
125
+ self.generate_episode(max_steps, render, **kwargs)
126
+ )[-1]
127
  return episode_hist, solved, rgb_array
128
 
129
  def test(self, n_test_episodes=100, verbose=True, greedy=True, **kwargs):
 
146
  f"Agent reached the goal in {num_successes}/{n_test_episodes} episodes ({success_rate * 100:.2f}%)"
147
  )
148
  return success_rate
149
+
150
  def save_policy(self, fname="policy.npy", save_dir=None):
151
  if save_dir is not None:
152
  os.makedirs(save_dir, exist_ok=True)
demo.py CHANGED
@@ -81,12 +81,24 @@ def change_render_fps(state, x):
81
  return state
82
 
83
 
 
 
 
 
 
 
84
  def change_epsilon(state, x):
85
  print("Changing greediness:", x)
86
  state.live_epsilon = x
87
  return state
88
 
89
 
 
 
 
 
 
 
90
  def change_paused(state, x):
91
  print("Changing paused:", x)
92
  state.live_paused = pause_val_map[x]
@@ -159,9 +171,9 @@ def run(
159
  agent.generate_episode(
160
  max_steps=max_steps,
161
  render=True,
162
- epsilon_override=localstate.live_epsilon,
163
  )
164
  ):
 
165
  _, _, last_reward = (
166
  episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
167
  )
@@ -207,7 +219,7 @@ def run(
207
  str(action),
208
  (
209
  label_loc_w - label_width // 2,
210
- label_loc_h + label_height // 2,
211
  ),
212
  frame_policy_label_font,
213
  action_text_scale,
@@ -230,9 +242,7 @@ def run(
230
  action_name,
231
  (
232
  int(label_loc_w - label_width / 2),
233
- frame_policy_h
234
- - (frame_policy_h - label_loc_h) // 2
235
- + label_height // 2,
236
  ),
237
  frame_policy_label_font,
238
  action_text_label_scale,
@@ -363,7 +373,14 @@ with gr.Blocks(title="CS581 Demo") as demo:
363
  label="Epsilon (0 = greedy, 1 = random)",
364
  )
365
  input_epsilon.change(
366
- change_epsilon, inputs=[localstate, input_epsilon], outputs=[localstate]
 
 
 
 
 
 
 
367
  )
368
 
369
  input_render_fps = gr.components.Slider(
@@ -378,6 +395,11 @@ with gr.Blocks(title="CS581 Demo") as demo:
378
  inputs=[localstate, input_render_fps],
379
  outputs=[localstate],
380
  )
 
 
 
 
 
381
 
382
  out_image_frame = gr.components.Image(
383
  label="Environment",
 
81
  return state
82
 
83
 
84
+ def change_render_fps_update(state, x):
85
+ print("Changing render fps:", x)
86
+ state.live_render_fps = x
87
+ return state, gr.update(value=x)
88
+
89
+
90
  def change_epsilon(state, x):
91
  print("Changing greediness:", x)
92
  state.live_epsilon = x
93
  return state
94
 
95
 
96
+ def change_epsilon_update(state, x):
97
+ print("Changing greediness:", x)
98
+ state.live_epsilon = x
99
+ return state, gr.update(value=x)
100
+
101
+
102
  def change_paused(state, x):
103
  print("Changing paused:", x)
104
  state.live_paused = pause_val_map[x]
 
171
  agent.generate_episode(
172
  max_steps=max_steps,
173
  render=True,
 
174
  )
175
  ):
176
+ agent.epsilon_override = localstate.live_epsilon
177
  _, _, last_reward = (
178
  episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
179
  )
 
219
  str(action),
220
  (
221
  label_loc_w - label_width // 2,
222
+ frame_policy_h // 3 + label_height // 2,
223
  ),
224
  frame_policy_label_font,
225
  action_text_scale,
 
242
  action_name,
243
  (
244
  int(label_loc_w - label_width / 2),
245
+ frame_policy_h - frame_policy_h // 3 + label_height // 2,
 
 
246
  ),
247
  frame_policy_label_font,
248
  action_text_label_scale,
 
373
  label="Epsilon (0 = greedy, 1 = random)",
374
  )
375
  input_epsilon.change(
376
+ change_epsilon,
377
+ inputs=[localstate, input_epsilon],
378
+ outputs=[localstate],
379
+ )
380
+ input_epsilon.release(
381
+ change_epsilon_update,
382
+ inputs=[localstate, input_epsilon],
383
+ outputs=[localstate, input_epsilon],
384
  )
385
 
386
  input_render_fps = gr.components.Slider(
 
395
  inputs=[localstate, input_render_fps],
396
  outputs=[localstate],
397
  )
398
+ input_render_fps.release(
399
+ change_render_fps_update,
400
+ inputs=[localstate, input_render_fps],
401
+ outputs=[localstate, input_render_fps],
402
+ )
403
 
404
  out_image_frame = gr.components.Image(
405
  label="Environment",