import gradio as gr import spaces import numpy as np from PIL import Image import cv2 from sim.simulator import GenieSimulator RES = 512 image = Image.open("sim/assets/langtable_prompt/frame_06.png") genie = GenieSimulator( image_encoder_type='temporalvae', image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', quantize=False, backbone_type='stmar', backbone_ckpt='data/mar_ckpt/langtable', prompt_horizon=2, action_stride=1, domain='language_table', ) prompt_image = np.tile( np.array(image), (genie.prompt_horizon, 1, 1, 1) ).astype(np.uint8) prompt_action = np.zeros( (genie.prompt_horizon, genie.action_stride, 2) ).astype(np.float32) genie.set_initial_state((prompt_image, prompt_action)) image = genie.reset() image = cv2.resize(image, (RES, RES)) image = Image.fromarray(image) # Example model: takes a direction and returns a random image def model(direction: str, genie=genie): if direction == 'right': action = np.array([0, 0.05]) elif direction == 'left': action = np.array([0, -0.05]) elif direction == 'down': action = np.array([0.05, 0]) elif direction == 'up': action = np.array([-0.05, 0]) else: raise ValueError(f"Invalid direction: {direction}") next_image = genie.step(action)['pred_next_frame'] next_image = cv2.resize(next_image, (RES, RES)) return Image.fromarray(next_image) # Gradio function to handle user input @spaces.GPU def handle_input(direction): print(f"User clicked: {direction}") new_image = model(direction) # Get a new image from the model return new_image if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): image_display = gr.Image(value=image, type="pil", label="Generated Image") with gr.Row(): up = gr.Button("↑ Up") with gr.Row(): left = gr.Button("← Left") down = gr.Button("↓ Down") right = gr.Button("→ Right") # Define button interactions up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden') down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden') left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden') right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden') demo.launch()