File size: 9,442 Bytes
7eb3676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import models_dir
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass


class DownloadStatusType(Enum):
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    ERROR = "error"

@dataclass
class DownloadModelStatus():
    status: str
    progress_percentage: float
    message: str
    already_existed: bool = False

    def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
        self.status = status.value  # Store the string value of the Enum
        self.progress_percentage = progress_percentage
        self.message = message
        self.already_existed = already_existed
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "status": self.status,
            "progress_percentage": self.progress_percentage,
            "message": self.message,
            "already_existed": self.already_existed
        }

async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
                         model_name: str,  
                         model_url: str, 
                         model_sub_directory: str,
                         progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
                         progress_interval: float = 1.0) -> DownloadModelStatus:
    """
    Download a model file from a given URL into the models directory.

    Args:
        model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): 
            A function that makes an HTTP request. This makes it easier to mock in unit tests.
        model_name (str): 
            The name of the model file to be downloaded. This will be the filename on disk.
        model_url (str): 
            The URL from which to download the model.
        model_sub_directory (str): 
            The subdirectory within the main models directory where the model 
            should be saved (e.g., 'checkpoints', 'loras', etc.).
        progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): 
            An asynchronous function to call with progress updates.

    Returns:
        DownloadModelStatus: The result of the download operation.
    """
    if not validate_model_subdirectory(model_sub_directory):
        return DownloadModelStatus(
            DownloadStatusType.ERROR, 
            0,
            "Invalid model subdirectory", 
            False
        )

    if not validate_filename(model_name):
        return DownloadModelStatus(
            DownloadStatusType.ERROR, 
            0,
            "Invalid model name", 
            False
        )

    file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
    existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
    if existing_file:
        return existing_file

    try:
        status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
        await progress_callback(relative_path, status)

        response = await model_download_request(model_url)
        if response.status != 200:
            error_message = f"Failed to download {model_name}. Status code: {response.status}"
            logging.error(error_message)
            status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
            await progress_callback(relative_path, status)
            return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)

        return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)

    except Exception as e:
        logging.error(f"Error in downloading model: {e}")
        return await handle_download_error(e, model_name, progress_callback, relative_path)
    

def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
    full_model_dir = os.path.join(models_base_dir, model_directory)
    os.makedirs(full_model_dir, exist_ok=True)
    file_path = os.path.join(full_model_dir, model_name)

    # Ensure the resulting path is still within the base directory
    abs_file_path = os.path.abspath(file_path)
    abs_base_dir = os.path.abspath(str(models_base_dir))
    if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
        raise Exception(f"Invalid model directory: {model_directory}/{model_name}")


    relative_path = '/'.join([model_directory, model_name])
    return file_path, relative_path

async def check_file_exists(file_path: str, 
                            model_name: str, 
                            progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], 
                            relative_path: str) -> Optional[DownloadModelStatus]:
    if os.path.exists(file_path):
        status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
        await progress_callback(relative_path, status)
        return status
    return None


async def track_download_progress(response: aiohttp.ClientResponse, 
                                  file_path: str, 
                                  model_name: str, 
                                  progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], 
                                  relative_path: str, 
                                  interval: float = 1.0) -> DownloadModelStatus:
    try:
        total_size = int(response.headers.get('Content-Length', 0))
        downloaded = 0
        last_update_time = time.time()

        async def update_progress():
            nonlocal last_update_time
            progress = (downloaded / total_size) * 100 if total_size > 0 else 0
            status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
            await progress_callback(relative_path, status)
            last_update_time = time.time()

        with open(file_path, 'wb') as f:
            chunk_iterator = response.content.iter_chunked(8192)
            while True:
                try:
                    chunk = await chunk_iterator.__anext__()
                except StopAsyncIteration:
                    break
                f.write(chunk)
                downloaded += len(chunk)
                
                if time.time() - last_update_time >= interval:
                    await update_progress()

        await update_progress()
        
        logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
        status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
        await progress_callback(relative_path, status)

        return status
    except Exception as e:
        logging.error(f"Error in track_download_progress: {e}")
        logging.error(traceback.format_exc())
        return await handle_download_error(e, model_name, progress_callback, relative_path)

async def handle_download_error(e: Exception, 
                                model_name: str, 
                                progress_callback: Callable[[str, DownloadModelStatus], Any], 
                                relative_path: str) -> DownloadModelStatus:
    error_message = f"Error downloading {model_name}: {str(e)}"
    status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
    await progress_callback(relative_path, status)
    return status

def validate_model_subdirectory(model_subdirectory: str) -> bool:
    """
    Validate that the model subdirectory is safe to install into. 
    Must not contain relative paths, nested paths or special characters
    other than underscores and hyphens.

    Args:
        model_subdirectory (str): The subdirectory for the specific model type.

    Returns:
        bool: True if the subdirectory is safe, False otherwise.
    """
    if len(model_subdirectory) > 50:
        return False

    if '..' in model_subdirectory or '/' in model_subdirectory:
        return False

    if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
        return False

    return True

def validate_filename(filename: str)-> bool:
    """
    Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
    
    Args:
    filename (str): The filename to validate

    Returns:
    bool: True if the filename is valid, False otherwise
    """
    if not filename.lower().endswith(('.sft', '.safetensors')):
        return False

    # Check if the filename is empty, None, or just whitespace
    if not filename or not filename.strip():
        return False

    # Check for any directory traversal attempts or invalid characters
    if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
        return False

    # Check if the filename starts with a dot (hidden file)
    if filename.startswith('.'):
        return False

    # Use a whitelist of allowed characters
    if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
        return False

    # Ensure the filename isn't too long
    if len(filename) > 255:
        return False

    return True