from pathlib import Path from typing import Optional import onnxruntime as rt from huggingface_hub import hf_hub_download def download_onnx( repo_id: str, filename: str = "model.onnx", revision: Optional[str] = None, token: Optional[str] = None, ) -> Path: if not filename.endswith(".onnx"): filename += ".onnx" model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) return Path(model_path).resolve() def create_session( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> rt.InferenceSession: model_path = download_onnx(repo_id, revision=revision, token=token) if not model_path.is_file(): model_path = model_path.joinpath("model.onnx") if not model_path.is_file(): raise FileNotFoundError(f"Model not found: {model_path}") model = rt.InferenceSession( str(model_path), providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], ) return model