Spaces:
Runtime error
Runtime error
File size: 1,034 Bytes
2b6048b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from pathlib import Path
from typing import Optional
import onnxruntime as rt
from huggingface_hub import hf_hub_download
def download_onnx(
repo_id: str,
filename: str = "model.onnx",
revision: Optional[str] = None,
token: Optional[str] = None,
) -> Path:
if not filename.endswith(".onnx"):
filename += ".onnx"
model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
return Path(model_path).resolve()
def create_session(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> rt.InferenceSession:
model_path = download_onnx(repo_id, revision=revision, token=token)
if not model_path.is_file():
model_path = model_path.joinpath("model.onnx")
if not model_path.is_file():
raise FileNotFoundError(f"Model not found: {model_path}")
model = rt.InferenceSession(
str(model_path),
providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"],
)
return model
|