Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- sim/simulator.py +32 -32
sim/simulator.py
CHANGED
@@ -27,13 +27,13 @@ class Simulator:
|
|
27 |
@torch.inference_mode()
|
28 |
def step(self, action):
|
29 |
raise NotImplementedError
|
30 |
-
|
31 |
def reset(self):
|
32 |
raise NotImplementedError
|
33 |
-
|
34 |
def close(self):
|
35 |
raise NotImplementedError
|
36 |
-
|
37 |
@property
|
38 |
def dt(self):
|
39 |
raise NotImplementedError
|
@@ -46,16 +46,16 @@ class PhysicsSimulator(Simulator):
|
|
46 |
# physics engine should be able to update dt
|
47 |
def set_dt(self, dt):
|
48 |
raise NotImplementedError
|
49 |
-
|
50 |
# physics engine should be able to get scene state
|
51 |
# e.g., robot joint positions, object positions, etc.
|
52 |
def get_raw_state(self, port: Optional[str] = None):
|
53 |
raise NotImplementedError
|
54 |
-
|
55 |
@property
|
56 |
def action_dimension(self):
|
57 |
raise NotImplementedError
|
58 |
-
|
59 |
|
60 |
class LearnedSimulator(Simulator):
|
61 |
def __init__(self):
|
@@ -65,9 +65,9 @@ class LearnedSimulator(Simulator):
|
|
65 |
# data replayed respect physics, so we inherit from PhysicsSimulator
|
66 |
# it can be considered as a special case of PhysicsSimulator
|
67 |
class ReplaySimulator(PhysicsSimulator):
|
68 |
-
def __init__(self,
|
69 |
-
frames,
|
70 |
-
prompt_horizon: int = 0,
|
71 |
dt: Optional[float] = None
|
72 |
):
|
73 |
super().__init__()
|
@@ -76,10 +76,10 @@ class ReplaySimulator(PhysicsSimulator):
|
|
76 |
assert self.frame_idx < len(self.frames)
|
77 |
self._dt = dt
|
78 |
self.prompt_horizon = prompt_horizon
|
79 |
-
|
80 |
def __len__(self):
|
81 |
return len(self.frames) - self.prompt_horizon
|
82 |
-
|
83 |
def step(self, action):
|
84 |
frame = self.frames[self.frame_idx]
|
85 |
assert self.frame_idx < len(self.frames)
|
@@ -87,20 +87,20 @@ class ReplaySimulator(PhysicsSimulator):
|
|
87 |
return {
|
88 |
'pred_next_frame': frame
|
89 |
}
|
90 |
-
|
91 |
def reset(self): # return current frame = last frame of prompt
|
92 |
self.frame_idx = self.prompt_horizon
|
93 |
return self.prompt()[-1]
|
94 |
-
|
95 |
def prompt(self):
|
96 |
return self.frames[:self.prompt_horizon]
|
97 |
-
|
98 |
@property
|
99 |
def dt(self):
|
100 |
return self._dt
|
101 |
-
|
102 |
|
103 |
-
|
|
|
104 |
|
105 |
class GenieSimulator(LearnedSimulator):
|
106 |
|
@@ -164,7 +164,7 @@ class GenieSimulator(LearnedSimulator):
|
|
164 |
elif backbone_type == "stmar":
|
165 |
inference_iterations = 2
|
166 |
|
167 |
-
# misc
|
168 |
self.device = torch.device(device)
|
169 |
self.measure_step_time = measure_step_time
|
170 |
self.compute_psnr = compute_psnr
|
@@ -200,11 +200,11 @@ class GenieSimulator(LearnedSimulator):
|
|
200 |
else:
|
201 |
self.backbone = STMAR.from_pretrained(backbone_ckpt)
|
202 |
self.backbone = self.backbone.to(device=self.device).eval()
|
203 |
-
|
204 |
self.post_processor = post_processor
|
205 |
-
|
206 |
# load physics simulator if available
|
207 |
-
# the phys sim to get ground truth image,
|
208 |
# assume the phys sim has aligned prompt frames
|
209 |
self.gt_phys_sim = physics_simulator
|
210 |
self.gt_teacher_force = physics_simulator_teacher_force
|
@@ -237,7 +237,7 @@ class GenieSimulator(LearnedSimulator):
|
|
237 |
# return: (H, W, 3)
|
238 |
assert self.cached_latent_frames is not None and self.cached_actions is not None, \
|
239 |
"Model is not prompted yet. Please call `set_initial_state` first."
|
240 |
-
|
241 |
if action.ndim == 1:
|
242 |
action = np.tile(action, (self.action_stride, 1))
|
243 |
|
@@ -273,7 +273,7 @@ class GenieSimulator(LearnedSimulator):
|
|
273 |
start_time = time.time()
|
274 |
pred_next_latent_state = self.backbone.maskgit_generate(
|
275 |
input_latent_states,
|
276 |
-
out_t=
|
277 |
maskgit_steps=self.inference_iterations,
|
278 |
temperature=self.sampling_temperature,
|
279 |
action_ids=input_actions,
|
@@ -310,7 +310,7 @@ class GenieSimulator(LearnedSimulator):
|
|
310 |
# compute PSNR against ground truth
|
311 |
if self.compute_psnr:
|
312 |
psnr = skimage.metrics.peak_signal_noise_ratio(
|
313 |
-
image_true=gt_next_frame / 255.,
|
314 |
image_test=pred_next_frame / 255.,
|
315 |
data_range=1.0
|
316 |
)
|
@@ -348,7 +348,7 @@ class GenieSimulator(LearnedSimulator):
|
|
348 |
|
349 |
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
|
350 |
pred_next_latent_state = self._encode_image(gt_next_frame)
|
351 |
-
|
352 |
# update history buffer
|
353 |
self.cached_latent_frames = torch.cat([
|
354 |
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
|
@@ -356,7 +356,7 @@ class GenieSimulator(LearnedSimulator):
|
|
356 |
self.cached_actions = torch.cat([
|
357 |
self.cached_actions[1:], action.unsqueeze(0)
|
358 |
])
|
359 |
-
|
360 |
# post processing
|
361 |
if self.post_processor is not None:
|
362 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
@@ -364,7 +364,7 @@ class GenieSimulator(LearnedSimulator):
|
|
364 |
self.step_count += 1
|
365 |
|
366 |
return step_result
|
367 |
-
|
368 |
|
369 |
@torch.inference_mode()
|
370 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
@@ -422,11 +422,11 @@ class GenieSimulator(LearnedSimulator):
|
|
422 |
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
|
423 |
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
|
424 |
return decoded_image
|
425 |
-
|
426 |
|
427 |
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
|
428 |
# (H, W, 3) normalized to [-1, 1]
|
429 |
-
# if `resize`, resize the shorter side to `resized_res`
|
430 |
# and then do a center crop
|
431 |
|
432 |
image = np.asarray(image, dtype=np.float32)
|
@@ -435,7 +435,7 @@ class GenieSimulator(LearnedSimulator):
|
|
435 |
|
436 |
# resize if asked
|
437 |
if self.resize_image:
|
438 |
-
resized_res = self.resize_image_resolution
|
439 |
if H < W:
|
440 |
Hnew, Wnew = resized_res, int(resized_res * W / H)
|
441 |
else:
|
@@ -469,7 +469,7 @@ class GenieSimulator(LearnedSimulator):
|
|
469 |
|
470 |
|
471 |
def reset(self) -> np.ndarray:
|
472 |
-
# if ground truth physics simulator is provided,
|
473 |
# return the the side-by-side concatenated image
|
474 |
|
475 |
# get the initial prompt from the physics simulator if not yet set
|
@@ -480,7 +480,7 @@ class GenieSimulator(LearnedSimulator):
|
|
480 |
action_prompt = np.zeros(
|
481 |
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
|
482 |
).astype(np.float32)
|
483 |
-
else:
|
484 |
assert self.init_prompt is not None, "Initial state is not set."
|
485 |
image_prompt, action_prompt = self.init_prompt
|
486 |
|
@@ -498,7 +498,7 @@ class GenieSimulator(LearnedSimulator):
|
|
498 |
], axis=0)
|
499 |
|
500 |
if self.resize_image:
|
501 |
-
current_image = cv2.resize(current_image,
|
502 |
(self.resize_image_resolution, self.resize_image_resolution))
|
503 |
|
504 |
if self.gt_phys_sim is not None:
|
|
|
27 |
@torch.inference_mode()
|
28 |
def step(self, action):
|
29 |
raise NotImplementedError
|
30 |
+
|
31 |
def reset(self):
|
32 |
raise NotImplementedError
|
33 |
+
|
34 |
def close(self):
|
35 |
raise NotImplementedError
|
36 |
+
|
37 |
@property
|
38 |
def dt(self):
|
39 |
raise NotImplementedError
|
|
|
46 |
# physics engine should be able to update dt
|
47 |
def set_dt(self, dt):
|
48 |
raise NotImplementedError
|
49 |
+
|
50 |
# physics engine should be able to get scene state
|
51 |
# e.g., robot joint positions, object positions, etc.
|
52 |
def get_raw_state(self, port: Optional[str] = None):
|
53 |
raise NotImplementedError
|
54 |
+
|
55 |
@property
|
56 |
def action_dimension(self):
|
57 |
raise NotImplementedError
|
58 |
+
|
59 |
|
60 |
class LearnedSimulator(Simulator):
|
61 |
def __init__(self):
|
|
|
65 |
# data replayed respect physics, so we inherit from PhysicsSimulator
|
66 |
# it can be considered as a special case of PhysicsSimulator
|
67 |
class ReplaySimulator(PhysicsSimulator):
|
68 |
+
def __init__(self,
|
69 |
+
frames,
|
70 |
+
prompt_horizon: int = 0,
|
71 |
dt: Optional[float] = None
|
72 |
):
|
73 |
super().__init__()
|
|
|
76 |
assert self.frame_idx < len(self.frames)
|
77 |
self._dt = dt
|
78 |
self.prompt_horizon = prompt_horizon
|
79 |
+
|
80 |
def __len__(self):
|
81 |
return len(self.frames) - self.prompt_horizon
|
82 |
+
|
83 |
def step(self, action):
|
84 |
frame = self.frames[self.frame_idx]
|
85 |
assert self.frame_idx < len(self.frames)
|
|
|
87 |
return {
|
88 |
'pred_next_frame': frame
|
89 |
}
|
90 |
+
|
91 |
def reset(self): # return current frame = last frame of prompt
|
92 |
self.frame_idx = self.prompt_horizon
|
93 |
return self.prompt()[-1]
|
94 |
+
|
95 |
def prompt(self):
|
96 |
return self.frames[:self.prompt_horizon]
|
97 |
+
|
98 |
@property
|
99 |
def dt(self):
|
100 |
return self._dt
|
|
|
101 |
|
102 |
+
|
103 |
+
|
104 |
|
105 |
class GenieSimulator(LearnedSimulator):
|
106 |
|
|
|
164 |
elif backbone_type == "stmar":
|
165 |
inference_iterations = 2
|
166 |
|
167 |
+
# misc
|
168 |
self.device = torch.device(device)
|
169 |
self.measure_step_time = measure_step_time
|
170 |
self.compute_psnr = compute_psnr
|
|
|
200 |
else:
|
201 |
self.backbone = STMAR.from_pretrained(backbone_ckpt)
|
202 |
self.backbone = self.backbone.to(device=self.device).eval()
|
203 |
+
|
204 |
self.post_processor = post_processor
|
205 |
+
|
206 |
# load physics simulator if available
|
207 |
+
# the phys sim to get ground truth image,
|
208 |
# assume the phys sim has aligned prompt frames
|
209 |
self.gt_phys_sim = physics_simulator
|
210 |
self.gt_teacher_force = physics_simulator_teacher_force
|
|
|
237 |
# return: (H, W, 3)
|
238 |
assert self.cached_latent_frames is not None and self.cached_actions is not None, \
|
239 |
"Model is not prompted yet. Please call `set_initial_state` first."
|
240 |
+
|
241 |
if action.ndim == 1:
|
242 |
action = np.tile(action, (self.action_stride, 1))
|
243 |
|
|
|
273 |
start_time = time.time()
|
274 |
pred_next_latent_state = self.backbone.maskgit_generate(
|
275 |
input_latent_states,
|
276 |
+
out_t=input_latent_states.shape[1] - 1,,
|
277 |
maskgit_steps=self.inference_iterations,
|
278 |
temperature=self.sampling_temperature,
|
279 |
action_ids=input_actions,
|
|
|
310 |
# compute PSNR against ground truth
|
311 |
if self.compute_psnr:
|
312 |
psnr = skimage.metrics.peak_signal_noise_ratio(
|
313 |
+
image_true=gt_next_frame / 255.,
|
314 |
image_test=pred_next_frame / 255.,
|
315 |
data_range=1.0
|
316 |
)
|
|
|
348 |
|
349 |
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
|
350 |
pred_next_latent_state = self._encode_image(gt_next_frame)
|
351 |
+
|
352 |
# update history buffer
|
353 |
self.cached_latent_frames = torch.cat([
|
354 |
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
|
|
|
356 |
self.cached_actions = torch.cat([
|
357 |
self.cached_actions[1:], action.unsqueeze(0)
|
358 |
])
|
359 |
+
|
360 |
# post processing
|
361 |
if self.post_processor is not None:
|
362 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
|
|
364 |
self.step_count += 1
|
365 |
|
366 |
return step_result
|
367 |
+
|
368 |
|
369 |
@torch.inference_mode()
|
370 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
|
|
422 |
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
|
423 |
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
|
424 |
return decoded_image
|
425 |
+
|
426 |
|
427 |
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
|
428 |
# (H, W, 3) normalized to [-1, 1]
|
429 |
+
# if `resize`, resize the shorter side to `resized_res`
|
430 |
# and then do a center crop
|
431 |
|
432 |
image = np.asarray(image, dtype=np.float32)
|
|
|
435 |
|
436 |
# resize if asked
|
437 |
if self.resize_image:
|
438 |
+
resized_res = self.resize_image_resolution
|
439 |
if H < W:
|
440 |
Hnew, Wnew = resized_res, int(resized_res * W / H)
|
441 |
else:
|
|
|
469 |
|
470 |
|
471 |
def reset(self) -> np.ndarray:
|
472 |
+
# if ground truth physics simulator is provided,
|
473 |
# return the the side-by-side concatenated image
|
474 |
|
475 |
# get the initial prompt from the physics simulator if not yet set
|
|
|
480 |
action_prompt = np.zeros(
|
481 |
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
|
482 |
).astype(np.float32)
|
483 |
+
else:
|
484 |
assert self.init_prompt is not None, "Initial state is not set."
|
485 |
image_prompt, action_prompt = self.init_prompt
|
486 |
|
|
|
498 |
], axis=0)
|
499 |
|
500 |
if self.resize_image:
|
501 |
+
current_image = cv2.resize(current_image,
|
502 |
(self.resize_image_resolution, self.resize_image_resolution))
|
503 |
|
504 |
if self.gt_phys_sim is not None:
|