import os
import torch
import psutil
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from huggingface_hub import login, create_repo, HfApi
import gradio as gr
import queue
import time
import shutil


# 全局日志
log_queue = queue.Queue()
current_logs = []

def log(msg):
    """追加并打印日志信息"""
    print(msg)
    current_logs.append(msg)
    return "\n".join(current_logs)

def timeit(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        log(f"{func.__name__}: {end_time - start_time:.2f} s")
        return result   
    return wrapper

@timeit
def get_model_size_in_gb(model_name):
    """通过 Hugging Face Hub 元数据估算模型大小(GB)"""
    try:
        api = HfApi()
        model_info = api.model_info(model_name)
        # 使用 safetensors 大小(不假定文件扩展名)
        return model_info.safetensors.total / (1024 ** 3)
    except Exception as e:
        log(f"Unable to estimate model size: {e}")
        return 1  # 默认值

@timeit
def check_system_resources(model_name):
    """检查系统资源,决定使用 CPU 或 GPU"""
    log("Checking system resources...")
    system_memory = psutil.virtual_memory()
    total_memory_gb = system_memory.total / (1024 ** 3)
    log(f"Total system memory: {total_memory_gb:.1f}GB")
    
    model_size_gb = get_model_size_in_gb(model_name)
    required_memory_gb = model_size_gb * 2.5  # 预留额外内存
    log(f"Estimated required memory for model: {required_memory_gb:.1f}GB")
    
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
        log(f"Detected GPU: {gpu_name} with {gpu_memory_gb:.1f}GB memory")
        if gpu_memory_gb >= required_memory_gb:
            log("✅ Sufficient GPU memory available; using GPU.")
            return "cuda", gpu_memory_gb
        else:
            log(f"⚠️ Insufficient GPU memory (requires {required_memory_gb:.1f}GB, found {gpu_memory_gb:.1f}GB).")
    else:
        log("❌ No GPU detected.")
        
    if total_memory_gb >= required_memory_gb:
        log("✅ Sufficient CPU memory available; using CPU.")
        return "cpu", total_memory_gb
    else:
        raise MemoryError(f"❌ Insufficient system memory (requires {required_memory_gb:.1f}GB, available {available_memory_gb:.1f}GB).")

@timeit
def setup_environment(model_name):
    """选择模型转换时使用的设备"""
    try:
        device, _ = check_system_resources(model_name)
    except Exception as e:
        log(f"Resource check failed: {e}. Defaulting to CPU.")
        device = "cpu"
    return device

@timeit
def create_hf_repo(repo_name, private=True):
    """创建 Hugging Face 仓库(如果不存在的话)"""
    try:
        api = HfApi()
        # 如果仓库已存在,则尝试附加索引直到名称可用
        if api.repo_exists(repo_name):
            retry_index = 0
            repo_name_with_index = repo_name
            while api.repo_exists(repo_name_with_index):
                retry_index += 1
                log(f"Repository {repo_name_with_index} exists; trying {repo_name}_{retry_index}")
                repo_name_with_index = f"{repo_name}_{retry_index}"
            repo_name = repo_name_with_index
        repo_url = create_repo(repo_name, private=private)
        log(f"Repository created successfully: {repo_url}")
        return repo_name
    except Exception as e:
        log(f"Failed to create repository: {e}")
        raise

@timeit
def download_and_merge_model(base_model_name, lora_model_name, output_dir, device):
    """
    1. 先加载 adapter 的 tokenizer 获取其词表大小
    2. 加载 base tokenizer 用于后续合并词表
    3. 加载 base 模型,并将嵌入层调整至 adapter 词表大小
    4. 使用高层 API 加载 LoRA adapter 并合并其权重
    5. 求 base 与 adapter tokenizer 的词表并取并集,扩展 tokenizer
    6. 调整合并模型嵌入层尺寸并保存
    """
    log("Loading base model...")
    model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True)
    log("Loading adapter tokenizer...")
    adapter_tokenizer = AutoTokenizer.from_pretrained(lora_model_name)
    log("Resizing token embeddings...")
    added_tokens_decoder = adapter_tokenizer.added_tokens_decoder
    model.resize_token_embeddings(adapter_tokenizer.vocab_size + len(added_tokens_decoder))
    log("Loading LoRA adapter...")
    peft_model = PeftModel.from_pretrained(model, lora_model_name, low_cpu_mem_usage=True)
    log("Merging and unloading model...")
    model = peft_model.merge_and_unload()
    log("Saving model...")
    model.save_pretrained(output_dir)
    adapter_tokenizer.save_pretrained(output_dir)
    return output_dir

@timeit
def clone_llamacpp_and_download_build():
    """克隆 llama.cpp 并下载最新构建"""
    llamacpp_repo = "https://github.com/ggerganov/llama.cpp.git"
    llamacpp_dir = os.path.join(os.getcwd(), "llama.cpp")
    
    if not os.path.exists(llamacpp_dir):
        log(f"Cloning llama.cpp from {llamacpp_repo}...")
        os.system(f"git clone {llamacpp_repo} {llamacpp_dir}")
    
    log("Building llama.cpp...")
    build_dir = os.path.join(llamacpp_dir, "build")
    os.makedirs(build_dir, exist_ok=True)
    
    """
    cmake -B build
    cmake --build build --config Release
    """
    
    # 进入构建目录并执行 cmake 和 make
    os.chdir(build_dir)
    os.system("cmake -B build")
    os.system("cmake --build build --config Release")
    
    log("llama.cpp build completed.")
    # 返回到原始目录
    os.chdir(os.path.dirname(llamacpp_dir))

def remove_illegal_chars_in_path(text):
    return text.replace(".", "_").replace(":", "_").replace("/", "_")

@timeit
def quantize(model_path, repo_id, quant_method=None):
    """
    利用 llama-cpp-python 对模型进行量化,并上传到 Hugging Face Hub。
    使用的量化预设:
      - 8-bit:  Q8_0
      - 4-bit:  Q4_K_M 或 Q4_K_L
      - 2-bit:  Q2_K_L
    模型输入(model_path)应为全精度(例如 fp16)的 GGUF 文件,
    输出文件将保存为 <model_path>_q{bits}_{quant_method}
    """
    # 使用llama.cpp的转换工具
    llamacpp_dir = os.path.join(os.getcwd(), "llama.cpp")
    if not os.path.exists(llamacpp_dir):
        clone_llamacpp_and_download_build()

    # 确保 model_output 目录存在
    model_output_dir = f"{model_path}/quantized/"
    os.makedirs(model_output_dir, exist_ok=True)

    # 中间文件保存在 model_output 目录下
    guff_16 = os.path.join(model_output_dir, f"{repo_id}-f16.gguf")
    
    if not os.path.exists(guff_16):
        log(f"正在将模型转换为GGML格式")
        convert_script = os.path.join(llamacpp_dir, "convert_hf_to_gguf.py")
        convert_cmd = f"python {convert_script} {model_path} --outfile {guff_16}"
        print(f"syscall:[{convert_cmd}]")
        os.system(convert_cmd)
    else:
        log(f"GGML中间文件已存在,跳过转换")

    # 最终文件保存在 model_output 目录下
    final_path = os.path.join(model_output_dir, f"{repo_id}-{quant_method}.gguf")
    log(f"正在进行{quant_method}量化")
    quantize_bin = os.path.join(llamacpp_dir, "build", "bin", "llama-quantize") 
    quant_cmd = f"{quantize_bin} {guff_16} {final_path} {quant_method}"
    print(f"syscall:[{quant_cmd}]")
    
    if not os.path.exists(final_path):
        os.system(quant_cmd)
    else:
        log(f"{quant_method}量化文件已存在,跳过量化")
        return None
    
    return final_path

def create_readme(repo_name, base_model_name, lora_model_name, quant_methods):
    readme_path = os.path.join("output", repo_name, "README.md")
    readme_template = """---tags:
- autotrain
- text-generation-inference
- text-generation
- peft{quantization}
library_name: transformers
base_model: {base_model_name}
widget:
- messages:
    - role: user
      content: What is your favorite condiment?
license: other
datasets:
- {lora_model_name}
---
# Model

{repo_name}

## Details:
- base_model: {base_model_name}
- lora_model: {lora_model_name}
- quant_methods: {quant_methods}
- created_at: {created_at}
- created_by: [Steven10429/apply_lora_and_quantize](https://github.com/Steven10429/apply_lora_and_quantize)

""".format(
        quantization="\n- quantization" if len(quant_methods) > 0 else "",
        base_model_name=base_model_name,
        lora_model_name=lora_model_name,
        repo_name=repo_name,
        quant_methods=quant_methods,
        created_at=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
    )
    
    with open(readme_path, "w") as f:
        f.write(readme_template)

@timeit
def process_model(base_model_name, lora_model_name, repo_name, quant_methods, hf_token):
    """
    主处理函数:
      1. 登录并(必要时)创建 Hugging Face 仓库;
      2. 设置设备;
      3. 下载并合并 base 模型与 LoRA adapter;
      4. 异步上传合并后的模型;
      5. 同时启动四个量化任务(8-bit、2-bit、4-bit 两种模式);
      6. 最后统一等待所有 Future 完成,再返回日志。
    """
    try:
        current_logs.clear()
        if hf_token.strip().lower() == "auto":
            hf_token = os.getenv("HF_TOKEN")
        elif hf_token.startswith("hf_"):
            os.environ["HF_TOKEN"] = hf_token
        login(hf_token)
        api = HfApi(token=hf_token)
        username = api.whoami()["name"]
        
        if base_model_name.strip().lower() == "auto":
            adapter_config = PeftConfig.from_pretrained(lora_model_name)
            base_model_name = adapter_config.base_model_name_or_path
        if repo_name.strip().lower() == "auto":
            repo_name = f"{base_model_name.split('/')[-1]}_{lora_model_name.split('/')[-1]}"
            repo_name = remove_illegal_chars_in_path(repo_name)
        
        device = setup_environment(base_model_name)
        repo_name = create_hf_repo(repo_name)
        
        output_dir = os.path.join(".", "output", repo_name)
        log("Starting model merge process...")
        model_path = download_and_merge_model(base_model_name, lora_model_name, output_dir, device)
        
        
        # 量化模型
        for quant_method in quant_methods:
            quantize(output_dir, repo_name, quant_method=quant_method)
            
        create_readme(repo_name, base_model_name, lora_model_name, quant_methods)
        
        # 上传合并后的模型和量化模型
        api.upload_large_folder(
            folder_path=model_path,
            repo_id=repo_name,
            repo_type="model",
            num_workers=os.cpu_count() if os.cpu_count() > 4 else 4,
            print_report_every=10,
        )
        log("Upload completed.")
        
        # rm -rf model_path
        shutil.rmtree(model_path)
        log("Removed model from local")
        
        return "\n".join(current_logs)
    except Exception as e:
        error_message = f"Error during processing: {e}"
        log(error_message)
        raise e
        return "\n".join(current_logs)

@timeit
def create_ui():
    """创建 Gradio 界面,仅展示日志"""
    with gr.Blocks(title="Model Merge & Quantization Tool") as app:
        gr.Markdown("""
        # 🤗 Model Merge and Quantization Tool
        
        This tool merges a base model with a LoRA adapter, creates 8-bit, 4-bit and 2-bit quantized versions
        (using guff's quantization: Q8_0, Q2_K_L, Q4_K_M, Q4_K_L), and uploads them to the Hugging Face Hub.
        """)
        with gr.Row():
            with gr.Column():
                base_model = gr.Textbox(
                    label="Base Model Path",
                    placeholder="e.g., Qwen/Qwen2.5-14B-Instruct",
                    value="Auto"
                )
                lora_model = gr.Textbox(
                    label="LoRA Model Path",
                    placeholder="Enter the path to your LoRA model"
                )
                repo_name = gr.Textbox(
                    label="Hugging Face Repository Name",
                    placeholder="Enter the repository name to create",
                    value="Auto"
                )
                quant_method = gr.CheckboxGroup(
                    choices=["Q2_K", "Q4_K", "IQ4_NL", "Q5_K_M", "Q6_K", "Q8_0"],
                    value=["Q4_K", "Q8_0"],
                    label="Quantization Method"
                )
                hf_token = gr.Textbox(
                    label="Hugging Face Token",
                    placeholder="Enter your Hugging Face Token",
                    value="Auto"
                )
                convert_btn = gr.Button("Start Conversion", variant="primary")
            with gr.Column():
                output = gr.TextArea(
                    label="Logs",
                    placeholder="Processing logs will appear here...",
                    interactive=False,
                    autoscroll=True,
                    lines=20
                )
        convert_btn.click(
            fn=process_model,
            inputs=[base_model, lora_model, repo_name, quant_method, hf_token],
            outputs=output
        )
    return app


if __name__ == "__main__":
    app = create_ui()
    app.queue()
    app.launch()