File size: 3,789 Bytes
246c106
bce7a55
 
14420e9
 
246c106
 
 
 
14420e9
 
 
 
 
 
 
246c106
 
14420e9
 
 
 
 
 
 
aca205f
 
14420e9
 
aca205f
 
 
 
14420e9
 
246c106
aca205f
246c106
 
 
 
 
 
 
 
 
 
aca205f
246c106
 
 
aca205f
246c106
aca205f
246c106
 
aca205f
14420e9
aca205f
14420e9
246c106
 
aca205f
 
 
 
 
 
 
 
 
 
 
 
 
246c106
14420e9
 
 
 
 
 
 
 
246c106
 
 
 
 
 
 
14420e9
aca205f
14420e9
aca205f
 
 
 
246c106
aca205f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import spaces


import gradio as gr
import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator
import os

if not os.path.exists("data/mar_ckpt/langtable"):
    # download from google drive
    import gdown
    gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
    os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")

RES = 512
PROMPT_HORIZON = 3
IMAGE_DIR = "sim/assets/langtable_prompt/"

# Load available images
available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])



def initialize_simulator(image_name, state):
    image_path = os.path.join(IMAGE_DIR, image_name)
    image = Image.open(image_path)
    prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
    prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
    state['genie'].set_initial_state((prompt_image, prompt_action))
    reset_image = state['genie'].reset()
    reset_image = cv2.resize(reset_image, (RES, RES))
    return Image.fromarray(reset_image)

def model(direction, state):
    if direction == 'right':
        action = np.array([0, 0.05])
    elif direction == 'left':
        action = np.array([0, -0.05])
    elif direction == 'down':
        action = np.array([0.05, 0])
    elif direction == 'up':
        action = np.array([-0.05, 0])
    else:
        raise ValueError(f"Invalid direction: {direction}")
    next_image = state['genie'].step(action)['pred_next_frame']
    next_image = cv2.resize(next_image, (RES, RES))
    return Image.fromarray(next_image)

def handle_input(direction, state):
    print(f"User clicked: {direction}")
    new_image = model(direction, state)
    return new_image

def handle_image_selection(image_name, state):
    print(f"User selected image: {image_name}")
    return initialize_simulator(image_name, state)

if __name__ == '__main__':
    with gr.Blocks() as demo:
        genie_instance = gr.State({
            'genie': GenieSimulator(
                image_encoder_type='temporalvae',
                image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
                quantize=False,
                backbone_type='stmar',
                backbone_ckpt='data/mar_ckpt/langtable',
                prompt_horizon=PROMPT_HORIZON,
                action_stride=1,
                domain='language_table',
            )
        })

        with gr.Row():
            image_selector = gr.Dropdown(
                choices=available_images, value=available_images[0], label="Select an Image"
            )
            select_button = gr.Button("Load Image")

        with gr.Row():
            image_display = gr.Image(type="pil", label="Generated Image")

        with gr.Row():
            up = gr.Button("↑ Up")
        with gr.Row():
            left = gr.Button("← Left")
            down = gr.Button("↓ Down")
            right = gr.Button("β†’ Right")

        select_button.click(
            fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
        )
        up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
        down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
        left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
        right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')

    demo.launch()