|
import os |
|
from os import listdir |
|
from os.path import isfile, join, exists, dirname |
|
import sys |
|
from datetime import datetime |
|
from shared_utils.log_utils import cstr |
|
|
|
def get_parent_dirpath_n_level_up(abs_path, n=1): |
|
for i in range(n): |
|
abs_path = dirname(abs_path) |
|
return abs_path |
|
|
|
def get_persistent_directory(folder_name): |
|
if sys.platform == "win32": |
|
folder = join(os.path.expanduser("~"), "AppData", "Local", folder_name) |
|
else: |
|
folder = join(os.path.expanduser("~"), "." + folder_name) |
|
|
|
os.makedirs(folder, exist_ok=True) |
|
return folder |
|
|
|
def parse_save_filename(save_path, output_directory, supported_extensions, class_name): |
|
|
|
folder_path, filename = os.path.split(save_path) |
|
filename, file_extension = os.path.splitext(filename) |
|
if file_extension.lower() in supported_extensions: |
|
if not os.path.isabs(save_path): |
|
folder_path = join(output_directory, folder_path) |
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
|
|
now = datetime.now() |
|
all_date_format = ["%Y", "%m", "%d", "%H", "%M", "%S", "%f"] |
|
for date_format in all_date_format: |
|
if date_format in filename: |
|
filename = filename.replace(date_format, now.strftime(date_format)) |
|
|
|
save_path = join(folder_path, filename) + file_extension |
|
cstr(f"[{class_name}] Saving model to {save_path}").msg.print() |
|
return save_path |
|
else: |
|
cstr(f"[{class_name}] File name {filename} does not end with supported file extensions: {supported_extensions}").error.print() |
|
|
|
return None |
|
|
|
def get_list_filenames(directory, extension_filter=None, recursive=False): |
|
""" |
|
Recursively finds files with specified extensions in a directory and returns relative paths. |
|
|
|
Args: |
|
directory (str): The directory path to search. |
|
extension_filter (list): List of file extensions (e.g., ['.txt', '.csv']). |
|
|
|
Returns: |
|
list: List of relative file paths matching the specified extensions. |
|
""" |
|
if exists(directory): |
|
if recursive: |
|
result = [] |
|
for root, _, files in os.walk(directory): |
|
for item in files: |
|
if extension_filter is None or os.path.splitext(item)[1].lower() in extension_filter: |
|
relative_path = os.path.relpath(os.path.join(root, item), directory) |
|
result.append(relative_path) |
|
return result |
|
else: |
|
return [f for f in listdir(directory) if isfile(join(directory, f)) and (extension_filter is None or f.lower().endswith(extension_filter))] |
|
else: |
|
return [] |
|
|
|
|
|
def resume_or_download_model_from_hf(checkpoints_dir_abs, repo_id, model_name, class_name="", repo_type="model"): |
|
|
|
ckpt_path = os.path.join(checkpoints_dir_abs, model_name) |
|
if not os.path.isfile(ckpt_path): |
|
cstr(f"[{class_name}] can't find checkpoint {ckpt_path}, will download it from repo {repo_id} instead").warning.print() |
|
|
|
from huggingface_hub import hf_hub_download |
|
hf_hub_download(repo_id=repo_id, local_dir=checkpoints_dir_abs, filename=model_name, repo_type=repo_type) |
|
|
|
return ckpt_path |
|
|