from optimum.exporters.tasks import TasksManager

from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs

from tempfile import TemporaryDirectory

from transformers import AutoConfig, AutoTokenizer, is_torch_available

from pathlib import Path

import os
import shutil
import argparse

from typing import Optional, Tuple, List

from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
from huggingface_hub.file_download import repo_folder_name

def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
            return discussion

def convert_onnx(model_id: str, task: str, folder: str) -> List:

    # Allocate the model
    model = TasksManager.get_model_from_task(task, model_id, framework="pt")
    model_type = model.config.model_type.replace("_", "-")
    model_name = getattr(model, "name", None)

    onnx_config_constructor = TasksManager.get_exporter_config_constructor(
        model_type, "onnx", task=task, model_name=model_name
    )
    onnx_config = onnx_config_constructor(model.config)

    needs_pad_token_id = (
        isinstance(onnx_config, OnnxConfigWithPast)
        and getattr(model.config, "pad_token_id", None) is None
        and task in ["sequence_classification"]
    )
    if needs_pad_token_id:
        #if args.pad_token_id is not None:
        #    model.config.pad_token_id = args.pad_token_id
        try:
            tok = AutoTokenizer.from_pretrained(model_id)
            model.config.pad_token_id = tok.pad_token_id
        except Exception:
            raise ValueError(
                "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
            )

    # Ensure the requested opset is sufficient
    opset = onnx_config.DEFAULT_ONNX_OPSET

    output = Path(folder).joinpath("model.onnx")
    onnx_inputs, onnx_outputs = export(
        model,
        onnx_config,
        opset,
        output,
    )

    atol = onnx_config.ATOL_FOR_VALIDATION
    if isinstance(atol, dict):
        atol = atol[task.replace("-with-past", "")]

    validate_model_outputs(onnx_config, model, output, onnx_outputs, atol)
    print(f"All good, model saved at: {output}")
    
    n_files = len([name for name in os.listdir(folder) if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")])
    
    if n_files == 1:
        operations = [CommitOperationAdd(path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
    else:
        operations = [CommitOperationAdd(path_in_repo=os.path.join("onnx", file_name), path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)]
    
    return operations


def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tuple[int, "CommitInfo"]:
    pr_title = "Adding ONNX file of this model"
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    if task == "auto":
        try:
            task = TasksManager.infer_task_from_model(model_id)
        except Exception as e:
            return f"### Error: {e}. Please pass explicitely the task as it could not be infered.", None

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            pr = previous_pr(api, model_id, pr_title)
            if "model.onnx" in filenames and not force:
                raise Exception(f"Model {model_id} is already converted, skipping..")
            elif pr is not None and not force:
                url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
                new_pr = pr
                raise Exception(f"Model {model_id} already has an open PR check out {url}")
            else:
                operations = convert_onnx(model_id, task, folder)

                new_pr = api.create_commit(
                    repo_id=model_id,
                    operations=operations,
                    commit_message=pr_title,
                    create_pr=True,
                )
        finally:
            shutil.rmtree(folder)
        return "0", new_pr


if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility tool to convert automatically a model on the hub to onnx format.
    It is PyTorch exclusive for now.
    It works by downloading the weights (PT), converting them locally, and uploading them back
    as a PR on the hub.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "--model_id",
        type=str,
        help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
    )
    parser.add_argument(
        "--task",
        type=str,
        help="The task the model is performing",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Create the PR even if it already exists of if the model was already converted.",
    )
    args = parser.parse_args()
    api = HfApi()
    convert(api, args.model_id, task=args.task, force=args.force)