Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,565 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import numpy as np
import imageio
from sim.simulator import Simulator
from sim.policy import Policy
from sim.viewer import ImageViewer
from typing import List, Tuple
step_time = []
psnr = []
delta_psnr = []
class InteractiveDigitalWorld:
def __init__(self,
simulator: Simulator,
policy: Policy,
offscreen: bool = True, # if False, show live window
window_size: Tuple[int, int] = (512, 512),
):
self.simulator = simulator
self.policy = policy
self.offscreen = offscreen
self.video_frames: List[np.ndarray] = []
self.dt = simulator.dt
self.obs = self.simulator.reset() # input to policy
self.video_frames.append(self.obs)
if not offscreen:
self.viewer = ImageViewer(
window_name=(
f"Simulator: {simulator.__class__.__name__} | "
f"Policy: {policy.__class__.__name__}"
),
refresh_rate=self.dt,
window_size=window_size
)
self.viewer.update_image(self.obs)
def step(self) -> None:
action = self.policy.generate_action(self.obs)
result = self.simulator.step(action)
next_frame = result['pred_next_frame']
if 'gt_next_frame' in result:
gt_next_frame = result['gt_next_frame']
next_frame = np.concatenate([next_frame, gt_next_frame], axis=1)
if 'psnr' in result:
psnr.append(result['psnr'])
if 'delta_psnr' in result:
delta_psnr.append(result['delta_psnr'])
if 'step_time' in result:
step_time.append(result['step_time'])
self.obs = next_frame
if not self.offscreen:
self.viewer.update_image(next_frame)
self.video_frames.append(next_frame)
def save_video(self, save_path: str, as_gif: bool = False) -> None:
if as_gif:
imageio.mimsave(save_path, self.video_frames, format='GIF', fps=1/self.dt)
else:
imageio.mimsave(save_path, self.video_frames, format='mp4', fps=1/self.dt)
print(f"{'GIF' if as_gif else 'MP4'} saved to {save_path}")
def reset(self) -> None:
self.obs = self.simulator.reset()
self.video_frames = []
def close(self) -> None:
self.simulator.close()
if not self.offscreen:
self.viewer.stop()
def analyze_scalar_sequence(data: List[float]):
q1 = np.percentile(data, 25, method='nearest')
median = np.median(data)
q3 = np.percentile(data, 75, method='nearest')
mean = np.mean([t for t in data if q1 <= t <= q3])
return mean, median
# report stats
if len(step_time) > 0:
# take mean over data between q1 and q3
mean, median = analyze_scalar_sequence(step_time)
print(
f"=========== Timing ===========\n"
f"Mean: {mean}\n"
f"Meadian: {median}\n"
)
if len(psnr) > 0:
mean, median = analyze_scalar_sequence(psnr)
print(
f"=========== PSNR ===========\n"
f"Mean: {mean}\n"
f"Meadian: {median}\n"
)
if len(delta_psnr) > 0:
mean, median = analyze_scalar_sequence(delta_psnr)
print(
f"=========== Delta PSNR ===========\n"
f"Mean: {mean}\n"
f"Meadian: {median}\n"
) |