hma / app.py
liruiw's picture
Update app.py
8e45d5b verified
import gradio as gr
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os
if not os.path.exists("data/mar_ckpt/langtable"):
# download from google drive
import gdown
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"
cached_latent_frames = None
cached_actions = None
# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
# Helper function to reset GenieSimulator with the selected image
@spaces.GPU
def initialize_simulator(image_name):
image_path = os.path.join(IMAGE_DIR, image_name)
image = Image.open(image_path)
prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
reset_image, set_cached_latent_frames, set_cached_actions = genie.reset()
global cached_latent_frames
global cached_actions
cached_latent_frames = set_cached_latent_frames
cached_actions = set_cached_actions
reset_image = cv2.resize(reset_image, (RES, RES))
return Image.fromarray(reset_image)
@spaces.GPU
def model(direction):
global cached_latent_frames
global cached_actions
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
genie_result = genie.step(action, cached_latent_frames=cached_latent_frames, cached_actions=cached_actions)
next_image = genie_result['pred_next_frame']
cached_latent_frames = genie_result['set_cached_latent_frames']
cached_actions = genie_result['set_cached_actions']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
@spaces.GPU
def handle_input(direction):
print(f"User clicked: {direction}")
new_image = model(direction)
return new_image
@spaces.GPU
def handle_image_selection(image_name):
print(f"User selected image: {image_name}")
return initialize_simulator(image_name)
genie = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt/langtable',
prompt_horizon=PROMPT_HORIZON,
action_stride=1,
domain='language_table',
device="cuda"
)
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
prompt_image = np.tile(
np.array(image), (genie.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
prompt_action = np.zeros(
(genie.prompt_horizon, genie.action_stride, 2)
).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
genie.device = "cuda"
if __name__ == '__main__':
with gr.Blocks() as demo:
genie.device = "cuda"
with gr.Row():
gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1)
with gr.Row():
image_selector = gr.Dropdown(
choices=available_images, value=available_images[0], label="Select an Image"
)
select_button = gr.Button("Load Image")
with gr.Row():
image_display = gr.Image(type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("↑ Up")
with gr.Row():
left = gr.Button("← Left")
down = gr.Button("↓ Down")
right = gr.Button("β†’ Right")
# Define interactions
select_button.click(
fn=lambda: handle_image_selection, inputs=image_selector, outputs=image_display, show_progress='hidden'
)
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
demo.launch(share=True)