File size: 9,206 Bytes
e60431a
f883aff
e60431a
 
f883aff
e60431a
 
 
f883aff
e60431a
 
 
 
f883aff
e60431a
 
 
 
 
 
 
 
57b5848
 
e60431a
 
57b5848
 
e60431a
57b5848
 
e60431a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f883aff
 
e60431a
f883aff
e60431a
 
f883aff
 
 
e60431a
f883aff
e60431a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#!/bin/sh

# 定义应用名称
APP_NAME="AiChat"

# 黑白名单配置(逗号分隔),先进行白名单过滤,然后在白名单的基础上进行黑名单过滤
WHITELIST="webui.db"
BLACKLIST=""

if [ -z "$HF_TOKEN" ] || [ -z "$DATASET_ID" ]; then
    echo "缺少必要的环境变量 HF_TOKEN 或 DATASET_ID"
    exit 1
fi

mkdir -p "./data"
mkdir -p "/tmp/${APP_NAME}"

cat > /tmp/hf_sync.py << 'EOL'
import os
import sys
import hashlib
import shutil
from datetime import datetime
from zoneinfo import ZoneInfo
from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete

SHANG_HAI_TZ = ZoneInfo("Asia/Shanghai")

def log_print(*args, prefix="[SyncData]", **kwargs):
    timestamp = datetime.now(SHANG_HAI_TZ).strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}]",prefix, *args, **kwargs)


def calculate_file_hash(file_path):
    """
    计算文件的 MD5 哈希值。
    """
    try:
        with open(file_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()
    except Exception as e:
        log_print(f"Error calculating hash for {file_path}: {e}")
        return None


def compare_and_sync_directories(source_dir, target_dir, whitelist=None, blacklist=None):
    """
    比较 source_dir 和 target_dir 的文件哈希值。
    如果不一致,将 source_dir 的文件同步到 target_dir。
    返回需要上传的文件列表。
    """
    files_to_upload = []

    def should_include_path(path, whitelist, blacklist):
        """
        检查路径是否应包含在内。
        """
        if whitelist:
            if not any(path.startswith(item.rstrip("/")) for item in whitelist):
                return False
        if blacklist:
            if any(path.startswith(item.rstrip("/")) for item in blacklist):
                return False
        return True

    def walk_and_filter(root_dir, rel_path="", whitelist=None, blacklist=None):
        """
        遍历目录并根据黑白名单过滤文件。
        """
        full_path = os.path.join(root_dir, rel_path)
        if not os.path.exists(full_path):
            return []

        filtered_files = []
        try:
            entries = os.listdir(full_path)
            for entry in entries:
                entry_rel_path = os.path.join(rel_path, entry)
                entry_full_path = os.path.join(full_path, entry)

                if not should_include_path(entry_rel_path, whitelist, blacklist):
                    continue

                if os.path.isdir(entry_full_path):
                    filtered_files.extend(walk_and_filter(root_dir, entry_rel_path, whitelist, blacklist))
                else:
                    filtered_files.append(entry_rel_path)
        except Exception as e:
            log_print(f"Error processing directory {full_path}: {e}")
        return filtered_files

    source_files = {}
    if os.path.exists(source_dir):
        filtered_source_files = walk_and_filter(source_dir, whitelist=whitelist, blacklist=blacklist)
        for relative_path in filtered_source_files:
            file_path = os.path.join(source_dir, relative_path)
            file_hash = calculate_file_hash(file_path)
            if file_hash is not None:
                source_files[relative_path] = file_hash

    target_files = {}
    if os.path.exists(target_dir):
        for root, _, files in os.walk(target_dir):
            for file in files:
                file_path = os.path.join(root, file)
                relative_path = os.path.relpath(file_path, target_dir)
                file_hash = calculate_file_hash(file_path)
                if file_hash is not None:
                    target_files[relative_path] = file_hash

    for relative_path in source_files.keys():
        source_file_path = os.path.join(source_dir, relative_path)
        target_file_path = os.path.join(target_dir, relative_path)
        if relative_path not in target_files or target_files[relative_path] != source_files[relative_path]:
            os.makedirs(os.path.dirname(target_file_path), exist_ok=True)
            shutil.copy2(source_file_path, target_file_path)
            files_to_upload.append(relative_path)

    for relative_path in target_files.keys():
        if relative_path not in source_files:
            target_file_path = os.path.join(target_dir, relative_path)
            os.remove(target_file_path)

    return files_to_upload


def upload_files(api, repo_id, local_dir, remote_dir, files_to_upload, operations):
    """
    上传本地文件到远程仓库。
    """
    for relative_path in files_to_upload:
        local_file_path = os.path.join(local_dir, relative_path)
        remote_file_path = os.path.join(remote_dir, relative_path)
        operations.append(CommitOperationAdd(path_in_repo=remote_file_path, path_or_fileobj=local_file_path))


def delete_files(api, repo_id, remote_dir, files_to_delete, operations):
    """
    删除远程仓库中的文件。
    """
    for relative_path in files_to_delete:
        remote_file_path = os.path.join(remote_dir, relative_path)
        operations.append(CommitOperationDelete(path_in_repo=remote_file_path))


def commit_operations(api, repo_id, operations, commit_message):
    """
    统一处理 Hugging Face 的 Commit 操作。
    """
    try:
        if operations:
            api.create_commit(
                repo_id=repo_id,
                repo_type="dataset",
                operations=operations,
                commit_message=commit_message,
            )
            log_print("已成功更新云端版本!")
        else:
            log_print("当前版本已为最新,无需更新!")
    except Exception as e:
        log_print(f"更新提交失败: {str(e)}")


def download_files(api, repo_id, remote_dir, local_dir):
    """
    从远程仓库下载文件到本地目录。
    """
    try:
        remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
        filtered_files = [file for file in remote_files if file.startswith(remote_dir)]

        for remote_file in filtered_files:
            relative_path = os.path.relpath(remote_file, remote_dir)
            local_file_path = os.path.join(local_dir, relative_path)
            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
            tmp_filepath = api.hf_hub_download(
                repo_id=repo_id,
                filename=remote_file,
                repo_type="dataset",
            )
            if tmp_filepath and os.path.exists(tmp_filepath):
                shutil.copy2(tmp_filepath, local_file_path)
                shutil.copy2(tmp_filepath, f"/tmp/{remote_file}")

            log_print(f"已下载{remote_file} -> {local_file_path}")

    except Exception as e:
        log_print(f"下载失败: {str(e)}")


def sync_repository(api, repo_id, remote_dir, whitelist=None, blacklist=None):
    """
    同步本地与远程仓库(单次执行)。
    """
    log_print(f"开始数据同步进程...")
    source_dir = "./data"
    target_dir = f"/tmp/{remote_dir}"

    files_to_upload = compare_and_sync_directories(source_dir, target_dir, whitelist, blacklist)

    remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
    local_files = []
    if os.path.exists(target_dir):
        for root, _, files in os.walk(target_dir):
            for file in files:
                file_path = os.path.join(root, file)
                relative_path = os.path.relpath(file_path, target_dir)
                local_files.append(relative_path)
    local_files_set = set(local_files)

    files_to_delete = [
        os.path.relpath(remote_file, remote_dir)
        for remote_file in remote_files
        if remote_file.startswith(remote_dir) and os.path.relpath(remote_file, remote_dir) not in local_files_set
    ]

    operations = []
    upload_files(api, repo_id, source_dir, remote_dir, files_to_upload, operations)
    delete_files(api, repo_id, remote_dir, files_to_delete, operations)

    commit_operations(api, repo_id, operations, f"Sync repository: {remote_dir}")

    if files_to_upload:
        log_print(f"文件已上传: {files_to_upload}")
    if files_to_delete:
        log_print(f"文件已删除: {files_to_delete}")


if __name__ == "__main__":
    action = sys.argv[1]
    token = sys.argv[2]
    repo_id = sys.argv[3]
    remote_dir = sys.argv[4]
    api = HfApi(token=token)
    source_dir = "./data"

    if action == "sync":
        whitelist = sys.argv[5].split(",") if len(sys.argv) > 5 and sys.argv[5] not in ["", "None"] else None
        blacklist = sys.argv[6].split(",") if len(sys.argv) > 6 and sys.argv[6] not in ["", "None"] else None
        sync_repository(api, repo_id, remote_dir, whitelist, blacklist)
    elif action == "download":
        download_files(api, repo_id, remote_dir, source_dir)
EOL

sync_data() {
    SYNC_INTERVAL=${SYNC_INTERVAL:-7200}  # 默认同步间隔为 7200 秒(2 小时)
    while true; do
        python3 /tmp/hf_sync.py sync "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" "${WHITELIST}" "${BLACKLIST}"
        sleep "${SYNC_INTERVAL}"
    done
}

python3 /tmp/hf_sync.py download "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}"

sync_data &