amos1088's picture
uuu
a2919a7
import os
from huggingface_hub import hf_hub_download, model_info
def get_model_path(pretrained_model_or_path, filename=None, subfolder=None):
"""
Retrieves the path to the model file.
If `pretrained_model_or_path` is a file, it returns the path directly.
Otherwise, it attempts to find a `.safetensors` file associated with the given model path.
If no `.safetensors` file is found, it raises a `FileNotFoundError`.
Parameters:
- pretrained_model_or_path (str): Path to the pretrained model or directory containing the model.
- filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file.
- subfolder (str, optional): Subfolder within the model directory to look for the file.
Returns:
- str: Path to the model file.
Raises:
- FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided.
"""
if os.path.isfile(pretrained_model_or_path):
return pretrained_model_or_path
if filename is None:
# If the filename is not passed, we only try to load a safetensor
info = model_info(pretrained_model_or_path)
filename = next(
(sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None
)
if filename is None:
raise FileNotFoundError("No safetensors checkpoint found.")
return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder)