|
import argparse |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
from tempfile import TemporaryDirectory |
|
from typing import List, Optional, Tuple |
|
|
|
from huggingface_hub import (CommitOperationAdd, HfApi, get_repo_discussions, |
|
hf_hub_download) |
|
from huggingface_hub.file_download import repo_folder_name |
|
from optimum.exporters.onnx import (OnnxConfigWithPast, export, |
|
validate_model_outputs) |
|
from optimum.exporters.tasks import TasksManager |
|
from transformers import AutoConfig, AutoTokenizer, is_torch_available |
|
|
|
SPACES_URL = "https://huggingface.co/spaces/optimum/exporters" |
|
|
|
|
|
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]: |
|
try: |
|
discussions = api.get_repo_discussions(repo_id=model_id) |
|
except Exception: |
|
return None |
|
for discussion in discussions: |
|
if ( |
|
discussion.status == "open" |
|
and discussion.is_pull_request |
|
and discussion.title == pr_title |
|
): |
|
return discussion |
|
|
|
|
|
def convert_onnx(model_id: str, task: str, folder: str) -> List: |
|
|
|
|
|
model = TasksManager.get_model_from_task(task, model_id, framework="pt") |
|
model_type = model.config.model_type.replace("_", "-") |
|
model_name = getattr(model, "name", None) |
|
|
|
onnx_config_constructor = TasksManager.get_exporter_config_constructor( |
|
model_type, "onnx", task=task, model_name=model_name |
|
) |
|
onnx_config = onnx_config_constructor(model.config) |
|
|
|
needs_pad_token_id = ( |
|
isinstance(onnx_config, OnnxConfigWithPast) |
|
and getattr(model.config, "pad_token_id", None) is None |
|
and task in ["sequence_classification"] |
|
) |
|
if needs_pad_token_id: |
|
|
|
|
|
try: |
|
tok = AutoTokenizer.from_pretrained(model_id) |
|
model.config.pad_token_id = tok.pad_token_id |
|
except Exception: |
|
raise ValueError( |
|
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" |
|
) |
|
|
|
|
|
opset = onnx_config.DEFAULT_ONNX_OPSET |
|
|
|
output = Path(folder).joinpath("model.onnx") |
|
onnx_inputs, onnx_outputs = export( |
|
model, |
|
onnx_config, |
|
opset, |
|
output, |
|
) |
|
|
|
atol = onnx_config.ATOL_FOR_VALIDATION |
|
if isinstance(atol, dict): |
|
atol = atol[task.replace("-with-past", "")] |
|
|
|
try: |
|
validate_model_outputs(onnx_config, model, output, onnx_outputs, atol) |
|
print(f"All good, model saved at: {output}") |
|
except ValueError: |
|
print(f"An error occured, but the model was saved at: {output.as_posix()}") |
|
|
|
n_files = len( |
|
[ |
|
name |
|
for name in os.listdir(folder) |
|
if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".") |
|
] |
|
) |
|
|
|
if n_files == 1: |
|
operations = [ |
|
CommitOperationAdd( |
|
path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name) |
|
) |
|
for file_name in os.listdir(folder) |
|
] |
|
else: |
|
operations = [ |
|
CommitOperationAdd( |
|
path_in_repo=os.path.join("onnx", file_name), |
|
path_or_fileobj=os.path.join(folder, file_name), |
|
) |
|
for file_name in os.listdir(folder) |
|
] |
|
|
|
return operations |
|
|
|
|
|
def convert( |
|
api: "HfApi", model_id: str, task: str, force: bool = False |
|
) -> Tuple[int, "CommitInfo"]: |
|
pr_title = "Adding ONNX file of this model" |
|
info = api.model_info(model_id) |
|
filenames = set(s.rfilename for s in info.siblings) |
|
|
|
requesting_user = api.whoami()["name"] |
|
|
|
if task == "auto": |
|
try: |
|
task = TasksManager.infer_task_from_model(model_id) |
|
except Exception as e: |
|
return ( |
|
f"### Error: {e}. Please pass explicitely the task as it could not be infered.", |
|
None, |
|
) |
|
|
|
with TemporaryDirectory() as d: |
|
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) |
|
os.makedirs(folder) |
|
new_pr = None |
|
try: |
|
pr = previous_pr(api, model_id, pr_title) |
|
if "model.onnx" in filenames and not force: |
|
raise Exception(f"Model {model_id} is already converted, skipping..") |
|
elif pr is not None and not force: |
|
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}" |
|
new_pr = pr |
|
raise Exception( |
|
f"Model {model_id} already has an open PR check out [{url}]({url})" |
|
) |
|
else: |
|
operations = convert_onnx(model_id, task, folder) |
|
|
|
commit_description = f""" |
|
Beep boop I am the [ONNX export bot 🤖🏎️]({SPACES_URL}). On behalf of [{requesting_user}](https://huggingface.co/{requesting_user}), I would like to add to this repository the model converted to ONNX. |
|
|
|
What is ONNX? It stands for "Open Neural Network Exchange", and is the most commonly used open standard for machine learning interoperability. You can find out more at [onnx.ai](https://onnx.ai/)! |
|
|
|
The exported ONNX model can be then be consumed by various backends as TensorRT or TVM, or simply be used in a few lines with 🤗 Optimum through ONNX Runtime, check out how [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models)! |
|
""" |
|
new_pr = api.create_commit( |
|
repo_id=model_id, |
|
operations=operations, |
|
commit_message=pr_title, |
|
commit_description=commit_description, |
|
create_pr=True, |
|
) |
|
finally: |
|
shutil.rmtree(folder) |
|
return "0", new_pr |
|
|
|
|
|
if __name__ == "__main__": |
|
DESCRIPTION = """ |
|
Simple utility tool to convert automatically a model on the hub to onnx format. |
|
It is PyTorch exclusive for now. |
|
It works by downloading the weights (PT), converting them locally, and uploading them back |
|
as a PR on the hub. |
|
""" |
|
parser = argparse.ArgumentParser(description=DESCRIPTION) |
|
parser.add_argument( |
|
"--model_id", |
|
type=str, |
|
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
|
) |
|
parser.add_argument( |
|
"--task", |
|
type=str, |
|
help="The task the model is performing", |
|
) |
|
parser.add_argument( |
|
"--force", |
|
action="store_true", |
|
help="Create the PR even if it already exists of if the model was already converted.", |
|
) |
|
args = parser.parse_args() |
|
api = HfApi() |
|
convert(api, args.model_id, task=args.task, force=args.force) |
|
|