cartpole-demo / app.py
skar0's picture
Initial commit
820bb68
raw
history blame
1.77 kB
import glob
import gradio as gr
import gym
import sys
from torch.utils.tensorboard import SummaryWriter
import yaml
import torch
from cartpole import (
make_env, reset_env, Agent, rollout_phase, get_action_shape
)
MAIN = __name__ == "__main__"
examples = [0, 1, 31415, 'Hello, World!', 'This is a seed...']
def generate_video(
string: str, wandb_path='wandb/run-20230303_211416-ox4d1p0u/files'
):
with open(f'{wandb_path}/config.yaml') as f_cfg:
config = yaml.safe_load(f_cfg)
seed = hash(string) % ((sys.maxsize + 1) * 2)
num_envs = config['num_envs']['value']
num_steps = config['num_steps']['value']
assert seed >= 0
assert isinstance(seed, int)
run_name = f'seed{seed}'
log_dir = f'generate/{run_name}'
writer = SummaryWriter(log_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
envs = gym.vector.SyncVectorEnv([
make_env("CartPole-v1", seed, i, True, run_name)
for i in range(num_envs)
])
action_shape = get_action_shape(envs)
next_obs, next_done = reset_env(envs, device)
global_step = 0
agent = Agent(envs).to(device)
agent.load_state_dict(torch.load(f'{wandb_path}/model_state_dict.pt'))
rollout_phase(
next_obs, next_done, agent, envs, writer, device,
global_step, action_shape, num_envs, num_steps,
)
video_path = glob.glob(f'videos/{run_name}/*.mp4')[0]
return video_path
if MAIN:
demo = gr.Interface(
fn=generate_video,
inputs=[
gr.components.Textbox(lines=1, label="Seed"),
],
outputs=gr.components.Video(label="Generated Video"),
examples=examples,
)
demo.launch(share=True)