File size: 2,519 Bytes
3261e0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import tempfile
import imageio
from stable_baselines3.common.vec_env import VecVideoRecorder
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv


def generate_video(model, video_fp, video_length_in_episodes=5):

    eval_env = model.get_env()

    max_video_length_in_steps = (
        video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps
    )

    with tempfile.TemporaryDirectory() as temp_dp:
        vec_env = VecVideoRecorder(
            eval_env,
            temp_dp,
            record_video_trigger=lambda x: x == 0,
            video_length=max_video_length_in_steps,
        )

        frame_count = 0
        episode_count = 0
        obs = vec_env.reset()
        for _ in range(max_video_length_in_steps):
            action, _ = model.predict(obs, deterministic=True)
            obs, _, dones, _ = vec_env.step(action)
            frame_count += 1
            if dones:
                episode_count += 1
            if episode_count >= video_length_in_episodes:
                break

        vec_env.close()

        temp_fp = vec_env.video_recorder.path

        # TODO: Fix this.
        # Use ffmpeg to remove the last frame (it is the first frame in a new episode).
        os.system(
            f"""ffmpeg -y -i {temp_fp} -vf "select='not(eq(n,{frame_count}))'" {video_fp} > /dev/null 2>&1"""
        )
        # os.rename(temp_fp, file_path)


def generate_gif(model, file_path, video_length_in_episodes=5):
    eval_env = model.get_env()

    max_video_length_in_steps = (
        video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps
    )

    render_image = lambda: eval_env.render(mode="rgb_array")

    images = []
    episode_count = 0
    obs = eval_env.reset()
    images.append(render_image())
    for _ in range(max_video_length_in_steps):
        action, _ = model.predict(obs)
        obs, _, dones, _ = eval_env.step(action)
        if dones:
            episode_count += 1
        if episode_count >= video_length_in_episodes:
            break
        images.append(render_image())

    imageio.mimsave(
        file_path, [np.array(img) for i, img in enumerate(images) if i % 2 == 0], fps=25
    )


def load_ppo_model_for_video(model_fp, env_id):
    env = DummyVecEnv([lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))])
    model = PPO.load(model_fp, env=env)
    return model