liruiw commited on
Commit
4c4632b
1 Parent(s): 8eeb719
Files changed (1) hide show
  1. sim/simulator.py +32 -32
sim/simulator.py CHANGED
@@ -27,13 +27,13 @@ class Simulator:
27
  @torch.inference_mode()
28
  def step(self, action):
29
  raise NotImplementedError
30
-
31
  def reset(self):
32
  raise NotImplementedError
33
-
34
  def close(self):
35
  raise NotImplementedError
36
-
37
  @property
38
  def dt(self):
39
  raise NotImplementedError
@@ -46,16 +46,16 @@ class PhysicsSimulator(Simulator):
46
  # physics engine should be able to update dt
47
  def set_dt(self, dt):
48
  raise NotImplementedError
49
-
50
  # physics engine should be able to get scene state
51
  # e.g., robot joint positions, object positions, etc.
52
  def get_raw_state(self, port: Optional[str] = None):
53
  raise NotImplementedError
54
-
55
  @property
56
  def action_dimension(self):
57
  raise NotImplementedError
58
-
59
 
60
  class LearnedSimulator(Simulator):
61
  def __init__(self):
@@ -65,9 +65,9 @@ class LearnedSimulator(Simulator):
65
  # data replayed respect physics, so we inherit from PhysicsSimulator
66
  # it can be considered as a special case of PhysicsSimulator
67
  class ReplaySimulator(PhysicsSimulator):
68
- def __init__(self,
69
- frames,
70
- prompt_horizon: int = 0,
71
  dt: Optional[float] = None
72
  ):
73
  super().__init__()
@@ -76,10 +76,10 @@ class ReplaySimulator(PhysicsSimulator):
76
  assert self.frame_idx < len(self.frames)
77
  self._dt = dt
78
  self.prompt_horizon = prompt_horizon
79
-
80
  def __len__(self):
81
  return len(self.frames) - self.prompt_horizon
82
-
83
  def step(self, action):
84
  frame = self.frames[self.frame_idx]
85
  assert self.frame_idx < len(self.frames)
@@ -87,20 +87,20 @@ class ReplaySimulator(PhysicsSimulator):
87
  return {
88
  'pred_next_frame': frame
89
  }
90
-
91
  def reset(self): # return current frame = last frame of prompt
92
  self.frame_idx = self.prompt_horizon
93
  return self.prompt()[-1]
94
-
95
  def prompt(self):
96
  return self.frames[:self.prompt_horizon]
97
-
98
  @property
99
  def dt(self):
100
  return self._dt
101
-
102
 
103
-
 
104
 
105
  class GenieSimulator(LearnedSimulator):
106
 
@@ -164,7 +164,7 @@ class GenieSimulator(LearnedSimulator):
164
  elif backbone_type == "stmar":
165
  inference_iterations = 2
166
 
167
- # misc
168
  self.device = torch.device(device)
169
  self.measure_step_time = measure_step_time
170
  self.compute_psnr = compute_psnr
@@ -200,11 +200,11 @@ class GenieSimulator(LearnedSimulator):
200
  else:
201
  self.backbone = STMAR.from_pretrained(backbone_ckpt)
202
  self.backbone = self.backbone.to(device=self.device).eval()
203
-
204
  self.post_processor = post_processor
205
-
206
  # load physics simulator if available
207
- # the phys sim to get ground truth image,
208
  # assume the phys sim has aligned prompt frames
209
  self.gt_phys_sim = physics_simulator
210
  self.gt_teacher_force = physics_simulator_teacher_force
@@ -237,7 +237,7 @@ class GenieSimulator(LearnedSimulator):
237
  # return: (H, W, 3)
238
  assert self.cached_latent_frames is not None and self.cached_actions is not None, \
239
  "Model is not prompted yet. Please call `set_initial_state` first."
240
-
241
  if action.ndim == 1:
242
  action = np.tile(action, (self.action_stride, 1))
243
 
@@ -273,7 +273,7 @@ class GenieSimulator(LearnedSimulator):
273
  start_time = time.time()
274
  pred_next_latent_state = self.backbone.maskgit_generate(
275
  input_latent_states,
276
- out_t=self.prompt_horizon,
277
  maskgit_steps=self.inference_iterations,
278
  temperature=self.sampling_temperature,
279
  action_ids=input_actions,
@@ -310,7 +310,7 @@ class GenieSimulator(LearnedSimulator):
310
  # compute PSNR against ground truth
311
  if self.compute_psnr:
312
  psnr = skimage.metrics.peak_signal_noise_ratio(
313
- image_true=gt_next_frame / 255.,
314
  image_test=pred_next_frame / 255.,
315
  data_range=1.0
316
  )
@@ -348,7 +348,7 @@ class GenieSimulator(LearnedSimulator):
348
 
349
  if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
350
  pred_next_latent_state = self._encode_image(gt_next_frame)
351
-
352
  # update history buffer
353
  self.cached_latent_frames = torch.cat([
354
  self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
@@ -356,7 +356,7 @@ class GenieSimulator(LearnedSimulator):
356
  self.cached_actions = torch.cat([
357
  self.cached_actions[1:], action.unsqueeze(0)
358
  ])
359
-
360
  # post processing
361
  if self.post_processor is not None:
362
  pred_next_frame = self.post_processor(pred_next_frame, action)
@@ -364,7 +364,7 @@ class GenieSimulator(LearnedSimulator):
364
  self.step_count += 1
365
 
366
  return step_result
367
-
368
 
369
  @torch.inference_mode()
370
  def _encode_image(self, image: np.ndarray) -> torch.Tensor:
@@ -422,11 +422,11 @@ class GenieSimulator(LearnedSimulator):
422
  decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
423
  decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
424
  return decoded_image
425
-
426
 
427
  def _normalize_image(self, image: np.ndarray) -> np.ndarray:
428
  # (H, W, 3) normalized to [-1, 1]
429
- # if `resize`, resize the shorter side to `resized_res`
430
  # and then do a center crop
431
 
432
  image = np.asarray(image, dtype=np.float32)
@@ -435,7 +435,7 @@ class GenieSimulator(LearnedSimulator):
435
 
436
  # resize if asked
437
  if self.resize_image:
438
- resized_res = self.resize_image_resolution
439
  if H < W:
440
  Hnew, Wnew = resized_res, int(resized_res * W / H)
441
  else:
@@ -469,7 +469,7 @@ class GenieSimulator(LearnedSimulator):
469
 
470
 
471
  def reset(self) -> np.ndarray:
472
- # if ground truth physics simulator is provided,
473
  # return the the side-by-side concatenated image
474
 
475
  # get the initial prompt from the physics simulator if not yet set
@@ -480,7 +480,7 @@ class GenieSimulator(LearnedSimulator):
480
  action_prompt = np.zeros(
481
  (self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
482
  ).astype(np.float32)
483
- else:
484
  assert self.init_prompt is not None, "Initial state is not set."
485
  image_prompt, action_prompt = self.init_prompt
486
 
@@ -498,7 +498,7 @@ class GenieSimulator(LearnedSimulator):
498
  ], axis=0)
499
 
500
  if self.resize_image:
501
- current_image = cv2.resize(current_image,
502
  (self.resize_image_resolution, self.resize_image_resolution))
503
 
504
  if self.gt_phys_sim is not None:
 
27
  @torch.inference_mode()
28
  def step(self, action):
29
  raise NotImplementedError
30
+
31
  def reset(self):
32
  raise NotImplementedError
33
+
34
  def close(self):
35
  raise NotImplementedError
36
+
37
  @property
38
  def dt(self):
39
  raise NotImplementedError
 
46
  # physics engine should be able to update dt
47
  def set_dt(self, dt):
48
  raise NotImplementedError
49
+
50
  # physics engine should be able to get scene state
51
  # e.g., robot joint positions, object positions, etc.
52
  def get_raw_state(self, port: Optional[str] = None):
53
  raise NotImplementedError
54
+
55
  @property
56
  def action_dimension(self):
57
  raise NotImplementedError
58
+
59
 
60
  class LearnedSimulator(Simulator):
61
  def __init__(self):
 
65
  # data replayed respect physics, so we inherit from PhysicsSimulator
66
  # it can be considered as a special case of PhysicsSimulator
67
  class ReplaySimulator(PhysicsSimulator):
68
+ def __init__(self,
69
+ frames,
70
+ prompt_horizon: int = 0,
71
  dt: Optional[float] = None
72
  ):
73
  super().__init__()
 
76
  assert self.frame_idx < len(self.frames)
77
  self._dt = dt
78
  self.prompt_horizon = prompt_horizon
79
+
80
  def __len__(self):
81
  return len(self.frames) - self.prompt_horizon
82
+
83
  def step(self, action):
84
  frame = self.frames[self.frame_idx]
85
  assert self.frame_idx < len(self.frames)
 
87
  return {
88
  'pred_next_frame': frame
89
  }
90
+
91
  def reset(self): # return current frame = last frame of prompt
92
  self.frame_idx = self.prompt_horizon
93
  return self.prompt()[-1]
94
+
95
  def prompt(self):
96
  return self.frames[:self.prompt_horizon]
97
+
98
  @property
99
  def dt(self):
100
  return self._dt
 
101
 
102
+
103
+
104
 
105
  class GenieSimulator(LearnedSimulator):
106
 
 
164
  elif backbone_type == "stmar":
165
  inference_iterations = 2
166
 
167
+ # misc
168
  self.device = torch.device(device)
169
  self.measure_step_time = measure_step_time
170
  self.compute_psnr = compute_psnr
 
200
  else:
201
  self.backbone = STMAR.from_pretrained(backbone_ckpt)
202
  self.backbone = self.backbone.to(device=self.device).eval()
203
+
204
  self.post_processor = post_processor
205
+
206
  # load physics simulator if available
207
+ # the phys sim to get ground truth image,
208
  # assume the phys sim has aligned prompt frames
209
  self.gt_phys_sim = physics_simulator
210
  self.gt_teacher_force = physics_simulator_teacher_force
 
237
  # return: (H, W, 3)
238
  assert self.cached_latent_frames is not None and self.cached_actions is not None, \
239
  "Model is not prompted yet. Please call `set_initial_state` first."
240
+
241
  if action.ndim == 1:
242
  action = np.tile(action, (self.action_stride, 1))
243
 
 
273
  start_time = time.time()
274
  pred_next_latent_state = self.backbone.maskgit_generate(
275
  input_latent_states,
276
+ out_t=input_latent_states.shape[1] - 1,,
277
  maskgit_steps=self.inference_iterations,
278
  temperature=self.sampling_temperature,
279
  action_ids=input_actions,
 
310
  # compute PSNR against ground truth
311
  if self.compute_psnr:
312
  psnr = skimage.metrics.peak_signal_noise_ratio(
313
+ image_true=gt_next_frame / 255.,
314
  image_test=pred_next_frame / 255.,
315
  data_range=1.0
316
  )
 
348
 
349
  if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
350
  pred_next_latent_state = self._encode_image(gt_next_frame)
351
+
352
  # update history buffer
353
  self.cached_latent_frames = torch.cat([
354
  self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
 
356
  self.cached_actions = torch.cat([
357
  self.cached_actions[1:], action.unsqueeze(0)
358
  ])
359
+
360
  # post processing
361
  if self.post_processor is not None:
362
  pred_next_frame = self.post_processor(pred_next_frame, action)
 
364
  self.step_count += 1
365
 
366
  return step_result
367
+
368
 
369
  @torch.inference_mode()
370
  def _encode_image(self, image: np.ndarray) -> torch.Tensor:
 
422
  decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
423
  decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
424
  return decoded_image
425
+
426
 
427
  def _normalize_image(self, image: np.ndarray) -> np.ndarray:
428
  # (H, W, 3) normalized to [-1, 1]
429
+ # if `resize`, resize the shorter side to `resized_res`
430
  # and then do a center crop
431
 
432
  image = np.asarray(image, dtype=np.float32)
 
435
 
436
  # resize if asked
437
  if self.resize_image:
438
+ resized_res = self.resize_image_resolution
439
  if H < W:
440
  Hnew, Wnew = resized_res, int(resized_res * W / H)
441
  else:
 
469
 
470
 
471
  def reset(self) -> np.ndarray:
472
+ # if ground truth physics simulator is provided,
473
  # return the the side-by-side concatenated image
474
 
475
  # get the initial prompt from the physics simulator if not yet set
 
480
  action_prompt = np.zeros(
481
  (self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
482
  ).astype(np.float32)
483
+ else:
484
  assert self.init_prompt is not None, "Initial state is not set."
485
  image_prompt, action_prompt = self.init_prompt
486
 
 
498
  ], axis=0)
499
 
500
  if self.resize_image:
501
+ current_image = cv2.resize(current_image,
502
  (self.resize_image_resolution, self.resize_image_resolution))
503
 
504
  if self.gt_phys_sim is not None: