liruiw commited on
Commit
8d2f46e
·
1 Parent(s): e176061
Files changed (2) hide show
  1. app.py +11 -0
  2. sim/simulator.py +5 -4
app.py CHANGED
@@ -80,6 +80,17 @@ if __name__ == '__main__':
80
  )
81
 
82
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
83
  with gr.Row():
84
  image_selector = gr.Dropdown(
85
  choices=available_images, value=available_images[0], label="Select an Image"
 
80
  )
81
 
82
  with gr.Blocks() as demo:
83
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
84
+ prompt_image = np.tile(
85
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
86
+ ).astype(np.uint8)
87
+ prompt_action = np.zeros(
88
+ (genie.prompt_horizon, genie.action_stride, 2)
89
+ ).astype(np.float32)
90
+ genie.set_initial_state((prompt_image, prompt_action))
91
+ image = genie.reset()
92
+
93
+
94
  with gr.Row():
95
  image_selector = gr.Dropdown(
96
  choices=available_images, value=available_images[0], label="Select an Image"
sim/simulator.py CHANGED
@@ -5,6 +5,7 @@ import einops
5
  import skimage
6
  import time
7
 
 
8
  from genie.st_mask_git import STMaskGIT
9
  from genie.st_mar import STMAR
10
  from datasets.utils import get_image_encoder
@@ -229,7 +230,7 @@ class GenieSimulator(LearnedSimulator):
229
  def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
230
  self.init_prompt = state
231
 
232
-
233
  @torch.inference_mode()
234
  def step(self, action: np.ndarray) -> Dict:
235
  # action: (action_stride, A) OR (A,)
@@ -364,7 +365,7 @@ class GenieSimulator(LearnedSimulator):
364
 
365
  return step_result
366
 
367
-
368
  @torch.inference_mode()
369
  def _encode_image(self, image: np.ndarray) -> torch.Tensor:
370
  # (H, W, 3)
@@ -396,7 +397,7 @@ class GenieSimulator(LearnedSimulator):
396
  latent = latent.squeeze(0).to(torch.float32).to(self.device)
397
  return latent
398
 
399
-
400
  @torch.inference_mode()
401
  def _decode_image(self, latent: torch.Tensor) -> np.ndarray:
402
  # latent can be either quantized indices or raw latent
@@ -467,7 +468,7 @@ class GenieSimulator(LearnedSimulator):
467
  image = np.clip(image, 0, 255).astype(np.uint8)
468
  return image
469
 
470
-
471
  def reset(self) -> np.ndarray:
472
  # if ground truth physics simulator is provided,
473
  # return the the side-by-side concatenated image
 
5
  import skimage
6
  import time
7
 
8
+ import spaces
9
  from genie.st_mask_git import STMaskGIT
10
  from genie.st_mar import STMAR
11
  from datasets.utils import get_image_encoder
 
230
  def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
231
  self.init_prompt = state
232
 
233
+ @spaces.GPU
234
  @torch.inference_mode()
235
  def step(self, action: np.ndarray) -> Dict:
236
  # action: (action_stride, A) OR (A,)
 
365
 
366
  return step_result
367
 
368
+ @spaces.GPU
369
  @torch.inference_mode()
370
  def _encode_image(self, image: np.ndarray) -> torch.Tensor:
371
  # (H, W, 3)
 
397
  latent = latent.squeeze(0).to(torch.float32).to(self.device)
398
  return latent
399
 
400
+ @spaces.GPU
401
  @torch.inference_mode()
402
  def _decode_image(self, latent: torch.Tensor) -> np.ndarray:
403
  # latent can be either quantized indices or raw latent
 
468
  image = np.clip(image, 0, 255).astype(np.uint8)
469
  return image
470
 
471
+ @spaces.GPU
472
  def reset(self) -> np.ndarray:
473
  # if ground truth physics simulator is provided,
474
  # return the the side-by-side concatenated image