import gradio as gr
import json
from pathlib import Path

from huggingface_hub import hf_hub_download, HfApi
from coremltools import ComputeUnit
from transformers.onnx.utils import get_preprocessor

from exporters.coreml import export
from exporters.coreml.features import FeaturesManager
from exporters.coreml.validate import validate_model_outputs

compute_units_mapping = {
    "All": ComputeUnit.ALL,
    "CPU": ComputeUnit.CPU_ONLY,
    "CPU + GPU": ComputeUnit.CPU_AND_GPU,
    "CPU + NE": ComputeUnit.CPU_AND_NE,
}
compute_units_labels = list(compute_units_mapping.keys())

framework_mapping = {
    "PyTorch": "pt",
    "TensorFlow": "tf",
}
framework_labels = list(framework_mapping.keys())

precision_mapping = {
    "Float32": "float32",
    "Float16 quantization": "float16",
}
precision_labels = list(precision_mapping.keys())

tolerance_mapping = {
    "Model default": None,
    "1e-2": 1e-2,
    "1e-3": 1e-3,
    "1e-4": 1e-4,
}
tolerance_labels = list(tolerance_mapping.keys())

def error_str(error, title="Error"):
    return f"""#### {title}
            {error}"""  if error else ""

def url_to_model_id(model_id_str):
    if not model_id_str.startswith("https://huggingface.co/"): return model_id_str
    return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1]

def supported_frameworks(model_id):
    """
    Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
    Only PyTorch and Tensorflow are supported.
    """
    api = HfApi()
    model_info = api.model_info(model_id)
    tags = model_info.tags
    frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
    return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks])

def on_model_change(model):
    model = url_to_model_id(model)    
    tasks = None
    error = None

    try:
        config_file = hf_hub_download(model, filename="config.json")
        if config_file is None:
            raise Exception(f"Model {model} not found")

        with open(config_file, "r") as f:
            config_json = f.read()

        config = json.loads(config_json)
        model_type = config["model_type"]

        features = FeaturesManager.get_supported_features_for_model_type(model_type)
        tasks = list(features.keys())

        frameworks = supported_frameworks(model)
        selected_framework = frameworks[0] if len(frameworks) > 0 else None
        return (
            gr.update(visible=bool(model_type)),                                                    # Settings column
            gr.update(choices=tasks, value=tasks[0] if tasks else None),                            # Tasks
            gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework),     # Frameworks
            gr.update(value=error_str(error)),                                                      # Error
        )
    except Exception as e:
        error = e
        model_type = None


def convert_model(preprocessor, model, model_coreml_config,
                  compute_units, precision, tolerance, output,
                  use_past=False, seq2seq=None,
                  progress=None, progress_start=0.1, progress_end=0.8):
    coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq)

    model_label = "model" if seq2seq is None else seq2seq
    progress(progress_start, desc=f"Converting {model_label}")
    mlmodel = export(
        preprocessor,
        model,
        coreml_config,
        quantize=precision,
        compute_units=compute_units,
    )

    filename = output
    if seq2seq == "encoder":
        filename = filename.parent / ("encoder_" + filename.name)
    elif seq2seq == "decoder":
        filename = filename.parent / ("decoder_" + filename.name)
    filename = filename.as_posix()

    mlmodel.save(filename)

    progress(progress_end * 0.8, desc=f"Validating {model_label}")
    if tolerance is None:
        tolerance = coreml_config.atol_for_validation
    validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance)
    progress(progress_end, desc=f"Done converting {model_label}")


def convert(model, task, compute_units, precision, tolerance, framework, progress=gr.Progress()):
    model = url_to_model_id(model)
    compute_units = compute_units_mapping[compute_units]
    precision = precision_mapping[precision]
    tolerance = tolerance_mapping[tolerance]
    framework = framework_mapping[framework]

    # TODO: support legacy format
    output = Path("exported")/model/"coreml"/task
    output.mkdir(parents=True, exist_ok=True)
    output = output/f"{precision}_model.mlpackage"

    try:
        progress(0, desc="Downloading model")

        preprocessor = get_preprocessor(model)
        model = FeaturesManager.get_model_from_feature(task, model, framework=framework)
        _, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task)

        if task in ["seq2seq-lm", "speech-seq2seq"]:
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                seq2seq="encoder",
                progress=progress,
                progress_start=0.1,
                progress_end=0.45,
            )
            progress(0.6, desc="Converting decoder")
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                seq2seq="decoder",
                progress=progress,
                progress_start=0.45,
                progress_end=0.8,
            )
        else:
            convert_model(
                preprocessor,
                model,
                model_coreml_config,
                compute_units,
                precision,
                tolerance,
                output,
                progress=progress,
                progress_end=0.8,
            )

        # TODO: push to hub, whatever
        progress(1, "Done")
        return "Done"
    except Exception as e:
        return error_str(e)

DESCRIPTION = """
## Convert a transformers model to Core ML

With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood.

Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters).

After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights.
"""

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1. Load model info")
            input_model = gr.Textbox(
                max_lines=1,
                label="Model name or URL, such as apple/mobilevit-small",
                placeholder="distilbert-base-uncased",
                value="distilbert-base-uncased",
            )
            btn_get_tasks = gr.Button("Load")
        with gr.Column(scale=3):
            with gr.Column(visible=False) as group_settings:
                gr.Markdown("## 2. Select Task")
                radio_tasks = gr.Radio(label="Choose the task for the converted model.")
                gr.Markdown("The `default` task is suitable for feature extraction.")
                radio_framework = gr.Radio(
                    visible=False,
                    label="Framework",
                    choices=framework_labels,
                    value=framework_labels[0],
                )
                radio_compute = gr.Radio(
                    label="Compute Units",
                    choices=compute_units_labels,
                    value=compute_units_labels[0],
                )
                radio_precision = gr.Radio(
                    label="Precision",
                    choices=precision_labels,
                    value=precision_labels[0],
                )
                radio_tolerance = gr.Radio(
                    label="Absolute Tolerance for Validation",
                    choices=tolerance_labels,
                    value=tolerance_labels[0],
                )
                btn_convert = gr.Button("Convert")
                gr.Markdown("Conversion will take a few minutes.")


    error_output = gr.Markdown(label="Output")

    btn_get_tasks.click(
        fn=on_model_change,
        inputs=input_model,
        outputs=[group_settings, radio_tasks, radio_framework, error_output],
        queue=False,
        scroll_to_output=True
    )
        
    btn_convert.click(
        fn=convert,
        inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework],
        outputs=error_output,
        scroll_to_output=True
    )

    # gr.HTML("""
    # <div style="border-top: 1px solid #303030;">
    #   <br>
    #   <p>Footer</p><br>
    #   <p><img src="https://visitor-badge.glitch.me/badge?page_id=pcuenq.transformers-to-coreml" alt="visitors"></p>
    # </div>
    # """)
    
demo.queue(concurrency_count=1, max_size=10)
demo.launch(debug=True, share=False)