# HURA (Hexagonal Uniformly Redundant Arrays) are used for aperture masks and imaging, and encoding.
# check it out https://ntrs.nasa.gov/citations/19850026627
# by Surn (Charles Fettinger) 4/5/2025
from PIL import Image
import math
import gradio as gr
from tempfile import NamedTemporaryFile

from transformers.models.deprecated.vit_hybrid import image_processing_vit_hybrid
import utils.constants as constants
import utils.color_utils as color_utils


class HuraConfig:
    """Configuration for Hexagonal Uniformly Redundant Array pattern generation."""
    
    def __init__(self):
        # Core parameters
        self.v = 139  # Prime number parameter (affects pattern complexity)
        self.r = 42   # Pattern frequency parameter
        self.version = "0.2.2"
        
        # Pattern generation constants
        self.hex_ratio = 0.5773503  # sqrt(3)/3
        self.pattern_scale = 21.0   # Controls pattern frequency
        self.vignette_inner = 0.97
        self.vignette_outer = 1.01
        
        # Colors
        self.default_colors = [
            (255, 0, 0),   # Red
            (0, 255, 0),   # Green
            (0, 0, 255)    # Blue
        ]
        
        # Prime number calculation
        self.prime_range_start = 1
        self.prime_range_end = 5001
        self._primes_cache = None  # Will be lazily loaded
    
    def get_v(self):
        """Get the current V parameter value."""
        return self.v
        
    def set_v(self, value):
        """Set the V parameter value."""
        if not isinstance(value, (int, float)) or value < 1:
            raise ValueError(f"V value must be a positive float, got {value}")
        self.v = value
        
    def get_r(self):
        """Get the current R parameter value."""
        return self.r
        
    def set_r(self, value):
        """Set the R parameter value."""
        if not isinstance(value, (int, float)) or value < 1:
            raise ValueError(f"R value must be a positive float, got {value}")
        self.r = value
    
    def get_primes(self):
        """Get or calculate the list of primes in the configured range."""
        if self._primes_cache is None:
            self._primes_cache = get_primes_in_range(self.prime_range_start, self.prime_range_end)
        return self._primes_cache
    
    def find_nearest_prime(self, value):
        """Find the nearest prime number to the given value."""
        return min(self.get_primes(), key=lambda x: abs(x - value))
    
    def reset_colors(self):
        """Reset default colors to original values."""
        self.default_colors = [
            (255, 0, 0),   # Red
            (0, 255, 0),   # Green
            (0, 0, 255)    # Blue
        ]
        return self.default_colors

# Initialize the HuraConfig instance
config = HuraConfig()

# For backwards compatibility - consider deprecating these
__version__ = config.version
_V = config.v
_R = config.r

def get_v():
    return config.get_v()

def set_v(val):
    config.set_v(val)

def get_r():
    return config.get_r()

def set_r(val):
    config.set_r(val)

state_colors = []

def smoothstep(edge0, edge1, x):
    """
    Smoothstep function for vignette effect.
    Smoothly interpolate between edge0 and edge1 based on x.
    """
    if edge0 == edge1:
        return 0.0 if x < edge0 else 1.0
    t = min(max((x - edge0) / (edge1 - edge0), 0.0), 1.0)
    return t * t * (3 - 2 * t)

# Define the hexagon function to compute coordinates
def hexagon(p):
    """
    Compute hexagon coordinates and metrics for point p.
    
    Args:
        p (tuple): Normalized point (x,y) in [-aspect,aspect] � [-1,1] range
        
    Returns:
        tuple: (hex_x, hex_y, edge_distance, center_distance)
            - hex_x, hex_y: Integer coordinates of the hexagon cell
            - edge_distance: Distance to nearest edge (0-1)
            - center_distance: Distance to cell center (0-1)
    """
    # Transform to hexagonal coordinate system
    q = (p[0] * 2.0 * config.hex_ratio, p[1] + p[0] * config.hex_ratio)
    pi = (math.floor(q[0]), math.floor(q[1]))
    pf = (q[0] - pi[0], q[1] - pi[1])
    mod_val = (pi[0] + pi[1]) % 3.0  # renamed from v
    ca = 1.0 if mod_val >= 1.0 else 0.0
    cb = 1.0 if mod_val >= 2.0 else 0.0
    ma = (1.0 if pf[1] >= pf[0] else 0.0, 1.0 if pf[0] >= pf[1] else 0.0)
    temp = (
        1.0 - pf[1] + ca * (pf[0] + pf[1] - 1.0) + cb * (pf[1] - 2.0 * pf[0]),
        1.0 - pf[0] + ca * (pf[0] + pf[1] - 1.0) + cb * (pf[0] - 2.0 * pf[1])
    )
    e = ma[0] * temp[0] + ma[1] * temp[1]
    p2_x = (q[0] + math.floor(0.5 + p[1] / 1.5)) * 0.5 + 0.5
    p2_y = (4.0 * p[1] / 3.0) * 0.5 + 0.5
    fract_p2 = (p2_x - math.floor(p2_x), p2_y - math.floor(p2_y))
    f = math.sqrt((fract_p2[0] - 0.5)**2 + ((fract_p2[1] - 0.5) * 0.85)**2)
    h_xy = (pi[0] + ca - cb * ma[0], pi[1] + ca - cb * ma[1])
    return (h_xy[0], h_xy[1], e, f)

# important note: this is not a true hexagonal pattern, but a hexagonal grid
def ura(p):
    """
    Generate binary pattern value based on Uniformly Redundant Array algorithm.
    
    Args:
        p (tuple): Hexagon coordinates (x,y)
        
    Returns:
        float: 1.0 for pattern, 0.0 for background

    future consideration.. add animation
    #ifdef INCREMENT_R
        float l = mod(p.y + floor(time*1.5)*p.x, v);
    #else
        float l = mod(p.y + r*p.x, v);
    """    
    r = get_r()
    v = get_v()
    l = math.fmod(abs(p[1]) + r * abs(p[0]), v)
    rz = 1.0
    for i in range(1, int(v/2) + 1):
        if math.isclose(math.fmod(i * i, v), l, abs_tol=1e-6):
            rz = 0.0
            break
    return rz

# Generate the image with colorful_hexagonal pattern
def generate_image_color(width, height, colors=None):
    """Generate an RGB image with a colorful hexagonal pattern."""
    img = Image.new('RGB', (width, height))
    if colors is None or colors == []:
        colors = config.default_colors
    r = config.get_r()
    v = config.get_v()
    aspect = width / height
    for j in range(height):
        for i in range(width):
            # Normalize pixel coordinates to [0, 1]
            q_x = i / width
            q_y = j / height
            # Transform to centered coordinates with aspect ratio
            p_x = (q_x * 2.0 - 1.0) * aspect
            p_y = q_y * 2.0 - 1.0
            p = (p_x, p_y)
            # Scale coordinates for pattern frequency
            h = hexagon((p[0] * config.pattern_scale, p[1] * config.pattern_scale))
            h_xy = (int(h[0]), int(h[1]))
            # Assign color based on hexagon coordinates
            rz = math.fmod(abs(h_xy[0]) + r * abs(h_xy[1]),v) 
            color_index = int(rz % len(colors))
            col = colors[color_index]
            # Apply vignette effect
            q = (q_x * 2.0 - 1.0, q_y * 2.0 - 1.0)
            vignette = smoothstep(config.vignette_outer, config.vignette_inner, max(abs(q[0]), abs(q[1])))
            col = tuple(int(c * vignette) for c in col)
            # Set the pixel color
            img.putpixel((i, j), col)
    return img

def generate_image_grayscale(width, height):
    img = Image.new('RGB', (width, height))
    aspect = width / height
    for j in range(height):
        for i in range(width):
            q_x = i / width
            q_y = j / height
            p_x = (q_x * 2.0 - 1.0) * aspect
            p_y = q_y * 2.0 - 1.0
            p = (p_x, p_y)
            h = hexagon((p[0] * config.pattern_scale, p[1] * config.pattern_scale))
            rz = ura(h[:2])
            smooth = smoothstep(-0.2, 0.13, h[2])
            if rz > 0.5:
                col = smooth
            else:
                col = 1.0 - smooth
            q = (q_x * 2.0 - 1.0, q_y * 2.0 - 1.0)
            vignette = smoothstep(config.vignette_outer, config.vignette_inner, max(abs(q[0]), abs(q[1])))
            col *= vignette
            color = int(abs(col) * 255)
            img.putpixel((i, j), (color, color, color))
    return img

def get_primes_in_range(start: int, end: int) -> list:
    """
    Return a list of prime numbers between start and end (inclusive).

    Uses the Sieve of Eratosthenes for efficiency.

    Parameters:
        start (int): The starting number of the range.
        end (int): The ending number of the range.

    Returns:
        list: A list of prime numbers between start and end.
    """
    if end < 2:
        return []    
    sieve = [True] * (end + 1)
    sieve[0] = sieve[1] = False
    for i in range(2, int(end ** 0.5) + 1):
        if sieve[i]:
            for j in range(i * i, end + 1, i):
                sieve[j] = False
    return [i for i in range(start, end + 1) if sieve[i]]

def find_nearest_prime(value):
    """Find the closest prime number to the given value."""
    return config.find_nearest_prime(value)

def generate_pattern_background(pattern_type="color", width=1024, height=768, v_value=_V, r_value=_R, colors=None):
    # Generate a hexagonal pattern image with the given parameters.
    # Do not pass gr.State values here
    # Set the parameters
    set_v(v_value)
    set_r(r_value)
    print(f"Generating pattern with V: {v_value}, R: {r_value}, Colors: {colors}")
    color_count = 3
    
    if pattern_type == "color":
        if colors is None:
            img = generate_image_color(width, height)
        else:
            img = generate_image_color(width, height, colors)
            color_count = len(colors)
    else:  # grayscale
        img = generate_image_grayscale(width, height)
        color_count = 1
    
    # Save to temporary file and return path
    with NamedTemporaryFile(delete=False,prefix=f"hura_{str(color_count)}_v{str(v_value)}_r{str(r_value)}_", suffix=".png") as tmp:
        img.save(tmp.name, format="PNG")
        constants.temp_files.append(tmp.name)
        return tmp.name


def create_color_swatch_html(colors):
    """Create HTML for displaying color swatches"""
    swatches = ''.join(
        f'<div style="width: 50px; height: 50px; background-color: rgb{c}; '
        f'border: 1px solid #ccc;"></div>' 
        for c in colors
    )
    return f'<div style="display: flex; gap: 10px;">{swatches}</div>'

def _add_color(color, color_list):
    if color is None:
        return color_list, color_list, ""
    
    # Convert hex color to RGB
    rgb_color = color_utils.hex_to_rgb(color)
    color_list = color_list + [rgb_color]
    
    # Create HTML to display color swatches
    html = create_color_swatch_html(color_list)
    return color_list, html

def _init_colors():
    """Initialize the color swatches HTML display based on config colors"""
    updated_list = list(config.default_colors)
    # Rebuild the HTML swatches from the updated list
    html = create_color_swatch_html(updated_list)
    return html

def reset_colors():
    """Reset the color list to the default colors."""
    colors = config.reset_colors()
    html = create_color_swatch_html(colors)
    return colors, html

def _generate_pattern_from_state(pt, width, height, v_val, r_val, colors_list):
    # colors_list is automatically the raw value from the gr.State input
    return generate_pattern_background(pt, width, height, v_val, r_val, colors_list)

def render() -> dict:
    """
    Renders a colorful or grayscale hexagonal pattern creation interface
    
    Returns:
        dict: A dictionary containing:
            - target_image (gr.Image): The generated pattern image component
            - run_generate_hex_pattern (function): Function to generate a pattern with given dimensions
            - set_height_width_hura_image (function): Function to update the slider values
            - width_slider (gr.Slider): The width slider component
            - height_slider (gr.Slider): The height slider component
    """
    
    # Initialize state
    global state_colors
    state_colors = gr.State(config.default_colors)
    init_colors_html = _init_colors()
    

    target_image = gr.Image(label="Generated Pattern", type="filepath")
    with gr.Row():
        pattern_type = gr.Radio(
            label="Pattern Type",
            choices=["color", "grayscale"],
            value="grayscale",
            type="value"
        )        
        with gr.Column():
            with gr.Row():
                width_slider = gr.Slider(minimum=256, maximum=2560, value=1024, label="Width", step=8)
                height_slider = gr.Slider(minimum=256, maximum=2560, value=768, label="Height", step=8)
            v_value_slider = gr.Slider(minimum=config.prime_range_start, maximum=config.prime_range_end, value=config.v, label="V Value (Prime Number)", step=1)
            r_value_slider = gr.Slider(minimum=1, maximum=100, value=config.r, label="R Value")
        show_borders_chbox = gr.Checkbox(label="Show Borders", value=True)
    
    with gr.Row(visible=False) as color_row:
        color_picker = gr.ColorPicker(label="Pick a Color")
        add_button = gr.Button("Add Color")
        with gr.Column():
            color_display = gr.HTML(label="Color Swatches", value=init_colors_html)
        with gr.Row():
            delete_colors_button = gr.Button("Delete Colors")
            reset_colors_button = gr.Button("Reset Colors")
    with gr.Row():
        generate_button = gr.Button("Generate Pattern")


    def run_generate_hex_pattern(width: int, height: int) -> str:
        """
        Generate a colored hexagonal pattern image with the given width and height.
        Uses default V and R values and the default color palette.
    
        Returns:
            str: The filepath of the generated image.
        """
        global state_colors
        width_slider.value=width
        height_slider.value=height
        gr.update()
        # Use the current _V, _R, and default_colors
        filepath = generate_pattern_background(
            pattern_type="color",
            width=width,
            height=height,
            v_value=get_v(),
            r_value=get_r(),
            colors=state_colors.value
        )
        return filepath

    pattern_type.change(
        fn=lambda x: gr.update(visible=(x == "color")),
        inputs=pattern_type,
        outputs=color_row
    )    
    add_button.click(
        fn=_add_color,
        inputs=[color_picker, state_colors],
        outputs=[state_colors, color_display]
    )
    delete_colors_button.click(
        fn=lambda x: ([], "<div>Add Colors</div>"),
        inputs=[],
        outputs=[state_colors, color_display]
    )
    reset_colors_button.click(
        fn=reset_colors,
        inputs=[],
        outputs=[state_colors,color_display]
    )
    generate_button.click(
        fn=_generate_pattern_from_state,
        inputs=[pattern_type, width_slider, height_slider, v_value_slider, r_value_slider, state_colors],
        outputs=target_image, scroll_to_output=True
    )

    v_value_slider.input(
        lambda x: config.find_nearest_prime(x),
        inputs=v_value_slider,
        outputs=v_value_slider
    )
    v_value_slider.release(
        lambda x: config.find_nearest_prime(x),
        inputs=v_value_slider,
        outputs=v_value_slider, queue=False
    )

    return {
        "target_image": target_image,
        "run_generate_hex_pattern": run_generate_hex_pattern,
        "width_slider": width_slider,
        "height_slider": height_slider
    }