Spaces:
Paused
Paused
import os | |
from huggingface_hub import hf_hub_download, model_info | |
def get_model_path(pretrained_model_or_path, filename=None, subfolder=None): | |
""" | |
Retrieves the path to the model file. | |
If `pretrained_model_or_path` is a file, it returns the path directly. | |
Otherwise, it attempts to find a `.safetensors` file associated with the given model path. | |
If no `.safetensors` file is found, it raises a `FileNotFoundError`. | |
Parameters: | |
- pretrained_model_or_path (str): Path to the pretrained model or directory containing the model. | |
- filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file. | |
- subfolder (str, optional): Subfolder within the model directory to look for the file. | |
Returns: | |
- str: Path to the model file. | |
Raises: | |
- FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided. | |
""" | |
if os.path.isfile(pretrained_model_or_path): | |
return pretrained_model_or_path | |
if filename is None: | |
# If the filename is not passed, we only try to load a safetensor | |
info = model_info(pretrained_model_or_path) | |
filename = next( | |
(sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None | |
) | |
if filename is None: | |
raise FileNotFoundError("No safetensors checkpoint found.") | |
return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder) | |