CS581-Algos-Demo / demo.py
Andrei Cozma
Updates
e3af403
raw
history blame
12.3 kB
import os
import time
import numpy as np
import gradio as gr
import scipy.ndimage
import cv2
from agents import AGENTS_MAP
default_n_test_episodes = 10
default_max_steps = 500
default_render_fps = 5
default_epsilon = 0.0
default_paused = True
frame_env_h, frame_env_w = 512, 768
frame_policy_res = 384
# For the dropdown list of policies
policies_folder = "policies"
try:
all_policies = [
file for file in os.listdir(policies_folder) if file.endswith(".npy")
]
except FileNotFoundError:
print("ERROR: No policies folder found!")
all_policies = []
action_map = {
"CliffWalking-v0": {
0: "up",
1: "right",
2: "down",
3: "left",
},
"FrozenLake-v1": {
0: "left",
1: "down",
2: "right",
3: "up",
},
}
pause_val_map = {
"▶️ Resume": False,
"⏸️ Pause": True,
}
pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
# Global variables to allow changing it on the fly
current_policy = None
live_render_fps = default_render_fps
live_epsilon = default_epsilon
live_paused = default_paused
live_steps_forward = None
should_reset = False
def reset(policy_fname):
global current_policy, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
if current_policy is not None and current_policy != policy_fname:
should_reset = True
live_paused = default_paused
live_render_fps = default_render_fps
live_epsilon = default_epsilon
live_steps_forward = None
return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
interactive=live_paused
)
def change_render_fps(x):
print("Changing render fps:", x)
global live_render_fps
live_render_fps = x
def change_epsilon(x):
print("Changing greediness:", x)
global live_epsilon
live_epsilon = x
def change_paused(x):
print("Changing paused:", x)
global live_paused
live_paused = pause_val_map[x]
return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
interactive=live_paused
)
def onclick_btn_forward():
print("Step forward")
global live_steps_forward
if live_steps_forward is None:
live_steps_forward = 0
live_steps_forward += 1
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
global current_policy, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
current_policy = policy_fname
live_render_fps = render_fps
live_epsilon = epsilon
live_steps_forward = None
print("=" * 80)
print("Running...")
print(f"- policy_fname: {policy_fname}")
print(f"- n_test_episodes: {n_test_episodes}")
print(f"- max_steps: {max_steps}")
print(f"- render_fps: {live_render_fps}")
print(f"- epsilon: {live_epsilon}")
policy_path = os.path.join(policies_folder, policy_fname)
props = policy_fname.split("_")
if len(props) < 2:
yield None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
return
agent_type, env_name = props[0], props[1]
agent = AGENTS_MAP[agent_type](env_name=env_name, render_mode="rgb_array")
agent.load_policy(policy_path)
env_action_map = action_map.get(env_name)
solved, frame_env, frame_policy = None, None, None
episode, step, state, action, reward, last_reward = (
None,
None,
None,
None,
None,
None,
)
episodes_solved = 0
def ep_str(episode):
return (
f"{episode} / {n_test_episodes} ({(episode) / n_test_episodes * 100:.2f}%)"
)
def step_str(step):
return f"{step + 1}"
for episode in range(n_test_episodes):
time.sleep(0.25)
for step, (episode_hist, solved, frame_env) in enumerate(
agent.generate_episode(
max_steps=max_steps, render=True, epsilon_override=live_epsilon
)
):
_, _, last_reward = (
episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
)
state, action, reward = episode_hist[-1]
curr_policy = agent.Pi[state]
frame_policy_h = frame_policy_res // len(curr_policy)
frame_policy = np.zeros((frame_policy_h, frame_policy_res))
for i, p in enumerate(curr_policy):
frame_policy[
:,
i
* (frame_policy_res // len(curr_policy)) : (i + 1)
* (frame_policy_res // len(curr_policy)),
] = p
frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
frame_policy = np.clip(
frame_policy * (1.0 - live_epsilon) + live_epsilon / len(curr_policy),
0.0,
1.0,
)
text_color = frame_policy[frame_policy_h // 2, int((action + 0.5) * frame_policy_res // len(curr_policy))]
text_color = 1.0 - text_color
cv2.putText(
frame_policy,
str(action),
(
int((action + 0.5) * frame_policy_res // len(curr_policy) - 8),
frame_policy_h // 2 - 5,
),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
text_color,
1,
cv2.LINE_AA,
)
if env_action_map:
action_name = env_action_map.get(action, "")
cv2.putText(
frame_policy,
action_name,
(
int(
(action + 0.5) * frame_policy_res // len(curr_policy)
- 5 * len(action_name)
),
frame_policy_h // 2 + 25,
),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
text_color,
1,
cv2.LINE_AA,
)
print(
f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
)
yield agent_type, env_name, frame_env, frame_policy, ep_str(
episode + 1
), ep_str(episodes_solved), step_str(
step
), state, action, last_reward, "Running..."
if live_steps_forward is not None:
if live_steps_forward > 0:
live_steps_forward -= 1
if live_steps_forward == 0:
live_steps_forward = None
live_paused = True
else:
time.sleep(1 / live_render_fps)
while live_paused and live_steps_forward is None:
yield agent_type, env_name, frame_env, frame_policy, ep_str(
episode + 1
), ep_str(episodes_solved), step_str(
step
), state, action, last_reward, "Paused..."
time.sleep(1 / live_render_fps)
if should_reset is True:
break
if should_reset is True:
should_reset = False
yield (
agent_type,
env_name,
np.ones((frame_env_h, frame_env_w, 3)),
np.ones((frame_policy_h, frame_policy_res)),
ep_str(episode + 1),
ep_str(episodes_solved),
step_str(step),
state,
action,
last_reward,
"Reset...",
)
return
if solved:
episodes_solved += 1
time.sleep(0.25)
current_policy = None
yield agent_type, env_name, frame_env, frame_policy, ep_str(episode + 1), ep_str(
episodes_solved
), step_str(step), state, action, reward, "Done!"
with gr.Blocks(title="CS581 Demo") as demo:
gr.components.HTML(
"<h1>CS581 Final Project Demo - Dynamic Programming & Monte-Carlo RL Methods (<a href='https://huggingface.co/spaces/acozma/CS581-Algos-Demo'>HF Space</a>)</h1>"
)
gr.components.HTML("<h2>Select Configuration:</h2>")
with gr.Row():
input_policy = gr.components.Dropdown(
label="Policy Checkpoint",
choices=all_policies,
value=all_policies[0] if all_policies else "No policies found :(",
)
out_environment = gr.components.Textbox(label="Resolved Environment")
out_agent = gr.components.Textbox(label="Resolved Agent")
with gr.Row():
input_n_test_episodes = gr.components.Slider(
minimum=1,
maximum=1000,
value=default_n_test_episodes,
label="Number of episodes",
)
input_max_steps = gr.components.Slider(
minimum=1,
maximum=1000,
value=default_max_steps,
label="Max steps per episode",
)
btn_run = gr.components.Button("👀 Select & Load", interactive=bool(all_policies))
gr.components.HTML("<h2>Live Visualization & Information:</h2>")
with gr.Row():
with gr.Column():
with gr.Row():
out_episode = gr.components.Textbox(label="Current Episode")
out_step = gr.components.Textbox(label="Current Step")
out_eps_solved = gr.components.Textbox(label="Episodes Solved")
with gr.Row():
out_state = gr.components.Textbox(label="Current State")
out_action = gr.components.Textbox(label="Chosen Action")
out_reward = gr.components.Textbox(label="Last Reward")
out_image_policy = gr.components.Image(
label="Action Sampled vs Policy Distribution for Current State",
type="numpy",
image_mode="RGB",
)
out_image_policy.style(height=200)
with gr.Row():
input_epsilon = gr.components.Slider(
minimum=0,
maximum=1,
value=live_epsilon,
step=1/200,
label="Epsilon (0 = greedy, 1 = random)",
)
input_epsilon.change(change_epsilon, inputs=[input_epsilon])
input_render_fps = gr.components.Slider(
minimum=1, maximum=60, value=live_render_fps, step=1,
label="Simulation speed (fps)"
)
input_render_fps.change(change_render_fps, inputs=[input_render_fps])
out_image_frame = gr.components.Image(
label="Environment",
type="numpy",
image_mode="RGB",
)
out_image_frame.style(height=frame_env_h)
with gr.Row():
btn_pause = gr.components.Button(
pause_val_map_inv[not live_paused], interactive=True
)
btn_forward = gr.components.Button("⏩ Step")
btn_pause.click(
fn=change_paused,
inputs=[btn_pause],
outputs=[btn_pause, btn_forward],
)
btn_forward.click(
fn=onclick_btn_forward,
)
out_msg = gr.components.Textbox(
value=""
if all_policies
else "ERROR: No policies found! Please train an agent first or add a policy to the policies folder.",
label="Status Message",
)
input_policy.change(
fn=reset, inputs=[input_policy], outputs=[btn_pause, btn_forward]
)
btn_run.click(
fn=run,
inputs=[
input_policy,
input_n_test_episodes,
input_max_steps,
input_render_fps,
input_epsilon,
],
outputs=[
out_agent,
out_environment,
out_image_frame,
out_image_policy,
out_episode,
out_eps_solved,
out_step,
out_state,
out_action,
out_reward,
out_msg,
],
)
demo.queue(concurrency_count=2)
demo.launch()