import gradio as gr
import collections
import numpy as np
import os
import torch
from safetensors.torch import serialize_file
import requests
import tempfile

def download_file(url, local_path):
    """Download a file from a URL to a local path."""
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(local_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    return local_path

def rename_key(rename, name):
    for k, v in rename.items():
        if k in name:
            name = name.replace(k, v)
    return name

def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
    loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]

    kk = list(loaded.keys())
    version = 4
    for x in kk:
        if "ln_x" in x:
            version = max(5, version)
        if "gate.weight" in x:
            version = max(5.1, version)
        if int(version) == 5 and "att.time_decay" in x:
            if len(loaded[x].shape) > 1:
                if loaded[x].shape[1] > 1:
                    version = max(5.2, version)
        if "time_maa" in x:
            version = max(6, version)

    print(f"Model detected: v{version:.1f}")

    if version == 5.1:
        _, n_emb = loaded["emb.weight"].shape
        for k in kk:
            if "time_decay" in k or "time_faaaa" in k:
                loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])

    with torch.no_grad():
        for k in kk:
            new_k = rename_key(rename, k).lower()
            v = loaded[k].half()
            del loaded[k]
            for transpose_name in transpose_names:
                if transpose_name in new_k:
                    dims = len(v.shape)
                    v = v.transpose(dims - 2, dims - 1)
                    break
            print(f"{new_k}\t{v.shape}\t{v.dtype}")
            loaded[new_k] = {
                "dtype": str(v.dtype).split(".")[-1],
                "shape": v.shape,
                "data": v.numpy().tobytes(),
            }

    os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
    serialize_file(loaded, sf_filename, metadata={"format": "pt"})
    return sf_filename

def process_model(url):
    """Process the model URL and return a downloadable safetensors file."""
    try:
        # Create temporary files
        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth:
            pth_path = temp_pth.name
        with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf:
            sf_path = temp_sf.name

        # Download the .pth file from the URL
        download_file(url, pth_path)

        # Conversion parameters
        rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"}
        transpose_names = [
            "time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
            "w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
            "time_state", "lora.0"
        ]

        # Convert the file
        converted_file = convert_file(pth_path, sf_path, rename, transpose_names)

        # Clean up the temporary .pth file
        os.remove(pth_path)

        return converted_file
    except Exception as e:
        # Clean up temporary files in case of error
        if os.path.exists(pth_path):
            os.remove(pth_path)
        if os.path.exists(sf_path):
            os.remove(sf_path)
        raise gr.Error(f"Error processing the model: {str(e)}")

# Gradio interface
with gr.Blocks(title="PTH to Safetensors Converter") as demo:
    gr.Markdown("# PTH to Safetensors Converter")
    gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.")
    
    url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth")
    convert_btn = gr.Button("Convert")
    output_file = gr.File(label="Download Converted Safetensors File")
    
    convert_btn.click(
        fn=process_model,
        inputs=url_input,
        outputs=output_file
    )

demo.launch()