File size: 2,949 Bytes
2fdce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Union, BinaryIO
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread


def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
    api = HfApi(
        token=token,
    )
    try:
        api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
        return True
    except:
        return False


def upload(
    args: argparse.Namespace,
    src: Union[str, Path, bytes, BinaryIO],
    dest_suffix: str = "",
    force_sync_upload: bool = False,
):
    repo_id = args.huggingface_repo_id
    repo_type = args.huggingface_repo_type
    token = args.huggingface_token
    path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
    private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
    api = HfApi(token=token)
    if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
        try:
            api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
        except Exception as e:  # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
            print("===========================================")
            print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
            print("===========================================")

    is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())

    def uploader():
        try:
            if is_folder:
                api.upload_folder(
                    repo_id=repo_id,
                    repo_type=repo_type,
                    folder_path=src,
                    path_in_repo=path_in_repo,
                )
            else:
                api.upload_file(
                    repo_id=repo_id,
                    repo_type=repo_type,
                    path_or_fileobj=src,
                    path_in_repo=path_in_repo,
                )
        except Exception as e:  # RuntimeErrorを確認済みだが他にあると困るので
            print("===========================================")
            print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
            print("===========================================")

    if args.async_upload and not force_sync_upload:
        fire_in_thread(uploader)
    else:
        uploader()


def list_dir(
    repo_id: str,
    subfolder: str,
    repo_type: str,
    revision: str = "main",
    token: str = None,
):
    api = HfApi(
        token=token,
    )
    repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
    file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
    return file_list