Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Optional | |
import onnxruntime as rt | |
from huggingface_hub import hf_hub_download | |
def download_onnx( | |
repo_id: str, | |
filename: str = "model.onnx", | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> Path: | |
if not filename.endswith(".onnx"): | |
filename += ".onnx" | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) | |
return Path(model_path).resolve() | |
def create_session( | |
repo_id: str, | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> rt.InferenceSession: | |
model_path = download_onnx(repo_id, revision=revision, token=token) | |
if not model_path.is_file(): | |
model_path = model_path.joinpath("model.onnx") | |
if not model_path.is_file(): | |
raise FileNotFoundError(f"Model not found: {model_path}") | |
model = rt.InferenceSession( | |
str(model_path), | |
providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], | |
) | |
return model | |