import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image
import cv2
from math import tau
import gradio as gr
from concurrent.futures import ThreadPoolExecutor
import tempfile

def process_image(input_image, img_size, blur_kernel_size, desired_range):
    img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
    img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_AREA)
    imgray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(imgray, (blur_kernel_size, blur_kernel_size), 0)
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    largest_contour_idx = np.argmax([cv2.contourArea(c) for c in contours])
    largest_contour = contours[largest_contour_idx]
    verts = [tuple(coord) for coord in largest_contour.squeeze()]
    xs, ys = np.asarray(list(zip(*verts)))
    x_range, y_range = np.max(xs) - np.min(xs), np.max(ys) - np.min(ys)
    scale_x, scale_y = desired_range / x_range, desired_range / y_range
    xs = (xs - np.mean(xs)) * scale_x
    ys = -(ys - np.mean(ys)) * scale_y

    return xs, ys

def compute_cn(f_exp, n, t_values):
    coef = np.trapz(f_exp * np.exp(-n * t_values * 1j), t_values) / tau
    return coef

def calculate_fourier_coefficients(xs, ys, num_points, coefficients):
    t_list = np.linspace(0, tau, len(xs))
    t_values = np.linspace(0, tau, num_points)
    f_precomputed = np.interp(t_values, t_list, xs + 1j * ys)

    N = coefficients
    indices = [0] + [j for i in range(1, N + 1) for j in (i, -i)]
    with ThreadPoolExecutor(max_workers=2) as executor:
        coefs = list(executor.map(lambda n: (compute_cn(f_precomputed, n, t_values), n), indices))
    
    return coefs

def setup_animation_env(img_size, desired_range, coefficients):
    fig, ax = plt.subplots()
    circles = [ax.plot([], [], 'b-')[0] for _ in range(-coefficients, coefficients + 1)]
    circle_lines = [ax.plot([], [], 'g-')[0] for _ in range(-coefficients, coefficients + 1)]
    drawing, = ax.plot([], [], 'r-', linewidth=2)

    ax.set_xlim(-desired_range, desired_range)
    ax.set_ylim(-desired_range, desired_range)
    ax.set_axis_off()
    ax.set_aspect('equal')
    fig.set_size_inches(15, 15)
    fig.canvas.draw()
    background = fig.canvas.copy_from_bbox(ax.bbox)

    return fig, ax, background, circles, circle_lines, drawing

def animate(frame, coefs, frame_times, fig, ax, background, circles, circle_lines, drawing, draw_x, draw_y, coefs_static, theta):
    fig.canvas.restore_region(background)

    center = (0, 0)
    for idx, (r, fr) in enumerate(coefs_static):
        c_dynamic = coefs[idx][0] * np.exp(1j * (fr * tau * frame_times[frame]))
        x, y = center[0] + r * np.cos(theta[frame]), center[1] + r * np.sin(theta[frame])
        circle_lines[idx].set_data([center[0], center[0] + np.real(c_dynamic)], [center[1], center[1] + np.imag(c_dynamic)])
        circles[idx].set_data([x], [y])
        center = (center[0] + np.real(c_dynamic), center[1] + np.imag(c_dynamic))

    draw_x.append(center[0])
    draw_y.append(center[1])
    drawing.set_data(draw_x, draw_y)

    for circle in circles:
        ax.draw_artist(circle)
    for line in circle_lines:
        ax.draw_artist(line)
    ax.draw_artist(drawing)

    fig.canvas.blit(ax.bbox)
    
    # Convert canvas to PIL Image using buffer_rgba
    #fig.canvas.draw()
    pil_image = Image.frombuffer("RGBA", fig.canvas.get_width_height(), fig.canvas.buffer_rgba(), "raw", "RGBA", 0, 1)

    return (pil_image, None)
    
def fourier_transform_drawing(input_image, frames, coefficients, img_size, blur_kernel_size, desired_range, num_points):
    xs, ys = process_image(input_image, img_size, blur_kernel_size, desired_range)
    coefs = calculate_fourier_coefficients(xs, ys, num_points, coefficients)
    
    # Setup animation environment
    fig, ax, background, circles, circle_lines, drawing = setup_animation_env(img_size, desired_range, coefficients)
    coefs_static = [(np.linalg.norm(c), fr) for c, fr in coefs]
    frame_times = np.linspace(0, 1, num=frames)
    thetas = np.linspace(0, tau, num=frames)
    draw_x, draw_y = [], []

    # Create a temporary file for the video
    with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
        video_path = temp_file.name

    # Generate and save each frame as a PIL image, and ultimately the video
    for frame in range(frames):
        pil_image, _ = animate(frame, coefs, frame_times, fig, ax, background, circles, circle_lines, drawing, draw_x, draw_y, coefs_static, thetas)
        #pil_image = input_image
        yield pil_image, video_path

    # Save the animation as a video
    #anim = animation.FuncAnimation(fig, animate, frames=frames, interval=5, fargs=(coefs, frame_times, fig, ax, background, circles, circle_lines, drawing, draw_x, draw_y, coefs_static, thetas))
    #anim.save(video_path, fps=15)

    yield pil_image, video_path
    
def setup_gradio_interface():
    interface = gr.Interface(
        fn=fourier_transform_drawing,
        inputs=[
            gr.Image(label="Drawing Progress", sources=['upload'], type="pil"),
            gr.Slider(minimum=5, maximum=500, value=100, label="Number of Frames"),
            gr.Slider(minimum=1, maximum=500, value=50, label="Number of Coefficients"),
            gr.Number(value=224, label="Image Size (px)", precision=0),
            gr.Slider(minimum=3, maximum=11, step=2, value=5, label="Blur Kernel Size (odd number)"),
            gr.Number(value=400, label="Desired Range for Scaling", precision=0),
            gr.Number(value=1000, label="Number of Points for Integration", precision=0),
        ],
        outputs=["image", gr.Video()],
        title="Fourier Transform Drawing",
        description="Upload an image and generate a Fourier Transform drawing animation."
    )
    return interface

if __name__ == "__main__":
    interface = setup_gradio_interface()
    interface.queue()
    interface.launch()