Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,933 Bytes
246c106 bce7a55 246c106 14420e9 240cb6a 14420e9 246c106 14420e9 240cb6a aa82009 e176061 240cb6a 14420e9 5ed752c 14420e9 246c106 240cb6a e176061 240cb6a 246c106 5ed752c 246c106 240cb6a b258516 ca9c8aa 246c106 240cb6a 246c106 240cb6a e176061 240cb6a 14420e9 240cb6a 8c3783f e176061 ca9c8aa 240cb6a 246c106 14420e9 246c106 aa82009 240cb6a aa82009 a100e67 240cb6a aa82009 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os
import spaces
if not os.path.exists("data/mar_ckpt/langtable"):
# download from google drive
import gdown
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"
# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
genie = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt_long2/langtable',
prompt_horizon=PROMPT_HORIZON,
action_stride=1,
domain='language_table',
)
# Helper function to reset GenieSimulator with the selected image
@spaces.GPU
def initialize_simulator(image_name):
image_path = os.path.join(IMAGE_DIR, image_name)
image = Image.open(image_path)
prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
reset_image = genie.reset()
reset_image = cv2.resize(reset_image, (RES, RES))
return Image.fromarray(reset_image)
# Example model: takes a direction and returns a random image
@spaces.GPU
def model(direction: str):
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
next_image = genie.step(action)['pred_next_frame']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
print(f"User clicked: {direction}")
new_image = model(direction) # Get a new image from the model
return new_image
# Gradio function to handle image selection
@spaces.GPU
def handle_image_selection(image_name):
print(f"User selected image: {image_name}")
return initialize_simulator(image_name)
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1)
with gr.Row():
image_selector = gr.Dropdown(
choices=available_images, value=available_images[0], label="Select an Image"
)
select_button = gr.Button("Load Image")
with gr.Row():
image_display = gr.Image(type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("β Up")
with gr.Row():
left = gr.Button("β Left")
down = gr.Button("β Down")
right = gr.Button("β Right")
# Define interactions
select_button.click(
fn=handle_image_selection, inputs=image_selector, outputs=image_display
)
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
demo.launch(share=True) |