Spaces:
Runtime error
Runtime error
fix
Browse files- app.py +11 -0
- sim/simulator.py +5 -4
app.py
CHANGED
@@ -80,6 +80,17 @@ if __name__ == '__main__':
|
|
80 |
)
|
81 |
|
82 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
with gr.Row():
|
84 |
image_selector = gr.Dropdown(
|
85 |
choices=available_images, value=available_images[0], label="Select an Image"
|
|
|
80 |
)
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
+
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
|
84 |
+
prompt_image = np.tile(
|
85 |
+
np.array(image), (genie.prompt_horizon, 1, 1, 1)
|
86 |
+
).astype(np.uint8)
|
87 |
+
prompt_action = np.zeros(
|
88 |
+
(genie.prompt_horizon, genie.action_stride, 2)
|
89 |
+
).astype(np.float32)
|
90 |
+
genie.set_initial_state((prompt_image, prompt_action))
|
91 |
+
image = genie.reset()
|
92 |
+
|
93 |
+
|
94 |
with gr.Row():
|
95 |
image_selector = gr.Dropdown(
|
96 |
choices=available_images, value=available_images[0], label="Select an Image"
|
sim/simulator.py
CHANGED
@@ -5,6 +5,7 @@ import einops
|
|
5 |
import skimage
|
6 |
import time
|
7 |
|
|
|
8 |
from genie.st_mask_git import STMaskGIT
|
9 |
from genie.st_mar import STMAR
|
10 |
from datasets.utils import get_image_encoder
|
@@ -229,7 +230,7 @@ class GenieSimulator(LearnedSimulator):
|
|
229 |
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
|
230 |
self.init_prompt = state
|
231 |
|
232 |
-
|
233 |
@torch.inference_mode()
|
234 |
def step(self, action: np.ndarray) -> Dict:
|
235 |
# action: (action_stride, A) OR (A,)
|
@@ -364,7 +365,7 @@ class GenieSimulator(LearnedSimulator):
|
|
364 |
|
365 |
return step_result
|
366 |
|
367 |
-
|
368 |
@torch.inference_mode()
|
369 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
370 |
# (H, W, 3)
|
@@ -396,7 +397,7 @@ class GenieSimulator(LearnedSimulator):
|
|
396 |
latent = latent.squeeze(0).to(torch.float32).to(self.device)
|
397 |
return latent
|
398 |
|
399 |
-
|
400 |
@torch.inference_mode()
|
401 |
def _decode_image(self, latent: torch.Tensor) -> np.ndarray:
|
402 |
# latent can be either quantized indices or raw latent
|
@@ -467,7 +468,7 @@ class GenieSimulator(LearnedSimulator):
|
|
467 |
image = np.clip(image, 0, 255).astype(np.uint8)
|
468 |
return image
|
469 |
|
470 |
-
|
471 |
def reset(self) -> np.ndarray:
|
472 |
# if ground truth physics simulator is provided,
|
473 |
# return the the side-by-side concatenated image
|
|
|
5 |
import skimage
|
6 |
import time
|
7 |
|
8 |
+
import spaces
|
9 |
from genie.st_mask_git import STMaskGIT
|
10 |
from genie.st_mar import STMAR
|
11 |
from datasets.utils import get_image_encoder
|
|
|
230 |
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
|
231 |
self.init_prompt = state
|
232 |
|
233 |
+
@spaces.GPU
|
234 |
@torch.inference_mode()
|
235 |
def step(self, action: np.ndarray) -> Dict:
|
236 |
# action: (action_stride, A) OR (A,)
|
|
|
365 |
|
366 |
return step_result
|
367 |
|
368 |
+
@spaces.GPU
|
369 |
@torch.inference_mode()
|
370 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
371 |
# (H, W, 3)
|
|
|
397 |
latent = latent.squeeze(0).to(torch.float32).to(self.device)
|
398 |
return latent
|
399 |
|
400 |
+
@spaces.GPU
|
401 |
@torch.inference_mode()
|
402 |
def _decode_image(self, latent: torch.Tensor) -> np.ndarray:
|
403 |
# latent can be either quantized indices or raw latent
|
|
|
468 |
image = np.clip(image, 0, 255).astype(np.uint8)
|
469 |
return image
|
470 |
|
471 |
+
@spaces.GPU
|
472 |
def reset(self) -> np.ndarray:
|
473 |
# if ground truth physics simulator is provided,
|
474 |
# return the the side-by-side concatenated image
|