File size: 6,142 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#credit to Acly for this module
#from https://github.com/Acly/comfyui-inpaint-nodes
import torch
import torch.nn.functional as F
import comfy
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.model_management import cast_to_device
from ..libs.log import log_node_warn, log_node_error, log_node_info
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
def __call__(self, x):
x = F.pad(x, (1, 1, 1, 1), "replicate")
return F.conv2d(x, weight=self.head)
# injected_model_patcher_calculate_weight = False
# original_calculate_weight = None
class applyFooocusInpaint:
def calculate_weight_patched(self, patches, weight, key, intermediate_dtype=torch.float32):
remaining = []
for p in patches:
alpha = p[0]
v = p[1]
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
if not is_fooocus_patch:
remaining.append(p)
continue
if alpha != 0.0:
v = v[1]
w1 = cast_to_device(v[0], weight.device, torch.float32)
if w1.shape == weight.shape:
w_min = cast_to_device(v[1], weight.device, torch.float32)
w_max = cast_to_device(v[2], weight.device, torch.float32)
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
else:
print(
f"[ApplyFooocusInpaint] Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
)
if len(remaining) > 0:
return original_calculate_weight(remaining, weight, key, intermediate_dtype)
return weight
def __enter__(self):
try:
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
self.original_calculate_weight = comfy.lora.calculate_weight
comfy.lora.calculate_weight = self.calculate_weight_patched
except AttributeError:
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
self.original_calculate_weight = ModelPatcher.calculate_weight
ModelPatcher.calculate_weight = self.calculate_weight_patched
def __exit__(self, exc_type, exc_value, traceback):
try:
comfy.lora.calculate_weight = self.original_calculate_weight
except:
ModelPatcher.calculate_weight = self.original_calculate_weight
# def inject_patched_calculate_weight():
# global injected_model_patcher_calculate_weight
# if not injected_model_patcher_calculate_weight:
# try:
# print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
# original_calculate_weight = comfy.lora.calculate_weight
# comfy.lora.original_calculate_weight = original_calculate_weight
# comfy.lora.calculate_weight = calculate_weight_patched
# except AttributeError:
# print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
# original_calculate_weight = ModelPatcher.calculate_weight
# ModelPatcher.original_calculate_weight = original_calculate_weight
# ModelPatcher.calculate_weight = calculate_weight_patched
# injected_model_patcher_calculate_weight = True
class InpaintWorker:
def __init__(self, node_name):
self.node_name = node_name if node_name is not None else ""
def load_fooocus_patch(self, lora: dict, to_load: dict):
patch_dict = {}
loaded_keys = set()
for key in to_load.values():
if value := lora.get(key, None):
patch_dict[key] = ("fooocus", value)
loaded_keys.add(key)
not_loaded = sum(1 for x in lora if x not in loaded_keys)
if not_loaded > 0:
log_node_info(self.node_name,
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
)
return patch_dict
def _input_block_patch(self, h: torch.Tensor, transformer_options: dict):
if transformer_options["block"][1] == 0:
if self._inpaint_block is None or self._inpaint_block.shape != h.shape:
assert self._inpaint_head_feature is not None
batch = h.shape[0] // self._inpaint_head_feature.shape[0]
self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1)
h = h + self._inpaint_block
return h
def patch(self, model, latent, patch):
base_model: BaseModel = model.model
latent_pixels = base_model.process_latent_in(latent["samples"])
noise_mask = latent["noise_mask"].round()
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
inpaint_head_model, inpaint_lora = patch
feed = torch.cat([latent_mask, latent_pixels], dim=1)
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
self._inpaint_head_feature = inpaint_head_model(feed)
self._inpaint_block = None
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
lora_keys.update({x: x for x in base_model.state_dict().keys()})
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
m = model.clone()
m.set_model_input_block_patch(self._input_block_patch)
patched = m.add_patches(loaded_lora, 1.0)
m.model_options['transformer_options']['fooocus'] = True
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
if not_patched_count > 0:
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
# inject_patched_calculate_weight()
return (m,) |