Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from sim.simulator import GenieSimulator | |
import os | |
if not os.path.exists("data/mar_ckpt/langtable"): | |
# download from google drive | |
import gdown | |
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN") | |
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/") | |
RES = 512 | |
PROMPT_HORIZON = 3 | |
IMAGE_DIR = "sim/assets/langtable_prompt/" | |
cached_latent_frames = None | |
cached_actions = None | |
# Load available images | |
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")]) | |
# Helper function to reset GenieSimulator with the selected image | |
def initialize_simulator(image_name): | |
image_path = os.path.join(IMAGE_DIR, image_name) | |
image = Image.open(image_path) | |
prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8) | |
prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32) | |
genie.set_initial_state((prompt_image, prompt_action)) | |
reset_image, set_cached_latent_frames, set_cached_actions = genie.reset() | |
global cached_latent_frames | |
global cached_actions | |
cached_latent_frames = set_cached_latent_frames | |
cached_actions = set_cached_actions | |
reset_image = cv2.resize(reset_image, (RES, RES)) | |
return Image.fromarray(reset_image) | |
def model(direction): | |
global cached_latent_frames | |
global cached_actions | |
if direction == 'right': | |
action = np.array([0, 0.05]) | |
elif direction == 'left': | |
action = np.array([0, -0.05]) | |
elif direction == 'down': | |
action = np.array([0.05, 0]) | |
elif direction == 'up': | |
action = np.array([-0.05, 0]) | |
else: | |
raise ValueError(f"Invalid direction: {direction}") | |
genie_result = genie.step(action, cached_latent_frames=cached_latent_frames, cached_actions=cached_actions) | |
next_image = genie_result['pred_next_frame'] | |
cached_latent_frames = genie_result['set_cached_latent_frames'] | |
cached_actions = genie_result['set_cached_actions'] | |
next_image = cv2.resize(next_image, (RES, RES)) | |
return Image.fromarray(next_image) | |
def handle_input(direction): | |
print(f"User clicked: {direction}") | |
new_image = model(direction) | |
return new_image | |
def handle_image_selection(image_name): | |
print(f"User selected image: {image_name}") | |
return initialize_simulator(image_name) | |
genie = GenieSimulator( | |
image_encoder_type='temporalvae', | |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
quantize=False, | |
backbone_type='stmar', | |
backbone_ckpt='data/mar_ckpt/langtable', | |
prompt_horizon=PROMPT_HORIZON, | |
action_stride=1, | |
domain='language_table', | |
device="cuda" | |
) | |
image = Image.open("sim/assets/langtable_prompt/frame_06.png") | |
prompt_image = np.tile( | |
np.array(image), (genie.prompt_horizon, 1, 1, 1) | |
).astype(np.uint8) | |
prompt_action = np.zeros( | |
(genie.prompt_horizon, genie.action_stride, 2) | |
).astype(np.float32) | |
genie.set_initial_state((prompt_image, prompt_action)) | |
genie.device = "cuda" | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
genie.device = "cuda" | |
with gr.Row(): | |
gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n' | |
'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1) | |
with gr.Row(): | |
image_selector = gr.Dropdown( | |
choices=available_images, value=available_images[0], label="Select an Image" | |
) | |
select_button = gr.Button("Load Image") | |
with gr.Row(): | |
image_display = gr.Image(type="pil", label="Generated Image") | |
with gr.Row(): | |
up = gr.Button("β Up") | |
with gr.Row(): | |
left = gr.Button("β Left") | |
down = gr.Button("β Down") | |
right = gr.Button("β Right") | |
# Define interactions | |
select_button.click( | |
fn=lambda: handle_image_selection, inputs=image_selector, outputs=image_display, show_progress='hidden' | |
) | |
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden') | |
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden') | |
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden') | |
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden') | |
demo.launch(share=True) |