File size: 2,948 Bytes
703e263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id


def download_from_modelscope(model_id, origin_file_path, local_dir):
    os.makedirs(local_dir, exist_ok=True)
    if os.path.basename(origin_file_path) in os.listdir(local_dir):
        print(f"    {os.path.basename(origin_file_path)} has been already in {local_dir}.")
        return
    else:
        print(f"    Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
    snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
    downloaded_file_path = os.path.join(local_dir, origin_file_path)
    target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
    if downloaded_file_path != target_file_path:
        shutil.move(downloaded_file_path, target_file_path)
        shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))


def download_from_huggingface(model_id, origin_file_path, local_dir):
    os.makedirs(local_dir, exist_ok=True)
    if os.path.basename(origin_file_path) in os.listdir(local_dir):
        print(f"    {os.path.basename(origin_file_path)} has been already in {local_dir}.")
        return
    else:
        print(f"    Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
    hf_hub_download(model_id, origin_file_path, local_dir=local_dir)


Preset_model_website: TypeAlias = Literal[
    "HuggingFace",
    "ModelScope",
]
website_to_preset_models = {
    "HuggingFace": preset_models_on_huggingface,
    "ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
    "HuggingFace": download_from_huggingface,
    "ModelScope": download_from_modelscope,
}


def download_models(
    model_id_list: List[Preset_model_id] = [],
    downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
    print(f"Downloading models: {model_id_list}")
    downloaded_files = []
    for model_id in model_id_list:
        for website in downloading_priority:
            if model_id in website_to_preset_models[website]:
                for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
                    # Check if the file is downloaded.
                    file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
                    if file_to_download in downloaded_files:
                        continue
                    # Download
                    website_to_download_fn[website](model_id, origin_file_path, local_dir)
                    if os.path.basename(origin_file_path) in os.listdir(local_dir):
                        downloaded_files.append(file_to_download)
    return downloaded_files