Andrei Cozma commited on
Commit
69d9811
1 Parent(s): 4567d2b
Files changed (1) hide show
  1. demo.py +20 -5
demo.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import gradio as gr
6
  from MonteCarloAgent import MonteCarloAgent
7
  import scipy.ndimage
 
8
 
9
  # For the dropdown list of policies
10
  policies_folder = "policies"
@@ -90,7 +91,8 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
90
  state, action, reward = episode_hist[-1]
91
  curr_policy = agent.Pi[state]
92
 
93
- viz_w, viz_h = 128, 16
 
94
  policy_viz = np.zeros((viz_h, viz_w))
95
  for i, p in enumerate(curr_policy):
96
  policy_viz[
@@ -100,6 +102,22 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
100
  * (viz_w // len(curr_policy)),
101
  ] = p
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1)
104
  policy_viz = np.clip(
105
  policy_viz * (1 - live_epsilon) + live_epsilon / len(curr_policy), 0, 1
@@ -152,9 +170,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
152
  label="Max steps per episode",
153
  )
154
 
155
- btn_run = gr.components.Button(
156
- "鈻讹笍 Start", interactive=True if all_policies else False
157
- )
158
 
159
  gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
160
  with gr.Row():
@@ -196,7 +212,6 @@ with gr.Blocks(title="CS581 Demo") as demo:
196
  )
197
 
198
  with gr.Row():
199
- # Pause/resume button
200
  btn_pause = gr.components.Button("鈴革笍 Pause", interactive=True)
201
  btn_pause.click(
202
  fn=change_paused,
 
5
  import gradio as gr
6
  from MonteCarloAgent import MonteCarloAgent
7
  import scipy.ndimage
8
+ import cv2
9
 
10
  # For the dropdown list of policies
11
  policies_folder = "policies"
 
91
  state, action, reward = episode_hist[-1]
92
  curr_policy = agent.Pi[state]
93
 
94
+ viz_w = 512
95
+ viz_h = viz_w // len(curr_policy)
96
  policy_viz = np.zeros((viz_h, viz_w))
97
  for i, p in enumerate(curr_policy):
98
  policy_viz[
 
102
  * (viz_w // len(curr_policy)),
103
  ] = p
104
 
105
+ policy_viz = np.stack([policy_viz] * 3, axis=-1)
106
+ text_offset = 15
107
+ cv2.putText(
108
+ policy_viz,
109
+ str(action),
110
+ (
111
+ int((action + 0.5) * viz_w // len(curr_policy) - text_offset),
112
+ viz_h // 2 + text_offset,
113
+ ),
114
+ cv2.FONT_HERSHEY_SIMPLEX,
115
+ 1.5,
116
+ (255, 255, 255),
117
+ 1,
118
+ cv2.LINE_AA,
119
+ )
120
+
121
  policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1)
122
  policy_viz = np.clip(
123
  policy_viz * (1 - live_epsilon) + live_epsilon / len(curr_policy), 0, 1
 
170
  label="Max steps per episode",
171
  )
172
 
173
+ btn_run = gr.components.Button("鈻讹笍 Start", interactive=bool(all_policies))
 
 
174
 
175
  gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
176
  with gr.Row():
 
212
  )
213
 
214
  with gr.Row():
 
215
  btn_pause = gr.components.Button("鈴革笍 Pause", interactive=True)
216
  btn_pause.click(
217
  fn=change_paused,