import gradio as gr import spaces import gradio as gr import numpy as np from PIL import Image import cv2 from sim.simulator import GenieSimulator import os if not os.path.exists("data/mar_ckpt/langtable"): # download from google drive import gdown gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN") os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/") RES = 512 PROMPT_HORIZON = 3 IMAGE_DIR = "sim/assets/langtable_prompt/" # Load available images available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")]) def initialize_simulator(image_name, state): image_path = os.path.join(IMAGE_DIR, image_name) image = Image.open(image_path) prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8) prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32) state['genie'].set_initial_state((prompt_image, prompt_action)) reset_image = state['genie'].reset() reset_image = cv2.resize(reset_image, (RES, RES)) return Image.fromarray(reset_image) def model(direction, state): if direction == 'right': action = np.array([0, 0.05]) elif direction == 'left': action = np.array([0, -0.05]) elif direction == 'down': action = np.array([0.05, 0]) elif direction == 'up': action = np.array([-0.05, 0]) else: raise ValueError(f"Invalid direction: {direction}") next_image = state['genie'].step(action)['pred_next_frame'] next_image = cv2.resize(next_image, (RES, RES)) return Image.fromarray(next_image) def handle_input(direction, state): print(f"User clicked: {direction}") new_image = model(direction, state) return new_image def handle_image_selection(image_name, state): print(f"User selected image: {image_name}") return initialize_simulator(image_name, state) if __name__ == '__main__': with gr.Blocks() as demo: genie_instance = gr.State({ 'genie': GenieSimulator( image_encoder_type='temporalvae', image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', quantize=False, backbone_type='stmar', backbone_ckpt='data/mar_ckpt/langtable', prompt_horizon=PROMPT_HORIZON, action_stride=1, domain='language_table', ) }) with gr.Row(): image_selector = gr.Dropdown( choices=available_images, value=available_images[0], label="Select an Image" ) select_button = gr.Button("Load Image") with gr.Row(): image_display = gr.Image(type="pil", label="Generated Image") with gr.Row(): up = gr.Button("↑ Up") with gr.Row(): left = gr.Button("← Left") down = gr.Button("↓ Down") right = gr.Button("→ Right") select_button.click( fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden' ) up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden') demo.launch()