File size: 15,868 Bytes
8ceccef
 
 
 
3e2038a
4567d2b
69d9811
8ceccef
46b0409
3e2038a
ed9cf21
 
45dcb54
 
 
ed9cf21
1ac9ba4
ec8233c
7357801
8ceccef
 
 
3e2038a
ed9cf21
 
 
 
 
 
 
1663f39
 
 
 
 
 
ec8233c
 
 
 
 
 
 
 
ed9cf21
8ceccef
d922c89
 
 
 
 
 
 
4567d2b
f221e91
 
 
 
 
 
 
 
 
 
 
 
adada5a
f221e91
 
 
 
 
 
6ee82fe
 
 
 
45dcb54
8ceccef
 
adada5a
0f41753
adada5a
 
 
 
 
 
 
 
 
 
 
f221e91
fee7b36
f221e91
 
8ceccef
 
668f525
fee7b36
668f525
 
 
 
f221e91
fee7b36
f221e91
 
8ceccef
 
668f525
fee7b36
668f525
 
 
 
f221e91
fee7b36
f221e91
 
 
 
 
45dcb54
e24c7c0
 
f221e91
e24c7c0
f221e91
 
 
 
 
 
 
 
 
 
 
 
 
e24c7c0
8ceccef
f221e91
8ceccef
 
f221e91
 
8ceccef
4567d2b
e24c7c0
6ee82fe
e173b06
 
 
30bb976
 
f221e91
e24c7c0
 
e173b06
6ee82fe
8ceccef
45dcb54
9c2fd5e
 
 
 
 
 
 
 
8ceccef
 
 
e24c7c0
 
 
8ceccef
 
 
 
 
ec8233c
53c3925
45dcb54
4567d2b
30bb976
f221e91
 
4567d2b
8ceccef
668f525
9c2fd5e
 
 
8ceccef
552dbe8
4567d2b
1ac9ba4
 
4567d2b
45dcb54
4567d2b
 
1ac9ba4
 
4567d2b
 
45dcb54
 
f221e91
 
de8a156
 
 
f221e91
 
 
 
 
fee7b36
 
 
 
 
 
f52cc9a
 
ec8233c
 
 
 
 
f221e91
 
 
 
 
 
 
de8a156
69d9811
45dcb54
69d9811
 
f52cc9a
668f525
69d9811
f52cc9a
 
 
 
69d9811
 
 
ed9cf21
53c3925
f221e91
 
 
 
 
 
 
ed9cf21
45dcb54
ed9cf21
 
f52cc9a
668f525
ed9cf21
f52cc9a
 
 
 
ed9cf21
 
 
8ceccef
f221e91
8ceccef
 
6ee82fe
ed9cf21
 
 
9c2fd5e
8ceccef
f221e91
 
 
d922c89
f221e91
 
 
d922c89
f221e91
de8a156
f221e91
6ee82fe
e24c7c0
 
 
9c2fd5e
f221e91
 
45dcb54
 
f221e91
 
a5e34df
45dcb54
f221e91
adada5a
 
45dcb54
1ac9ba4
adada5a
 
 
 
 
 
45dcb54
 
 
e24c7c0
ed9cf21
 
 
ec8233c
6f3d0ef
f221e91
6ee82fe
f221e91
ec8233c
8ceccef
 
4567d2b
ec8233c
 
 
 
 
 
 
 
 
4567d2b
8a49a12
4567d2b
8ceccef
f221e91
 
4567d2b
8ceccef
4567d2b
 
 
 
 
8ceccef
4567d2b
 
8ceccef
4567d2b
 
 
ed9cf21
 
4567d2b
 
 
 
ed9cf21
 
4567d2b
 
 
adada5a
 
 
 
 
8ceccef
6f3d0ef
4567d2b
 
 
 
 
 
8ceccef
4567d2b
 
 
ec8233c
4567d2b
 
9c2fd5e
4567d2b
 
 
1ac9ba4
4567d2b
8ceccef
4567d2b
 
 
f221e91
49c1e2a
4567d2b
 
f221e91
668f525
 
 
 
 
 
 
 
f221e91
4567d2b
 
f221e91
 
 
 
 
 
 
 
 
 
4567d2b
668f525
 
 
 
 
4567d2b
 
7357801
 
 
4567d2b
1ac9ba4
8ceccef
4567d2b
d922c89
f221e91
d922c89
6f3d0ef
e24c7c0
4567d2b
 
f221e91
 
e24c7c0
 
 
f221e91
4567d2b
 
 
 
 
e24c7c0
4567d2b
 
8ceccef
843e8ad
adada5a
f221e91
 
843e8ad
1663f39
adada5a
 
 
 
 
 
8ceccef
 
 
f221e91
4567d2b
8ceccef
 
 
4567d2b
8ceccef
 
f221e91
4567d2b
 
 
 
8ceccef
4567d2b
8ceccef
 
 
 
 
 
 
 
46e4f7b
8ceccef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import os
import time
import numpy as np
import gradio as gr

import scipy.ndimage
import cv2

from utils import load_agent

default_n_test_episodes = 10
default_max_steps = 500
default_render_fps = 5
default_epsilon = 0.0
default_paused = True

frame_env_h, frame_env_w = 512, 768
frame_policy_res = 512

# For the dropdown list of policies
policies_folder = "policies"


action_map = {
    "CliffWalking-v0": {
        0: "up",
        1: "right",
        2: "down",
        3: "left",
    },
    "FrozenLake-v1": {
        0: "left",
        1: "down",
        2: "right",
        3: "up",
    },
    "Taxi-v3": {
        0: "down",
        1: "up",
        2: "right",
        3: "left",
        4: "pickup",
        5: "dropoff",
    },
}


pause_val_map = {
    "▶️ Resume": False,
    "⏸️ Pause": True,
}
pause_val_map_inv = {v: k for k, v in pause_val_map.items()}

# Global variables to allow changing it on the fly


class RunState:
    def __init__(self) -> None:
        self.current_policy = None
        self.live_render_fps = default_render_fps
        self.live_epsilon = default_epsilon
        self.live_paused = default_paused
        self.live_steps_forward = None
        self.should_reset = False


def reset_change(state, policy_fname):
    if state.current_policy is not None and state.current_policy != policy_fname:
        state.should_reset = True
    state.live_paused = default_paused
    state.live_render_fps = default_render_fps
    state.live_epsilon = default_epsilon
    state.live_steps_forward = None
    return (
        state,
        gr.update(value=pause_val_map_inv[not state.live_paused]),
        gr.update(interactive=state.live_paused),
    )


def reset_click(state):
    state.should_reset = state.current_policy is not None
    state.live_paused = default_paused
    state.live_render_fps = default_render_fps
    state.live_epsilon = default_epsilon
    state.live_steps_forward = None
    return (
        state,
        gr.update(value=pause_val_map_inv[not state.live_paused]),
        gr.update(interactive=state.live_paused),
    )


def change_render_fps(state, x):
    print("change_render_fps:", x)
    state.live_render_fps = x
    return state


def change_render_fps_update(state, x):
    print("change_render_fps:", x)
    state.live_render_fps = x
    return state, gr.update(value=x)


def change_epsilon(state, x):
    print("change_epsilon:", x)
    state.live_epsilon = x
    return state


def change_epsilon_update(state, x):
    print("change_epsilon:", x)
    state.live_epsilon = x
    return state, gr.update(value=x)


def change_paused(state, x):
    print("change_paused:", x)
    state.live_paused = pause_val_map[x]
    return (
        state,
        gr.update(value=pause_val_map_inv[not state.live_paused]),
        gr.update(interactive=state.live_paused),
    )


def onclick_btn_forward(state):
    print("Step forward")
    if state.live_steps_forward is None:
        state.live_steps_forward = 0
    state.live_steps_forward += 1
    return state


def run(
    localstate: RunState, policy_fname, n_test_episodes, max_steps, render_fps, epsilon
):
    localstate.current_policy = policy_fname
    localstate.live_render_fps = render_fps
    localstate.live_epsilon = epsilon
    localstate.live_steps_forward = None
    print("=" * 80)
    print("Running...")
    print(f"- policy_fname: {localstate.current_policy}")
    print(f"- n_test_episodes: {n_test_episodes}")
    print(f"- max_steps: {max_steps}")
    print(f"- render_fps: {localstate.live_render_fps}")
    print(f"- epsilon: {localstate.live_steps_forward}")

    policy_path = os.path.join(policies_folder, policy_fname)

    try:
        agent = load_agent(
            policy_path, return_agent_env_keys=True, render_mode="rgb_array"
        )
    except ValueError as e:
        print(f"🚫 Error: {e}")
        yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
        return

    agent_key, env_key = agent.__class__.__name__, agent.env_name
    env_action_map = action_map.get(env_key)

    solved, frame_env, frame_policy = None, None, None
    episode, step, state, action, reward, last_reward = (
        None,
        None,
        None,
        None,
        None,
        None,
    )
    episodes_solved = 0

    def ep_str(episode):
        return (
            f"{episode} / {n_test_episodes} ({(episode) / n_test_episodes * 100:.2f}%)"
        )

    def step_str(step):
        return f"{step + 1}"

    for episode in range(n_test_episodes):
        time.sleep(0.5)

        for step, (episode_hist, solved, frame_env) in enumerate(
            agent.generate_episode(
                policy=agent.Pi,
                max_steps=max_steps,
                render=True,
            )
        ):
            agent.epsilon_override = localstate.live_epsilon
            _, _, last_reward = (
                episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
            )
            state, action, reward = episode_hist[-1]
            curr_policy = agent.Pi[state]

            frame_policy_h = frame_policy_res // len(curr_policy)
            frame_policy = np.zeros((frame_policy_h, frame_policy_res))
            for i, p in enumerate(curr_policy):
                frame_policy[
                    :,
                    i
                    * (frame_policy_res // len(curr_policy)) : (i + 1)
                    * (frame_policy_res // len(curr_policy)),
                ] = p

            frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
            frame_policy = np.clip(
                frame_policy * (1.0 - localstate.live_epsilon)
                + localstate.live_epsilon / len(curr_policy),
                0.0,
                1.0,
            )

            label_loc_h, label_loc_w = frame_policy_h // 2, int(
                (action + 0.5) * frame_policy_res // len(curr_policy)
            )

            frame_policy_label_color = 0.0
            if frame_policy[label_loc_h, label_loc_w] > 0.5:
                frame_policy_label_color = 0.0
            else:
                frame_policy_label_color = 1.0

            frame_policy_label_font = cv2.FONT_HERSHEY_SIMPLEX
            frame_policy_label_thicc = 1
            action_text_scale, action_text_label_scale = 1.0, 0.6
            # These scales are for policies that have length 4
            # Longer policies should have smaller scales
            action_text_scale *= 4 / len(curr_policy)
            action_text_label_scale *= 4 / len(curr_policy)

            (label_width, label_height), _ = cv2.getTextSize(
                str(action),
                frame_policy_label_font,
                action_text_scale,
                frame_policy_label_thicc,
            )

            cv2.putText(
                frame_policy,
                str(action),
                (
                    label_loc_w - label_width // 2,
                    frame_policy_h // 3 + label_height // 2,
                ),
                frame_policy_label_font,
                action_text_scale,
                frame_policy_label_color,
                frame_policy_label_thicc,
                cv2.LINE_AA,
            )

            if env_action_map:
                action_name = env_action_map.get(action, "")
                (label_width, label_height), _ = cv2.getTextSize(
                    action_name,
                    frame_policy_label_font,
                    action_text_label_scale,
                    frame_policy_label_thicc,
                )

                cv2.putText(
                    frame_policy,
                    action_name,
                    (
                        int(label_loc_w - label_width / 2),
                        frame_policy_h - frame_policy_h // 3 + label_height // 2,
                    ),
                    frame_policy_label_font,
                    action_text_label_scale,
                    frame_policy_label_color,
                    frame_policy_label_thicc,
                    cv2.LINE_AA,
                )

            print(
                f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {localstate.live_epsilon:.2f}) (frame time: {1 / localstate.live_render_fps:.2f}s)"
            )

            yield localstate, agent_key, env_key, frame_env, frame_policy, ep_str(
                episode + 1
            ), ep_str(episodes_solved), step_str(
                step
            ), state, action, last_reward, "Running..."

            if localstate.live_steps_forward is not None:
                if localstate.live_steps_forward > 0:
                    localstate.live_steps_forward -= 1

                if localstate.live_steps_forward == 0:
                    localstate.live_steps_forward = None
                    localstate.live_paused = True
            else:
                time.sleep(1 / localstate.live_render_fps)

            while localstate.live_paused and localstate.live_steps_forward is None:
                yield localstate, agent_key, env_key, frame_env, frame_policy, ep_str(
                    episode + 1
                ), ep_str(episodes_solved), step_str(
                    step
                ), state, action, last_reward, "Paused..."
                time.sleep(1 / localstate.live_render_fps)
                if localstate.should_reset is True:
                    break

            if localstate.should_reset is True:
                localstate.should_reset = False
                localstate.current_policy = None
                yield (
                    localstate,
                    None,
                    None,
                    np.ones((frame_env_h, frame_env_w, 3)),
                    np.ones((frame_policy_h, frame_policy_res)),
                    None,
                    None,
                    None,
                    None,
                    None,
                    None,
                    "Reset...",
                )
                return

        if solved:
            episodes_solved += 1

        time.sleep(0.5)

    localstate.current_policy = None
    yield localstate, agent_key, env_key, frame_env, frame_policy, ep_str(
        episode + 1
    ), ep_str(episodes_solved), step_str(step), state, action, last_reward, "Done!"


with gr.Blocks(title="CS581 Demo") as demo:
    try:
        all_policies = [
            file for file in os.listdir(policies_folder) if file.endswith(".npy")
        ]
        all_policies.sort()
    except FileNotFoundError:
        print("ERROR: No policies folder found!")
        all_policies = []

    gr.components.HTML(
        "<h1>CS581 Final Project Demo - Dynamic Programming & Monte-Carlo RL Methods (<a href='https://github.com/andreicozma1/CS581-Algorithms-Project'>GitHub</a>) (<a href='https://huggingface.co/spaces/acozma/CS581-Algos-Demo'>HF Space</a>)</h1>"
    )

    localstate = gr.State(RunState())

    gr.components.HTML("<h2>Select Configuration:</h2>")
    with gr.Row():
        input_policy = gr.components.Dropdown(
            label="Policy Checkpoint",
            choices=all_policies,
            value=all_policies[0] if all_policies else "No policies found :(",
        )

        out_environment = gr.components.Textbox(label="Resolved Environment")
        out_agent = gr.components.Textbox(label="Resolved Agent")

    with gr.Row():
        input_n_test_episodes = gr.components.Slider(
            minimum=1,
            maximum=1000,
            value=default_n_test_episodes,
            label="Number of episodes",
        )
        input_max_steps = gr.components.Slider(
            minimum=1,
            maximum=1000,
            value=default_max_steps,
            label="Max steps per episode",
        )

    with gr.Row():
        btn_run = gr.components.Button(
            "👀 Select & Load", interactive=bool(all_policies)
        )
        btn_clear = gr.components.Button("🗑️ Clear", interactive=bool(all_policies))

    gr.components.HTML("<h2>Live Visualization & Information:</h2>")
    with gr.Row():
        with gr.Column():
            with gr.Row():
                out_episode = gr.components.Textbox(label="Current Episode")
                out_step = gr.components.Textbox(label="Current Step")
                out_eps_solved = gr.components.Textbox(label="Episodes Solved")

            with gr.Row():
                out_state = gr.components.Textbox(label="Current State")
                out_action = gr.components.Textbox(label="Chosen Action")
                out_reward = gr.components.Textbox(label="Reward Received")

        out_image_policy = gr.components.Image(
            label="Action Sampled vs Policy Distribution for Current State",
            type="numpy",
            image_mode="RGB",
        )
        out_image_policy.style(height=200)

    with gr.Row():
        input_epsilon = gr.components.Slider(
            minimum=0,
            maximum=1,
            value=default_epsilon,
            step=1 / 20,
            label="Epsilon (0 = greedy, 1 = random)",
        )
        input_epsilon.change(
            change_epsilon,
            inputs=[localstate, input_epsilon],
            outputs=[localstate],
        )
        input_epsilon.release(
            change_epsilon_update,
            inputs=[localstate, input_epsilon],
            outputs=[localstate, input_epsilon],
        )

        input_render_fps = gr.components.Slider(
            minimum=1,
            maximum=60,
            value=default_render_fps,
            step=1,
            label="Simulation speed (fps)",
        )
        input_render_fps.change(
            change_render_fps,
            inputs=[localstate, input_render_fps],
            outputs=[localstate],
        )
        input_render_fps.release(
            change_render_fps_update,
            inputs=[localstate, input_render_fps],
            outputs=[localstate, input_render_fps],
        )

    out_image_frame = gr.components.Image(
        label="Environment",
        type="numpy",
        image_mode="RGB",
    )
    out_image_frame.style(height=frame_env_h)

    with gr.Row():
        btn_pause = gr.components.Button(
            pause_val_map_inv[not default_paused], interactive=True
        )
        btn_forward = gr.components.Button("⏩ Step")

        btn_pause.click(
            fn=change_paused,
            inputs=[localstate, btn_pause],
            outputs=[localstate, btn_pause, btn_forward],
        )

        btn_forward.click(
            fn=onclick_btn_forward, inputs=[localstate], outputs=[localstate]
        )

    out_msg = gr.components.Textbox(
        value=""
        if all_policies
        else "ERROR: No policies found! Please train an agent first or add a policy to the policies folder.",
        label="Status Message",
    )

    input_policy.change(
        fn=reset_change,
        inputs=[localstate, input_policy],
        outputs=[localstate, btn_pause, btn_forward],
    )

    btn_clear.click(
        fn=reset_click,
        inputs=[localstate],
        outputs=[localstate, btn_pause, btn_forward],
    )

    btn_run.click(
        fn=run,
        inputs=[
            localstate,
            input_policy,
            input_n_test_episodes,
            input_max_steps,
            input_render_fps,
            input_epsilon,
        ],
        outputs=[
            localstate,
            out_agent,
            out_environment,
            out_image_frame,
            out_image_policy,
            out_episode,
            out_eps_solved,
            out_step,
            out_state,
            out_action,
            out_reward,
            out_msg,
        ],
    )

demo.queue(concurrency_count=8)
demo.launch()