import gradio as gr
import spaces

import numpy as np
from PIL import Image
import cv2
from sim.simulator import GenieSimulator

RES = 512
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
genie = GenieSimulator(
    image_encoder_type='temporalvae',
    image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
    quantize=False,
    backbone_type='stmar',
    backbone_ckpt='data/mar_ckpt/langtable',
    prompt_horizon=11,
    action_stride=1,
    domain='language_table',
)
prompt_image = np.tile(
    np.array(image), (genie.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
prompt_action = np.zeros(
    (genie.prompt_horizon, genie.action_stride, 2)
).astype(np.float32)
genie.set_initial_state((prompt_image, prompt_action))
image = genie.reset()
image = cv2.resize(image, (RES, RES))
image = Image.fromarray(image)

# Example model: takes a direction and returns a random image
def model(direction: str, genie=genie):
    if direction == 'right':
        action = np.array([0, 0.05])
    elif direction == 'left':
        action = np.array([0, -0.05])
    elif direction == 'down':
        action = np.array([0.05, 0])
    elif direction == 'up':
        action = np.array([-0.05, 0])
    else:
        raise ValueError(f"Invalid direction: {direction}")
    next_image = genie.step(action)['pred_next_frame']
    next_image = cv2.resize(next_image, (RES, RES))
    return Image.fromarray(next_image)

# Gradio function to handle user input
@spaces.GPU
def handle_input(direction):
    print(f"User clicked: {direction}")
    new_image = model(direction)  # Get a new image from the model
    return new_image

if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            image_display = gr.Image(value=image, type="pil", label="Generated Image")
        with gr.Row():
            up = gr.Button("↑ Up")
        with gr.Row():
            left = gr.Button("← Left")
            down = gr.Button("↓ Down")
            right = gr.Button("→ Right")

        # Define button interactions
        up.click(fn=lambda: handle_input("up"), outputs=image_display)
        down.click(fn=lambda: handle_input("down"), outputs=image_display)
        left.click(fn=lambda: handle_input("left"), outputs=image_display)
        right.click(fn=lambda: handle_input("right"), outputs=image_display)

    demo.launch()