Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,106 Bytes
aca205f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os
if not os.path.exists("data/mar_ckpt/langtable"):
# download from google drive
import gdown
gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"
# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
# Helper function to reset GenieSimulator with the selected image
@spaces.GPU
def initialize_simulator(image_name):
global genie
image_path = os.path.join(IMAGE_DIR, image_name)
image = Image.open(image_path)
prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
reset_image = genie.reset()
reset_image = cv2.resize(reset_image, (RES, RES))
return Image.fromarray(reset_image)
# Example model: takes a direction and returns a random image
@spaces.GPU
def model(direction: str):
global genie
if direction == 'right':
action = np.array([0, 0.05])
elif direction == 'left':
action = np.array([0, -0.05])
elif direction == 'down':
action = np.array([0.05, 0])
elif direction == 'up':
action = np.array([-0.05, 0])
else:
raise ValueError(f"Invalid direction: {direction}")
next_image = genie.step(action)['pred_next_frame']
next_image = cv2.resize(next_image, (RES, RES))
return Image.fromarray(next_image)
# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
print(f"User clicked: {direction}")
new_image = model(direction) # Get a new image from the model
return new_image
# Gradio function to handle image selection
@spaces.GPU
def handle_image_selection(image_name):
print(f"User selected image: {image_name}")
return initialize_simulator(image_name)
if __name__ == '__main__':
genie = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type='stmar',
backbone_ckpt='data/mar_ckpt/langtable',
prompt_horizon=PROMPT_HORIZON,
action_stride=1,
domain='language_table',
)
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
prompt_image = np.tile(
np.array(image), (genie.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
prompt_action = np.zeros(
(genie.prompt_horizon, genie.action_stride, 2)
).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
image = genie.reset()
with gr.Blocks() as demo:
with gr.Row():
image_selector = gr.Dropdown(
choices=available_images, value=available_images[0], label="Select an Image"
)
select_button = gr.Button("Load Image")
with gr.Row():
image_display = gr.Image(type="pil", label="Generated Image")
with gr.Row():
up = gr.Button("β Up")
with gr.Row():
left = gr.Button("β Left")
down = gr.Button("β Down")
right = gr.Button("β Right")
# Define interactions
select_button.click(
fn=handle_image_selection, inputs=image_selector, outputs=image_display
)
up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
demo.launch()
|