Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,789 Bytes
246c106 bce7a55 14420e9 246c106 14420e9 246c106 14420e9 aca205f 14420e9 aca205f 14420e9 246c106 aca205f 246c106 aca205f 246c106 aca205f 246c106 aca205f 246c106 aca205f 14420e9 aca205f 14420e9 246c106 aca205f 246c106 14420e9 246c106 14420e9 aca205f 14420e9 aca205f 246c106 aca205f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os
if not os.path.exists("data/mar_ckpt/langtable"):
# download from google drive
import gdown
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"
# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
def initialize_simulator(image_name, state):
image_path = os.path.join(IMAGE_DIR, image_name)
image = Image.open(image_path)
prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
state['genie'].set_initial_state((prompt_image, prompt_action))
reset_image = state['genie'].reset()
reset_image = cv2.resize(reset_image, (RES, RES))
return Image.fromarray(reset_image)
def model(direction, state):
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
next_image = state['genie'].step(action)['pred_next_frame']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
def handle_input(direction, state):
print(f"User clicked: {direction}")
new_image = model(direction, state)
return new_image
def handle_image_selection(image_name, state):
print(f"User selected image: {image_name}")
return initialize_simulator(image_name, state)
if __name__ == '__main__':
with gr.Blocks() as demo:
genie_instance = gr.State({
'genie': GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt/langtable',
prompt_horizon=PROMPT_HORIZON,
action_stride=1,
domain='language_table',
)
})
with gr.Row():
image_selector = gr.Dropdown(
choices=available_images, value=available_images[0], label="Select an Image"
)
select_button = gr.Button("Load Image")
with gr.Row():
image_display = gr.Image(type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("β Up")
with gr.Row():
left = gr.Button("β Left")
down = gr.Button("β Down")
right = gr.Button("β Right")
select_button.click(
fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
)
up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
demo.launch() |