File size: 10,935 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class AlwaysEqualProxy(str):
    def __eq__(self, _):
        return True

    def __ne__(self, _):
        return False

class TautologyStr(str):
    def __ne__(self, other):
        return False

class ByPassTypeTuple(tuple):
    def __getitem__(self, index):
        if index>0:
            index=0
        item = super().__getitem__(index)
        if isinstance(item, str):
            return TautologyStr(item)
        return item

comfy_ui_revision = None
def get_comfyui_revision():
    try:
        import git
        import os
        import folder_paths
        repo = git.Repo(os.path.dirname(folder_paths.__file__))
        comfy_ui_revision = len(list(repo.iter_commits('HEAD')))
    except:
        comfy_ui_revision = "Unknown"
    return comfy_ui_revision


import sys
import importlib.util
import importlib.metadata
import comfy.model_management as mm
import gc
from packaging import version
from server import PromptServer
def is_package_installed(package):
    try:
        module = importlib.util.find_spec(package)
        return module is not None
    except ImportError as e:
        print(e)
        return False

def install_package(package, v=None, compare=True, compare_version=None):
    run_install = True
    if is_package_installed(package):
        try:
            installed_version = importlib.metadata.version(package)
            if v is not None:
                if compare_version is None:
                    compare_version = v
                if not compare or version.parse(installed_version) >= version.parse(compare_version):
                    run_install = False
            else:
                run_install = False
        except:
            run_install = False

    if run_install:
        import subprocess
        package_command = package + '==' + v if v is not None else package
        PromptServer.instance.send_sync("easyuse-toast", {'content': f"Installing {package_command}...", 'duration': 5000})
        result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', package_command], capture_output=True, text=True)
        if result.returncode == 0:
            PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed successfully", 'type': 'success', 'duration': 5000})
            print(f"Package {package} installed successfully")
            return True
        else:
            PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed failed", 'type': 'error', 'duration': 5000})
            print(f"Package {package} installed failed")
            return False
    else:
        return False

def compare_revision(num):
    global comfy_ui_revision
    if not comfy_ui_revision:
        comfy_ui_revision = get_comfyui_revision()
    return True if comfy_ui_revision == 'Unknown' or int(comfy_ui_revision) >= num else False
def find_tags(string: str, sep="/") -> list[str]:
    """
    find tags from string use the sep for split
    Note: string may contain the \\ or / for path separator
    """
    if not string:
        return []
    string = string.replace("\\", "/")
    while "//" in string:
        string = string.replace("//", "/")
    if string and sep in string:
        return string.split(sep)[:-1]
    return []


from comfy.model_base import BaseModel
import comfy.supported_models
import comfy.supported_models_base
def get_sd_version(model):
    base: BaseModel = model.model
    model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
    if isinstance(model_config, comfy.supported_models.SDXL):
        return 'sdxl'
    elif isinstance(model_config, comfy.supported_models.SDXLRefiner):
        return 'sdxl_refiner'
    elif isinstance(
            model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
    ):
        return 'sd1'
    elif isinstance(
            model_config, (comfy.supported_models.SVD_img2vid)
    ):
        return 'svd'
    elif isinstance(model_config, comfy.supported_models.SD3):
        return 'sd3'
    elif isinstance(model_config, comfy.supported_models.HunyuanDiT):
        return 'hydit'
    elif isinstance(model_config, comfy.supported_models.Flux):
        return 'flux'
    elif isinstance(model_config, comfy.supported_models.GenmoMochi):
        return 'mochi'
    else:
        return 'unknown'

def find_nearest_steps(clip_id, prompt):
    """Find the nearest KSampler or preSampling node that references the given id."""
    def check_link_to_clip(node_id, clip_id, visited=None, node=None):
        """Check if a given node links directly or indirectly to a loader node."""
        if visited is None:
            visited = set()

        if node_id in visited:
            return False
        visited.add(node_id)
        if "pipe" in node["inputs"]:
            link_ids = node["inputs"]["pipe"]
            for id in link_ids:
                if id != 0 and id == str(clip_id):
                    return True
        return False

    for id in prompt:
        node = prompt[id]
        if "Sampler" in node["class_type"] or "sampler" in node["class_type"] or "Sampling" in node["class_type"]:
            # Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
            if check_link_to_clip(id, clip_id, None, node):
                steps = node["inputs"]["steps"] if "steps" in node["inputs"] else 1
                return steps
    return 1

def find_wildcards_seed(clip_id, text, prompt):
    """ Find easy wildcards seed value"""
    def find_link_clip_id(id, seed, wildcard_id):
        node = prompt[id]
        if "positive" in node['inputs']:
            link_ids = node["inputs"]["positive"]
            if type(link_ids) == list:
                for id in link_ids:
                    if id != 0:
                        if id == wildcard_id:
                            wildcard_node = prompt[wildcard_id]
                            seed = wildcard_node["inputs"]["seed"] if "seed" in wildcard_node["inputs"] else None
                            if seed is None:
                                seed = wildcard_node["inputs"]["seed_num"] if "seed_num" in wildcard_node["inputs"] else None
                            return seed
                        else:
                            return find_link_clip_id(id, seed, wildcard_id)
            else:
                return None
        else:
            return None
    if "__" in text:
        seed = None
        for id in prompt:
            node = prompt[id]
            if "wildcards" in node["class_type"]:
                wildcard_id = id
                return find_link_clip_id(str(clip_id), seed, wildcard_id)
        return seed
    else:
        return None

def is_linked_styles_selector(prompt, unique_id, prompt_type='positive'):
    unique_id = unique_id.split('.')[len(unique_id.split('.')) - 1] if "." in unique_id else unique_id
    inputs_values = prompt[unique_id]['inputs'][prompt_type] if prompt_type in prompt[unique_id][
        'inputs'] else None
    if type(inputs_values) == list and inputs_values != 'undefined' and inputs_values[0]:
        return True if prompt[inputs_values[0]] and prompt[inputs_values[0]]['class_type'] == 'easy stylesSelector' else False
    else:
        return False

use_mirror = False
def get_local_filepath(url, dirname, local_file_name=None):
    """Get local file path when is already downloaded or download it"""
    import os
    from server import PromptServer
    from urllib.parse import urlparse
    from torch.hub import download_url_to_file
    global use_mirror
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    if not local_file_name:
        parsed_url = urlparse(url)
        local_file_name = os.path.basename(parsed_url.path)
    destination = os.path.join(dirname, local_file_name)
    if not os.path.exists(destination):
        try:
            if use_mirror:
                url = url.replace('huggingface.co', 'hf-mirror.com')
            print(f'downloading {url} to {destination}')
            PromptServer.instance.send_sync("easyuse-toast", {'content': f'Downloading model to {destination}, please wait...', 'duration': 10000})
            download_url_to_file(url, destination)
        except Exception as e:
            use_mirror = True
            url = url.replace('huggingface.co', 'hf-mirror.com')
            print(f'无法从huggingface下载,正在尝试从 {url} 下载...')
            PromptServer.instance.send_sync("easyuse-toast", {'content': f'无法连接huggingface,正在尝试从 {url} 下载...', 'duration': 10000})
            try:
                download_url_to_file(url, destination)
            except Exception as err:
                PromptServer.instance.send_sync("easyuse-toast",
                                                {'content': f'无法从 {url} 下载模型', 'type':'error'})
                raise Exception(f'无法从 {url} 下载,错误信息:{str(err.args[0])}')
    return destination

def to_lora_patch_dict(state_dict: dict) -> dict:
    """ Convert raw lora state_dict to patch_dict that can be applied on
    modelpatcher."""
    patch_dict = {}
    for k, w in state_dict.items():
        model_key, patch_type, weight_index = k.split('::')
        if model_key not in patch_dict:
            patch_dict[model_key] = {}
        if patch_type not in patch_dict[model_key]:
            patch_dict[model_key][patch_type] = [None] * 16
        patch_dict[model_key][patch_type][int(weight_index)] = w

    patch_flat = {}
    for model_key, v in patch_dict.items():
        for patch_type, weight_list in v.items():
            patch_flat[model_key] = (patch_type, weight_list)

    return patch_flat

def easySave(images, filename_prefix, output_type, prompt=None, extra_pnginfo=None):
    """Save or Preview Image"""
    from nodes import PreviewImage, SaveImage
    if output_type in ["Hide", "None"]:
        return list()
    elif output_type in ["Preview", "Preview&Choose"]:
        filename_prefix = 'easyPreview'
        results = PreviewImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
        return results['ui']['images']
    else:
        results = SaveImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
        return results['ui']['images']

def getMetadata(filepath):
    with open(filepath, "rb") as file:
        # https://github.com/huggingface/safetensors#format
        # 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
        header_size = int.from_bytes(file.read(8), "little", signed=False)

        if header_size <= 0:
            raise BufferError("Invalid header size")

        header = file.read(header_size)
        if header_size <= 0:
            raise BufferError("Invalid header")

        return header

def cleanGPUUsedForce():
    gc.collect()
    mm.unload_all_models()
    mm.soft_empty_cache()