from huggingface_hub import hf_hub_download from modelscope import snapshot_download import os, shutil from typing_extensions import Literal, TypeAlias from typing import List from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id def download_from_modelscope(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) if os.path.basename(origin_file_path) in os.listdir(local_dir): print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") return else: print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) downloaded_file_path = os.path.join(local_dir, origin_file_path) target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) if downloaded_file_path != target_file_path: shutil.move(downloaded_file_path, target_file_path) shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) def download_from_huggingface(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) if os.path.basename(origin_file_path) in os.listdir(local_dir): print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") return else: print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") hf_hub_download(model_id, origin_file_path, local_dir=local_dir) Preset_model_website: TypeAlias = Literal[ "HuggingFace", "ModelScope", ] website_to_preset_models = { "HuggingFace": preset_models_on_huggingface, "ModelScope": preset_models_on_modelscope, } website_to_download_fn = { "HuggingFace": download_from_huggingface, "ModelScope": download_from_modelscope, } def download_models( model_id_list: List[Preset_model_id] = [], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], ): print(f"Downloading models: {model_id_list}") downloaded_files = [] for model_id in model_id_list: for website in downloading_priority: if model_id in website_to_preset_models[website]: for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]: # Check if the file is downloaded. file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) if file_to_download in downloaded_files: continue # Download website_to_download_fn[website](model_id, origin_file_path, local_dir) if os.path.basename(origin_file_path) in os.listdir(local_dir): downloaded_files.append(file_to_download) return downloaded_files