File size: 3,565 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import imageio
from sim.simulator import Simulator
from sim.policy import Policy
from sim.viewer import ImageViewer
from typing import List, Tuple

step_time = []
psnr = []
delta_psnr = []


class InteractiveDigitalWorld:
    def __init__(self,
        simulator: Simulator,
        policy: Policy,
        offscreen: bool = True,                 # if False, show live window
        window_size: Tuple[int, int] = (512, 512),
    ):
        self.simulator = simulator
        self.policy = policy
        self.offscreen = offscreen
        self.video_frames: List[np.ndarray] = []
        self.dt = simulator.dt

        self.obs = self.simulator.reset()   # input to policy
        self.video_frames.append(self.obs)

        if not offscreen:
            self.viewer = ImageViewer(
                window_name=(
                    f"Simulator: {simulator.__class__.__name__} | "
                    f"Policy: {policy.__class__.__name__}"
                ),
                refresh_rate=self.dt,
                window_size=window_size
            )
            self.viewer.update_image(self.obs)

    def step(self) -> None:
        action = self.policy.generate_action(self.obs)
        result = self.simulator.step(action)
        next_frame = result['pred_next_frame']
        if 'gt_next_frame' in result:
            gt_next_frame = result['gt_next_frame']
            next_frame = np.concatenate([next_frame, gt_next_frame], axis=1)
        if 'psnr' in result:
            psnr.append(result['psnr'])
        if 'delta_psnr' in result:
            delta_psnr.append(result['delta_psnr'])
        if 'step_time' in result:
            step_time.append(result['step_time'])
        self.obs = next_frame
        if not self.offscreen:
            self.viewer.update_image(next_frame)
        self.video_frames.append(next_frame)


    def save_video(self, save_path: str, as_gif: bool = False) -> None:
        if as_gif:
            imageio.mimsave(save_path, self.video_frames, format='GIF', fps=1/self.dt)
        else:
            imageio.mimsave(save_path, self.video_frames, format='mp4', fps=1/self.dt)
        print(f"{'GIF' if as_gif else 'MP4'} saved to {save_path}")


    def reset(self) -> None:
        self.obs = self.simulator.reset()
        self.video_frames = []


    def close(self) -> None:
        self.simulator.close()
        if not self.offscreen:
            self.viewer.stop()

        def analyze_scalar_sequence(data: List[float]):
            q1 = np.percentile(data, 25, method='nearest')
            median = np.median(data)
            q3 = np.percentile(data, 75, method='nearest')
            mean = np.mean([t for t in data if q1 <= t <= q3])
            return mean, median

        # report stats
        if len(step_time) > 0:
            # take mean over data between q1 and q3
            mean, median = analyze_scalar_sequence(step_time)
            print(
                f"=========== Timing ===========\n"
                f"Mean: {mean}\n"
                f"Meadian: {median}\n"
            )

        if len(psnr) > 0:
            mean, median = analyze_scalar_sequence(psnr)
            print(
                f"=========== PSNR ===========\n"
                f"Mean: {mean}\n"
                f"Meadian: {median}\n"
            )

        if len(delta_psnr) > 0:
            mean, median = analyze_scalar_sequence(delta_psnr)
            print(
                f"=========== Delta PSNR ===========\n"
                f"Mean: {mean}\n"
                f"Meadian: {median}\n"
            )