Spaces:
Runtime error
Runtime error
fix
Browse files- app.py +26 -22
- sim/simulator.py +1 -0
app.py
CHANGED
@@ -65,31 +65,35 @@ def handle_image_selection(image_name, state):
|
|
65 |
print(f"User selected image: {image_name}")
|
66 |
return initialize_simulator(image_name, state)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
|
|
|
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
-
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
|
84 |
-
prompt_image = np.tile(
|
85 |
-
np.array(image), (genie.prompt_horizon, 1, 1, 1)
|
86 |
-
).astype(np.uint8)
|
87 |
-
prompt_action = np.zeros(
|
88 |
-
(genie.prompt_horizon, genie.action_stride, 2)
|
89 |
-
).astype(np.float32)
|
90 |
-
genie.set_initial_state((prompt_image, prompt_action))
|
91 |
-
genie.device = "cuda"
|
92 |
-
image = genie.reset()
|
93 |
genie_instance = gr.State({'genie': genie})
|
94 |
|
95 |
with gr.Row():
|
|
|
65 |
print(f"User selected image: {image_name}")
|
66 |
return initialize_simulator(image_name, state)
|
67 |
|
68 |
+
@spaces.GPU
|
69 |
+
def init_model():
|
70 |
+
genie = GenieSimulator(
|
71 |
+
image_encoder_type='temporalvae',
|
72 |
+
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
|
73 |
+
quantize=False,
|
74 |
+
backbone_type='stmar',
|
75 |
+
backbone_ckpt='data/mar_ckpt/langtable',
|
76 |
+
prompt_horizon=PROMPT_HORIZON,
|
77 |
+
action_stride=1,
|
78 |
+
domain='language_table',
|
79 |
+
)
|
80 |
+
|
81 |
+
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
|
82 |
+
prompt_image = np.tile(
|
83 |
+
np.array(image), (genie.prompt_horizon, 1, 1, 1)
|
84 |
+
).astype(np.uint8)
|
85 |
+
prompt_action = np.zeros(
|
86 |
+
(genie.prompt_horizon, genie.action_stride, 2)
|
87 |
+
).astype(np.float32)
|
88 |
+
genie.set_initial_state((prompt_image, prompt_action))
|
89 |
+
genie.device = "cuda"
|
90 |
+
image = genie.reset()
|
91 |
+
return genie
|
92 |
|
93 |
+
if __name__ == '__main__':
|
94 |
+
genie = init_model()
|
95 |
|
96 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
genie_instance = gr.State({'genie': genie})
|
98 |
|
99 |
with gr.Row():
|
sim/simulator.py
CHANGED
@@ -245,6 +245,7 @@ class GenieSimulator(LearnedSimulator):
|
|
245 |
if self.gauss_act_perturb_scale is not None:
|
246 |
action = np.random.normal(action, self.gauss_act_perturb_scale)
|
247 |
|
|
|
248 |
# encoding
|
249 |
input_latent_states = torch.cat([
|
250 |
self.cached_latent_frames,
|
|
|
245 |
if self.gauss_act_perturb_scale is not None:
|
246 |
action = np.random.normal(action, self.gauss_act_perturb_scale)
|
247 |
|
248 |
+
self.backbone = backbone.to(device=self.device).eval()
|
249 |
# encoding
|
250 |
input_latent_states = torch.cat([
|
251 |
self.cached_latent_frames,
|