File size: 4,695 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import contextlib
import copy
import unittest
import comfy.model_management
import comfy.model_patcher
import comfy.sd
import comfy.utils
import torch
class QuantizedModelPatcher(comfy.model_patcher.ModelPatcher):
_object_to_patch_default = None
_quantize_fn_default = None
_lowvram_default = True
_full_load_default = True
_is_quantized_default = False
_load_device = None
_offload_device = None
_disable_load = False
@classmethod
@contextlib.contextmanager
def _override_defaults(cls, **kwargs):
old_defaults = {}
for k in ("object_to_patch", "quantize_fn", "lowvram", "full_load"):
if k in kwargs:
old_defaults[k] = getattr(cls, f"_{k}_default")
setattr(cls, f"_{k}_default", kwargs[k])
try:
yield
finally:
for k in old_defaults:
setattr(cls, f"_{k}_default", old_defaults[k])
@classmethod
@contextlib.contextmanager
def _set_disable_load(cls, disable_load=True):
old_disable_load = cls._disable_load
cls._disable_load = disable_load
try:
yield
finally:
cls._disable_load = old_disable_load
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._object_to_patch = QuantizedModelPatcher._object_to_patch_default
self._quantize_fn = QuantizedModelPatcher._quantize_fn_default
self._lowvram = QuantizedModelPatcher._lowvram_default
self._full_load = QuantizedModelPatcher._full_load_default
self._is_quantized = QuantizedModelPatcher._is_quantized_default
def load(
self, device_to=None, force_patch_weights=False, full_load=False, **kwargs
):
if self._disable_load:
return
if self._is_quantized:
super().load(
device_to=device_to,
force_patch_weights=force_patch_weights,
full_load=full_load,
**kwargs,
)
return
with unittest.mock.patch.object(
QuantizedModelPatcher, "_load_device", self.load_device
), unittest.mock.patch.object(
QuantizedModelPatcher, "_offload_device", self.offload_device
):
# always call `patch_weight_to_device` even for lowvram
super().load(
torch.device("cpu") if self._lowvram else device_to,
force_patch_weights=True,
full_load=self._full_load or full_load,
**kwargs,
)
if self._quantize_fn is not None:
if self._object_to_patch is None:
target_model = self.model
else:
target_model = comfy.utils.get_attr(
self.model, self._object_to_patch
)
target_model = self._quantize_fn(target_model)
if self._object_to_patch is None:
self.model = target_model
else:
comfy.utils.set_attr(
self.model, self._object_to_patch, target_model
)
if self._lowvram:
if device_to.type == "cuda":
torch.cuda.empty_cache()
self.model.to(device_to)
self._is_quantized = True
# def model_size(self):
# return super().model_size() // 2
def clone(self, *args, **kwargs):
n = QuantizedModelPatcher(
self.model,
self.load_device,
self.offload_device,
self.size,
weight_inplace_update=self.weight_inplace_update,
)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n._object_to_patch = getattr(
self, "_object_to_patch", QuantizedModelPatcher._object_to_patch_default
)
n._quantize_fn = getattr(
self, "_quantize_fn", QuantizedModelPatcher._quantize_fn_default
)
n._lowvram = getattr(self, "_lowvram", QuantizedModelPatcher._lowvram_default)
n._full_load = getattr(
self, "_full_load", QuantizedModelPatcher._full_load_default
)
n._is_quantized = getattr(
self, "_is_quantized", QuantizedModelPatcher._is_quantized_default
)
return n
|