hma / app.py
liruiw's picture
improve pred
8eeb719
raw
history blame
2.46 kB
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
RES = 512
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
genie = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt/langtable',
prompt_horizon=2,
action_stride=1,
domain='language_table',
)
prompt_image = np.tile(
np.array(image), (genie.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
prompt_action = np.zeros(
(genie.prompt_horizon, genie.action_stride, 2)
).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
image = genie.reset()
image = cv2.resize(image, (RES, RES))
image = Image.fromarray(image)
# Example model: takes a direction and returns a random image
def model(direction: str, genie=genie):
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
next_image = genie.step(action)['pred_next_frame']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
print(f"User clicked: {direction}")
new_image = model(direction) # Get a new image from the model
return new_image
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
image_display = gr.Image(value=image, type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("↑ Up")
with gr.Row():
left = gr.Button("← Left")
down = gr.Button("↓ Down")
right = gr.Button("β†’ Right")
# Define button interactions
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
demo.launch()