import spaces
import os
import gradio as gr
import json
import logging
logging.getLogger("diffusers").setLevel(logging.ERROR)
import diffusers
diffusers.utils.logging.set_verbosity(40)
import warnings
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
from pathlib import Path
from env import (hf_token, hf_read_token, # to use only for private repos
    CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
    HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
    download_model_list, download_lora_list, download_vae_list)
from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
    safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
    get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
    get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai)


def download_things(directory, url, hf_token="", civitai_api_key=""):
    url = url.strip()
    
    if "drive.google.com" in url:
        original_dir = os.getcwd()
        os.chdir(directory)
        os.system(f"gdown --fuzzy {url}")
        os.chdir(original_dir)
    elif "huggingface.co" in url:
        url = url.replace("?download=true", "")
        # url = urllib.parse.quote(url, safe=':/')  # fix encoding
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        user_header = f'"Authorization: Bearer {hf_token}"'
        if hf_token:
            os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}")
        else:
            os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}")
    elif "civitai.com" in url:
        if "?" in url:
            url = url.split("?")[0]
        if civitai_api_key:
            url = url + f"?token={civitai_api_key}"
            os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
        else:
            print("\033[91mYou need an API key to download Civitai models.\033[0m")
    else:
        os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")


def get_model_list(directory_path):
    model_list = []
    valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}

    for filename in os.listdir(directory_path):
        if os.path.splitext(filename)[1] in valid_extensions:
            name_without_extension = os.path.splitext(filename)[0]
            file_path = os.path.join(directory_path, filename)
            # model_list.append((name_without_extension, file_path))
            model_list.append(file_path)
            print('\033[34mFILE: ' + file_path + '\033[0m')
    return model_list


# - **Download Models**
download_model = ", ".join(download_model_list)
# - **Download VAEs**
download_vae = ", ".join(download_vae_list)
# - **Download LoRAs**
download_lora = ", ".join(download_lora_list)

#download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
#download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)

CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
hf_token = os.environ.get("HF_TOKEN")

# Download stuffs
for url in [url.strip() for url in download_model.split(',')]:
    if not os.path.exists(f"./models/{url.split('/')[-1]}"):
        download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
for url in [url.strip() for url in download_vae.split(',')]:
    if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
        download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
for url in [url.strip() for url in download_lora.split(',')]:
    if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
        download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)

lora_model_list = get_lora_model_list()
vae_model_list = get_model_list(directory_vaes)
vae_model_list.insert(0, "None")


def get_t2i_model_info(repo_id: str):
    from huggingface_hub import HfApi
    api = HfApi()
    try:
        if " " in repo_id or not api.repo_exists(repo_id): return ""
        model = api.model_info(repo_id=repo_id)
    except Exception as e:
        print(f"Error: Failed to get {repo_id}'s info. ")
        print(e)
        return ""
    if model.private or model.gated: return ""
    tags = model.tags
    info = []
    url = f"https://huggingface.co/{repo_id}/"
    if not 'diffusers' in tags: return ""
    if 'diffusers:FluxPipeline' in tags:
        info.append("FLUX.1")
    elif 'diffusers:StableDiffusionXLPipeline' in tags:
        info.append("SDXL")
    elif 'diffusers:StableDiffusionPipeline' in tags:
        info.append("SD1.5")
    if model.card_data and model.card_data.tags:
        info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
    info.append(f"DLs: {model.downloads}")
    info.append(f"likes: {model.likes}")
    info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
    md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
    return gr.update(value=md)


private_lora_dict = {"": ["", "", "", "", ""]}
try:
    with open('lora_dict.json', encoding='utf-8') as f:
        d = json.load(f)
        for k, v in d.items():
            private_lora_dict[escape_lora_basename(k)] = v
except Exception:
    pass


private_lora_model_list = get_private_lora_model_lists()
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...}
civitai_lora_last_results = {}  # {"URL to download": {search results}, ...}
all_lora_list = []


def get_all_lora_list():
    global all_lora_list
    loras = get_lora_model_list()
    all_lora_list = loras.copy()
    return loras


def get_all_lora_tupled_list():
    global loras_dict
    models = get_all_lora_list()
    if not models: return []
    tupled_list = []
    for model in models:
        #if not model: continue # to avoid GUI-related bug
        basename = Path(model).stem
        key = to_lora_key(model)
        items = None
        if key in loras_dict.keys():
            items = loras_dict.get(key, None)
        else:
            items = get_civitai_info(model)
            if items != None:
                loras_dict[key] = items
        name = basename
        value = model
        if items and items[2] != "":
            if items[1] == "Pony":
                name = f"{basename} (for {items[1]}🐴, {items[2]})"
            else:
                name = f"{basename} (for {items[1]}, {items[2]})"
        tupled_list.append((name, value))
    return tupled_list


def update_lora_dict(path: str):
    global loras_dict
    key = to_lora_key(path)
    if key in loras_dict.keys(): return
    items = get_civitai_info(path)
    if items == None: return
    loras_dict[key] = items


def download_lora(dl_urls: str):
    global loras_url_to_path_dict
    dl_path = ""
    before = get_local_model_list(directory_loras)
    urls = []
    for url in [url.strip() for url in dl_urls.split(',')]:
        local_path = f"{directory_loras}/{url.split('/')[-1]}"
        if not Path(local_path).exists():
            download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
            urls.append(url)
    after = get_local_model_list(directory_loras)
    new_files = list_sub(after, before)
    for i, file in enumerate(new_files):
        path = Path(file)
        if path.exists():
            new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
            path.resolve().rename(new_path.resolve())
            loras_url_to_path_dict[urls[i]] = str(new_path)
            update_lora_dict(str(new_path))
            dl_path = str(new_path)
    return dl_path


def copy_lora(path: str, new_path: str):
    import shutil
    if path == new_path: return new_path
    cpath = Path(path)
    npath = Path(new_path)
    if cpath.exists():
        try:
            shutil.copy(str(cpath.resolve()), str(npath.resolve()))
        except Exception:
            return None
        update_lora_dict(str(npath))
        return new_path
    else:
        return None


def download_my_lora(dl_urls: str, lora):
    path = download_lora(dl_urls)
    if path: lora = path
    choices = get_all_lora_tupled_list()
    return gr.update(value=lora, choices=choices)


def apply_lora_prompt(lora_info: str):
    if lora_info == "None": return ""
    lora_tag = lora_info.replace("/",",")
    lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
    lora_prompts = normalize_prompt_list(lora_tags)
    prompt = ", ".join(list_uniq(lora_prompts))
    return prompt


def update_loras(prompt, lora, lora_wt):
    on, label, tag, md = get_lora_info(lora)
    choices = get_all_lora_tupled_list()
    return gr.update(value=prompt), gr.update(value=lora, choices=choices), gr.update(value=lora_wt),\
     gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)


def search_civitai_lora(query, base_model):
    global civitai_lora_last_results
    items = search_lora_on_civitai(query, base_model)
    if not items: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    civitai_lora_last_results = {}
    choices = []
    for item in items:
        base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
        name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
        value = item['dl_url']
        choices.append((name, value))
        civitai_lora_last_results[value] = item
    if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    result = civitai_lora_last_results.get(choices[0][1], "None")
    md = result['md'] if result else ""
    return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
          gr.update(visible=True), gr.update(visible=True)


def select_civitai_lora(search_result):
    if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
    result = civitai_lora_last_results.get(search_result, "None")
    md = result['md'] if result else ""
    return gr.update(value=search_result), gr.update(value=md, visible=True)


def search_civitai_lora_json(query, base_model):
    results = {}
    items = search_lora_on_civitai(query, base_model)
    if not items: return gr.update(value=results)
    for item in items:
        results[item['dl_url']] = item
    return gr.update(value=results)