File size: 4,695 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import contextlib
import copy
import unittest

import comfy.model_management
import comfy.model_patcher
import comfy.sd
import comfy.utils
import torch


class QuantizedModelPatcher(comfy.model_patcher.ModelPatcher):
    _object_to_patch_default = None
    _quantize_fn_default = None
    _lowvram_default = True
    _full_load_default = True
    _is_quantized_default = False

    _load_device = None
    _offload_device = None
    _disable_load = False

    @classmethod
    @contextlib.contextmanager
    def _override_defaults(cls, **kwargs):
        old_defaults = {}
        for k in ("object_to_patch", "quantize_fn", "lowvram", "full_load"):
            if k in kwargs:
                old_defaults[k] = getattr(cls, f"_{k}_default")
                setattr(cls, f"_{k}_default", kwargs[k])
        try:
            yield
        finally:
            for k in old_defaults:
                setattr(cls, f"_{k}_default", old_defaults[k])

    @classmethod
    @contextlib.contextmanager
    def _set_disable_load(cls, disable_load=True):
        old_disable_load = cls._disable_load
        cls._disable_load = disable_load
        try:
            yield
        finally:
            cls._disable_load = old_disable_load

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._object_to_patch = QuantizedModelPatcher._object_to_patch_default
        self._quantize_fn = QuantizedModelPatcher._quantize_fn_default
        self._lowvram = QuantizedModelPatcher._lowvram_default
        self._full_load = QuantizedModelPatcher._full_load_default
        self._is_quantized = QuantizedModelPatcher._is_quantized_default

    def load(
        self, device_to=None, force_patch_weights=False, full_load=False, **kwargs
    ):
        if self._disable_load:
            return

        if self._is_quantized:
            super().load(
                device_to=device_to,
                force_patch_weights=force_patch_weights,
                full_load=full_load,
                **kwargs,
            )
            return

        with unittest.mock.patch.object(
            QuantizedModelPatcher, "_load_device", self.load_device
        ), unittest.mock.patch.object(
            QuantizedModelPatcher, "_offload_device", self.offload_device
        ):
            # always call `patch_weight_to_device` even for lowvram
            super().load(
                torch.device("cpu") if self._lowvram else device_to,
                force_patch_weights=True,
                full_load=self._full_load or full_load,
                **kwargs,
            )

            if self._quantize_fn is not None:
                if self._object_to_patch is None:
                    target_model = self.model
                else:
                    target_model = comfy.utils.get_attr(
                        self.model, self._object_to_patch
                    )
                target_model = self._quantize_fn(target_model)
                if self._object_to_patch is None:
                    self.model = target_model
                else:
                    comfy.utils.set_attr(
                        self.model, self._object_to_patch, target_model
                    )

            if self._lowvram:
                if device_to.type == "cuda":
                    torch.cuda.empty_cache()
                self.model.to(device_to)

        self._is_quantized = True

    # def model_size(self):
    #     return super().model_size() // 2

    def clone(self, *args, **kwargs):
        n = QuantizedModelPatcher(
            self.model,
            self.load_device,
            self.offload_device,
            self.size,
            weight_inplace_update=self.weight_inplace_update,
        )
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]
        n.patches_uuid = self.patches_uuid

        n.object_patches = self.object_patches.copy()
        n.model_options = copy.deepcopy(self.model_options)
        n.backup = self.backup
        n.object_patches_backup = self.object_patches_backup

        n._object_to_patch = getattr(
            self, "_object_to_patch", QuantizedModelPatcher._object_to_patch_default
        )
        n._quantize_fn = getattr(
            self, "_quantize_fn", QuantizedModelPatcher._quantize_fn_default
        )
        n._lowvram = getattr(self, "_lowvram", QuantizedModelPatcher._lowvram_default)
        n._full_load = getattr(
            self, "_full_load", QuantizedModelPatcher._full_load_default
        )
        n._is_quantized = getattr(
            self, "_is_quantized", QuantizedModelPatcher._is_quantized_default
        )
        return n