File size: 3,368 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from os import listdir
from os.path import isfile, join, exists, dirname
import sys
from datetime import datetime
from shared_utils.log_utils import cstr

def get_parent_dirpath_n_level_up(abs_path, n=1):
    for i in range(n):
        abs_path = dirname(abs_path)
    return abs_path

def get_persistent_directory(folder_name):
    if sys.platform == "win32":
        folder = join(os.path.expanduser("~"), "AppData", "Local", folder_name)
    else:
        folder = join(os.path.expanduser("~"), "." + folder_name)
    
    os.makedirs(folder, exist_ok=True)
    return folder

def parse_save_filename(save_path, output_directory, supported_extensions, class_name):
    
    folder_path, filename = os.path.split(save_path)
    filename, file_extension = os.path.splitext(filename)
    if file_extension.lower() in supported_extensions:
        if not os.path.isabs(save_path):
            folder_path = join(output_directory, folder_path)
        
        os.makedirs(folder_path, exist_ok=True)
        
        # replace time date format to current time
        now = datetime.now() # current date and time
        all_date_format = ["%Y", "%m", "%d", "%H", "%M", "%S", "%f"]
        for date_format in all_date_format:
            if date_format in filename:
                filename = filename.replace(date_format, now.strftime(date_format))
                
        save_path = join(folder_path, filename) + file_extension
        cstr(f"[{class_name}] Saving model to {save_path}").msg.print()
        return save_path
    else:
        cstr(f"[{class_name}] File name {filename} does not end with supported file extensions: {supported_extensions}").error.print()
    
    return None

def get_list_filenames(directory, extension_filter=None, recursive=False):
    """
    Recursively finds files with specified extensions in a directory and returns relative paths.

    Args:
        directory (str): The directory path to search.
        extension_filter (list): List of file extensions (e.g., ['.txt', '.csv']).

    Returns:
        list: List of relative file paths matching the specified extensions.
    """
    if exists(directory):
        if recursive:
            result = []
            for root, _, files in os.walk(directory):
                for item in files:
                    if extension_filter is None or os.path.splitext(item)[1].lower() in extension_filter:
                        relative_path = os.path.relpath(os.path.join(root, item), directory)
                        result.append(relative_path)
            return result
        else:
            return [f for f in listdir(directory) if isfile(join(directory, f)) and (extension_filter is None or f.lower().endswith(extension_filter))]
    else:
        return []
    
# Download pre-trained model if it not exist locally
def resume_or_download_model_from_hf(checkpoints_dir_abs, repo_id, model_name, class_name="", repo_type="model"):
    
    ckpt_path = os.path.join(checkpoints_dir_abs, model_name)
    if not os.path.isfile(ckpt_path):
        cstr(f"[{class_name}] can't find checkpoint {ckpt_path}, will download it from repo {repo_id} instead").warning.print()
        
        from huggingface_hub import hf_hub_download
        hf_hub_download(repo_id=repo_id, local_dir=checkpoints_dir_abs, filename=model_name, repo_type=repo_type)

    return ckpt_path