File size: 7,961 Bytes
47c5626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import time
import datetime
import threading
import gradio as gr
import subprocess
import logging
from modules import script_callbacks, shared
from git import Repo
import shutil

# Constants
REPO_NAME = 'sd-webui-backups'
BACKUP_INTERVAL = 3600  # 1 hour in seconds
HF_TOKEN_KEY = 'hf_token'
BACKUP_PATHS_KEY = 'backup_paths'
SD_PATH_KEY = 'sd_path'
HF_USER_KEY = 'hf_user'
DEFAULT_BACKUP_PATHS = ['models/Stable-diffusion', 'models/VAE', 'embeddings', 'loras']

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.StreamHandler()
                    ])
logger = logging.getLogger(__name__)

# --- Helper function for updating the status ---
def update_status(script, status, file=None):
    if file:
        script.status = f"{status}: {file}"
        print(f"{status}: {file}") # For console logging.
    else:
        script.status = status
        print(status)  # For console logging

# --- Git Related Functions ---
def clone_or_create_repo(repo_url, repo_path, script):
    update_status(script, "Checking/Cloning Repo...")
    if os.path.exists(repo_path) and os.path.isdir(repo_path):
        logger.info(f"Repository already exists at {repo_path}, updating...")
        repo = Repo(repo_path)
        if repo.is_dirty():
            logger.warning("Local repo has uncommitted changes. Commit those before running to make sure nothing breaks.")
            update_status(script, "Local repo has uncommitted changes")
    else:
        logger.info(f"Cloning repository from {repo_url} to {repo_path}")
        update_status(script, "Cloning repository")
        try:
             use_git_credential_store = shared.opts.data.get("git_credential_store", True)
             if use_git_credential_store:
                repo = Repo.clone_from(repo_url, repo_path)
             else:
                 if "HF_TOKEN" not in os.environ:
                    update_status(script, "HF_TOKEN environment variable not found")
                    raise Exception("HF_TOKEN environment variable not found")
                 env_token = os.environ["HF_TOKEN"]
                 repo = Repo.clone_from(repo_url.replace("https://", f"https://{script.hf_user}:{env_token}@"), repo_path)

        except Exception as e:
            logger.error(f"Error creating or cloning repo: {e}")
            update_status(script, f"Error creating or cloning repo: {e}")
            raise
    update_status(script, "Repo ready")
    return repo

def git_push_files(repo_path, commit_message, script):
    update_status(script, "Pushing changes...")
    try:
        repo = Repo(repo_path)
        repo.git.add(all=True)
        repo.index.commit(commit_message)
        origin = repo.remote(name='origin')
        use_git_credential_store = shared.opts.data.get("git_credential_store", True)
        if use_git_credential_store:
            origin.push()
        else:
            if "HF_TOKEN" not in os.environ:
                update_status(script, "HF_TOKEN environment variable not found")
                raise Exception("HF_TOKEN environment variable not found")
            env_token = os.environ["HF_TOKEN"]
            origin.push(f"https://{script.hf_user}:{env_token}@huggingface.co/{script.hf_user}/{REPO_NAME}")

        logger.info(f"Changes pushed successfully to remote repository.")
        update_status(script, "Pushing Complete")
    except Exception as e:
         logger.error(f"Git push failed: {e}")
         update_status(script, f"Git push failed: {e}")
         raise

# --- Backup Logic ---
def backup_files(paths, hf_client, script):
    logger.info("Starting backup...")
    update_status(script, "Starting Backup...")
    repo_id = script.hf_user + "/" + REPO_NAME
    repo_path = os.path.join(script.basedir, 'backup')
    sd_path = script.sd_path

    try:
        repo = clone_or_create_repo(f"https://huggingface.co/{repo_id}", repo_path, script)
    except Exception as e:
        logger.error("Error starting the backup, please see the traceback.")
        return

    for base_path in paths:
        logger.info(f"Backing up: {base_path}")
        for root, _, files in os.walk(os.path.join(sd_path, base_path)):
            for file in files:
                local_file_path = os.path.join(root, file)
                repo_file_path = os.path.relpath(local_file_path, start=sd_path)
                try:
                    os.makedirs(os.path.dirname(os.path.join(repo_path, repo_file_path)), exist_ok=True)
                    shutil.copy2(local_file_path, os.path.join(repo_path, repo_file_path))
                    logger.info(f"Copied: {repo_file_path}")
                    update_status(script, "Copied", repo_file_path)
                except Exception as e:
                    logger.error(f"Error copying {repo_file_path}: {e}")
                    update_status(script, f"Error copying: {repo_file_path}: {e}")
                    return

    try:
        git_push_files(repo_path, f"Backup at {datetime.datetime.now()}", script)
        logger.info("Backup complete")
        update_status(script, "Backup Complete")
    except Exception as e:
         logger.error("Error pushing to the repo: ", e)
         return

def start_backup_thread(script):
    threading.Thread(target=backup_files, args=(script.backup_paths, None, script), daemon=True).start()

# Gradio UI Setup
def on_ui(script):
    with gr.Column():
        with gr.Row():
            with gr.Column(scale=3):
                hf_token_box = gr.Textbox(label="Huggingface Token", type='password', value=script.hf_token)
                def on_token_change(token):
                    script.hf_token = token
                    script.save()
                hf_token_box.change(on_token_change, inputs=[hf_token_box], outputs=None)
            with gr.Column(scale=1):
                status_box = gr.Textbox(label="Status", value=script.status)
                
                def on_start_button():
                    start_backup_thread(script)
                    return "Starting Backup"

                start_button = gr.Button(value="Start Backup")
                start_button.click(on_start_button, inputs=None, outputs=[status_box])

        with gr.Row():
             with gr.Column():
                  sd_path_box = gr.Textbox(label="SD Webui Path", value=script.sd_path)
                  def on_sd_path_change(path):
                        script.sd_path = path
                        script.save()
                  sd_path_box.change(on_sd_path_change, inputs=[sd_path_box], outputs=None)
             with gr.Column():
                  hf_user_box = gr.Textbox(label="Huggingface Username", value=script.hf_user)
                  def on_hf_user_change(user):
                        script.hf_user = user
                        script.save()
                  hf_user_box.change(on_hf_user_change, inputs=[hf_user_box], outputs=None)
        with gr.Row():
             backup_paths_box = gr.Textbox(label="Backup Paths (one path per line)", lines=4, value='\n'.join(script.backup_paths))
             def on_backup_paths_change(paths):
                paths_list = paths.split('\n')
                paths_list = [p.strip() for p in paths_list if p.strip()]
                script.backup_paths = paths_list
                script.save()
             backup_paths_box.change(on_backup_paths_change, inputs=[backup_paths_box], outputs=None)

def on_run(script, p, *args):
    pass
  
def on_script_load(script):
    script.hf_token = script.load().get(HF_TOKEN_KEY, '')
    script.backup_paths = script.load().get(BACKUP_PATHS_KEY, DEFAULT_BACKUP_PATHS)
    script.sd_path = script.load().get(SD_PATH_KEY, '')
    script.hf_user = script.load().get(HF_USER_KEY, '')
    script.status = "Not running"


script_callbacks.on_ui_tabs(on_ui)
script_callbacks.on_script_load(on_script_load)