Spaces:
Paused
Paused
File size: 1,517 Bytes
a2919a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import os
from huggingface_hub import hf_hub_download, model_info
def get_model_path(pretrained_model_or_path, filename=None, subfolder=None):
"""
Retrieves the path to the model file.
If `pretrained_model_or_path` is a file, it returns the path directly.
Otherwise, it attempts to find a `.safetensors` file associated with the given model path.
If no `.safetensors` file is found, it raises a `FileNotFoundError`.
Parameters:
- pretrained_model_or_path (str): Path to the pretrained model or directory containing the model.
- filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file.
- subfolder (str, optional): Subfolder within the model directory to look for the file.
Returns:
- str: Path to the model file.
Raises:
- FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided.
"""
if os.path.isfile(pretrained_model_or_path):
return pretrained_model_or_path
if filename is None:
# If the filename is not passed, we only try to load a safetensor
info = model_info(pretrained_model_or_path)
filename = next(
(sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None
)
if filename is None:
raise FileNotFoundError("No safetensors checkpoint found.")
return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder)
|