Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from sim.simulator import GenieSimulator | |
RES = 512 | |
image = Image.open("sim/assets/langtable_prompt/frame_06.png") | |
genie = GenieSimulator( | |
image_encoder_type='temporalvae', | |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
quantize=False, | |
backbone_type='stmar', | |
backbone_ckpt='data/mar_ckpt/langtable', | |
prompt_horizon=3, | |
action_stride=1, | |
domain='language_table', | |
) | |
prompt_image = np.tile( | |
np.array(image), (genie.prompt_horizon, 1, 1, 1) | |
).astype(np.uint8) | |
prompt_action = np.zeros( | |
(genie.prompt_horizon - 1, genie.action_stride, 2) | |
).astype(np.float32) | |
genie.set_initial_state((prompt_image, prompt_action)) | |
image = genie.reset() | |
image = cv2.resize(image, (RES, RES)) | |
image = Image.fromarray(image) | |
# Example model: takes a direction and returns a random image | |
def model(direction: str, genie=genie): | |
if direction == 'right': | |
action = np.array([0, 0.05]) | |
elif direction == 'left': | |
action = np.array([0, -0.05]) | |
elif direction == 'down': | |
action = np.array([0.05, 0]) | |
elif direction == 'up': | |
action = np.array([-0.05, 0]) | |
else: | |
raise ValueError(f"Invalid direction: {direction}") | |
next_image = genie.step(action)['pred_next_frame'] | |
next_image = cv2.resize(next_image, (RES, RES)) | |
return Image.fromarray(next_image) | |
# Gradio function to handle user input | |
def handle_input(direction): | |
print(f"User clicked: {direction}") | |
new_image = model(direction) # Get a new image from the model | |
return new_image | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image_display = gr.Image(value=image, type="pil", label="Generated Image") | |
with gr.Row(): | |
up = gr.Button("β Up") | |
with gr.Row(): | |
left = gr.Button("β Left") | |
down = gr.Button("β Down") | |
right = gr.Button("β Right") | |
# Define button interactions | |
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden') | |
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden') | |
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden') | |
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden') | |
demo.launch() | |