File size: 1,517 Bytes
a2919a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os

from huggingface_hub import hf_hub_download, model_info


def get_model_path(pretrained_model_or_path, filename=None, subfolder=None):
    """
    Retrieves the path to the model file.

    If `pretrained_model_or_path` is a file, it returns the path directly.
    Otherwise, it attempts to find a `.safetensors` file associated with the given model path.
    If no `.safetensors` file is found, it raises a `FileNotFoundError`.

    Parameters:
    - pretrained_model_or_path (str): Path to the pretrained model or directory containing the model.
    - filename (str, optional): Specific filename to load. If not provided, the function will search for a `.safetensors` file.
    - subfolder (str, optional): Subfolder within the model directory to look for the file.

    Returns:
    - str: Path to the model file.

    Raises:
    - FileNotFoundError: If no `.safetensors` file is found when `filename` is not provided.
    """
    if os.path.isfile(pretrained_model_or_path):
        return pretrained_model_or_path

    if filename is None:
        # If the filename is not passed, we only try to load a safetensor
        info = model_info(pretrained_model_or_path)
        filename = next(
            (sibling.rfilename for sibling in info.siblings if sibling.rfilename.endswith(".safetensors")), None
        )
        if filename is None:
            raise FileNotFoundError("No safetensors checkpoint found.")

    return hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder)