Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from sim.simulator import GenieSimulator | |
import os | |
if not os.path.exists("data/mar_ckpt/langtable"): | |
# download from google drive | |
import gdown | |
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN") | |
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/") | |
RES = 512 | |
PROMPT_HORIZON = 3 | |
IMAGE_DIR = "sim/assets/langtable_prompt/" | |
# Load available images | |
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")]) | |
def initialize_simulator(image_name, state): | |
image_path = os.path.join(IMAGE_DIR, image_name) | |
image = Image.open(image_path) | |
prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8) | |
prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32) | |
state['genie'].set_initial_state((prompt_image, prompt_action)) | |
reset_image = state['genie'].reset() | |
reset_image = cv2.resize(reset_image, (RES, RES)) | |
return Image.fromarray(reset_image) | |
def model(direction, state): | |
if direction == 'right': | |
action = np.array([0, 0.05]) | |
elif direction == 'left': | |
action = np.array([0, -0.05]) | |
elif direction == 'down': | |
action = np.array([0.05, 0]) | |
elif direction == 'up': | |
action = np.array([-0.05, 0]) | |
else: | |
raise ValueError(f"Invalid direction: {direction}") | |
next_image = state['genie'].step(action)['pred_next_frame'] | |
next_image = cv2.resize(next_image, (RES, RES)) | |
return Image.fromarray(next_image) | |
def handle_input(direction, state): | |
print(f"User clicked: {direction}") | |
new_image = model(direction, state) | |
return new_image | |
def handle_image_selection(image_name, state): | |
print(f"User selected image: {image_name}") | |
return initialize_simulator(image_name, state) | |
def init_model(): | |
genie = GenieSimulator( | |
image_encoder_type='temporalvae', | |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
quantize=False, | |
backbone_type='stmar', | |
backbone_ckpt='data/mar_ckpt/langtable', | |
prompt_horizon=PROMPT_HORIZON, | |
action_stride=1, | |
domain='language_table', | |
) | |
image = Image.open("sim/assets/langtable_prompt/frame_06.png") | |
prompt_image = np.tile( | |
np.array(image), (genie.prompt_horizon, 1, 1, 1) | |
).astype(np.uint8) | |
prompt_action = np.zeros( | |
(genie.prompt_horizon, genie.action_stride, 2) | |
).astype(np.float32) | |
return genie | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
genie = init_model() | |
genie_instance = gr.State({ | |
'genie': genie}) | |
with gr.Row(): | |
image_selector = gr.Dropdown( | |
choices=available_images, value=available_images[0], label="Select an Image" | |
) | |
select_button = gr.Button("Load Image") | |
with gr.Row(): | |
image_display = gr.Image(type="pil", label="Generated Image") | |
select_button.click( | |
fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden' | |
) | |
with gr.Row(): | |
up = gr.Button("β Up") | |
with gr.Row(): | |
left = gr.Button("β Left") | |
down = gr.Button("β Down") | |
right = gr.Button("β Right") | |
up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') | |
down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') | |
left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') | |
right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') | |
demo.launch() |