Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files
app.py
CHANGED
@@ -1,36 +1,52 @@
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
|
|
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
import cv2
|
7 |
from sim.simulator import GenieSimulator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
RES = 512
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
genie = GenieSimulator(
|
12 |
image_encoder_type='temporalvae',
|
13 |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
|
14 |
quantize=False,
|
15 |
backbone_type='stmar',
|
16 |
-
backbone_ckpt='data/
|
17 |
-
prompt_horizon=
|
18 |
action_stride=1,
|
19 |
domain='language_table',
|
20 |
)
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
).astype(np.
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
31 |
|
32 |
# Example model: takes a direction and returns a random image
|
33 |
-
def model(direction: str
|
34 |
if direction == 'right':
|
35 |
action = np.array([0, 0.05])
|
36 |
elif direction == 'left':
|
@@ -52,10 +68,22 @@ def handle_input(direction):
|
|
52 |
new_image = model(direction) # Get a new image from the model
|
53 |
return new_image
|
54 |
|
|
|
|
|
|
|
|
|
|
|
55 |
if __name__ == '__main__':
|
56 |
with gr.Blocks() as demo:
|
57 |
with gr.Row():
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
with gr.Row():
|
60 |
up = gr.Button("β Up")
|
61 |
with gr.Row():
|
@@ -63,10 +91,13 @@ if __name__ == '__main__':
|
|
63 |
down = gr.Button("β Down")
|
64 |
right = gr.Button("β Right")
|
65 |
|
66 |
-
# Define
|
|
|
|
|
|
|
67 |
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
|
68 |
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
|
69 |
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
|
70 |
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
|
71 |
|
72 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
import numpy as np
|
7 |
from PIL import Image
|
8 |
import cv2
|
9 |
from sim.simulator import GenieSimulator
|
10 |
+
import os
|
11 |
+
|
12 |
+
if not os.path.exists("data/mar_ckpt/langtable"):
|
13 |
+
# download from google drive
|
14 |
+
import gdown
|
15 |
+
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
|
16 |
+
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
|
17 |
|
18 |
RES = 512
|
19 |
+
PROMPT_HORIZON = 3
|
20 |
+
IMAGE_DIR = "sim/assets/langtable_prompt/"
|
21 |
+
|
22 |
+
# Load available images
|
23 |
+
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
|
24 |
+
|
25 |
+
|
26 |
genie = GenieSimulator(
|
27 |
image_encoder_type='temporalvae',
|
28 |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
|
29 |
quantize=False,
|
30 |
backbone_type='stmar',
|
31 |
+
backbone_ckpt='data/mar_ckpt_long2/langtable',
|
32 |
+
prompt_horizon=PROMPT_HORIZON,
|
33 |
action_stride=1,
|
34 |
domain='language_table',
|
35 |
)
|
36 |
+
|
37 |
+
# Helper function to reset GenieSimulator with the selected image
|
38 |
+
def initialize_simulator(image_name):
|
39 |
+
image_path = os.path.join(IMAGE_DIR, image_name)
|
40 |
+
image = Image.open(image_path)
|
41 |
+
prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
|
42 |
+
prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
|
43 |
+
genie.set_initial_state((prompt_image, prompt_action))
|
44 |
+
reset_image = genie.reset()
|
45 |
+
reset_image = cv2.resize(reset_image, (RES, RES))
|
46 |
+
return Image.fromarray(reset_image)
|
47 |
|
48 |
# Example model: takes a direction and returns a random image
|
49 |
+
def model(direction: str):
|
50 |
if direction == 'right':
|
51 |
action = np.array([0, 0.05])
|
52 |
elif direction == 'left':
|
|
|
68 |
new_image = model(direction) # Get a new image from the model
|
69 |
return new_image
|
70 |
|
71 |
+
# Gradio function to handle image selection
|
72 |
+
def handle_image_selection(image_name):
|
73 |
+
print(f"User selected image: {image_name}")
|
74 |
+
return initialize_simulator(image_name)
|
75 |
+
|
76 |
if __name__ == '__main__':
|
77 |
with gr.Blocks() as demo:
|
78 |
with gr.Row():
|
79 |
+
image_selector = gr.Dropdown(
|
80 |
+
choices=available_images, value=available_images[0], label="Select an Image"
|
81 |
+
)
|
82 |
+
select_button = gr.Button("Load Image")
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
image_display = gr.Image(type="pil", label="Generated Image")
|
86 |
+
|
87 |
with gr.Row():
|
88 |
up = gr.Button("β Up")
|
89 |
with gr.Row():
|
|
|
91 |
down = gr.Button("β Down")
|
92 |
right = gr.Button("β Right")
|
93 |
|
94 |
+
# Define interactions
|
95 |
+
select_button.click(
|
96 |
+
fn=handle_image_selection, inputs=image_selector, outputs=image_display
|
97 |
+
)
|
98 |
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
|
99 |
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
|
100 |
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
|
101 |
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
|
102 |
|
103 |
+
demo.launch(share=True)
|