File size: 3,933 Bytes
246c106
bce7a55
 
246c106
 
 
 
14420e9
240cb6a
 
14420e9
 
 
 
 
 
246c106
 
14420e9
 
 
 
 
 
 
240cb6a
 
 
 
 
 
 
 
 
 
aa82009
 
e176061
240cb6a
14420e9
 
5ed752c
 
 
 
14420e9
 
246c106
240cb6a
e176061
240cb6a
246c106
 
 
 
 
 
 
 
 
 
5ed752c
246c106
 
 
240cb6a
b258516
ca9c8aa
246c106
240cb6a
246c106
 
240cb6a
e176061
240cb6a
14420e9
240cb6a
8c3783f
 
e176061
ca9c8aa
 
240cb6a
 
246c106
14420e9
 
 
 
 
 
 
 
246c106
 
 
 
 
 
 
aa82009
 
240cb6a
aa82009
a100e67
 
 
 
240cb6a
aa82009
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import spaces

import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os
import spaces


if not os.path.exists("data/mar_ckpt/langtable"):
    # download from google drive
    import gdown
    gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
    os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")

RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"

# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])


genie = GenieSimulator(
    image_encoder_type='temporalvae',
    image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
    quantize=False,
    backbone_type='stmar',
    backbone_ckpt='data/mar_ckpt_long2/langtable',
    prompt_horizon=PROMPT_HORIZON,
    action_stride=1,
    domain='language_table',
)

# Helper function to reset GenieSimulator with the selected image
@spaces.GPU
def initialize_simulator(image_name):
    image_path = os.path.join(IMAGE_DIR, image_name)
    image = Image.open(image_path)
    prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
    prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
    genie.set_initial_state((prompt_image, prompt_action))
    reset_image = genie.reset()
    reset_image = cv2.resize(reset_image, (RES, RES))
    return Image.fromarray(reset_image)

# Example model: takes a direction and returns a random image
@spaces.GPU
def model(direction: str):
    if direction == 'right':
        action = np.array([0, 0.05])
    elif direction == 'left':
        action = np.array([0, -0.05])
    elif direction == 'down':
        action = np.array([0.05, 0])
    elif direction == 'up':
        action = np.array([-0.05, 0])
    else:
        raise ValueError(f"Invalid direction: {direction}")
    next_image = genie.step(action)['pred_next_frame']
    next_image = cv2.resize(next_image, (RES, RES))
    return Image.fromarray(next_image)

# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
    print(f"User clicked: {direction}")
    new_image = model(direction)  # Get a new image from the model
    return new_image

# Gradio function to handle image selection
@spaces.GPU
def handle_image_selection(image_name):
    print(f"User selected image: {image_name}")
    return initialize_simulator(image_name)

if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
            'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1)

        with gr.Row():
            image_selector = gr.Dropdown(
                choices=available_images, value=available_images[0], label="Select an Image"
            )
            select_button = gr.Button("Load Image")

        with gr.Row():
            image_display = gr.Image(type="pil", label="Generated Image")

        with gr.Row():
            up = gr.Button("↑ Up")
        with gr.Row():
            left = gr.Button("← Left")
            down = gr.Button("↓ Down")
            right = gr.Button("β†’ Right")

        # Define interactions
        select_button.click(
            fn=handle_image_selection, inputs=image_selector, outputs=image_display
        )
        up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
        down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
        left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
        right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')

    demo.launch(share=True)