jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
#credit to Acly for this module
#from https://github.com/Acly/comfyui-inpaint-nodes
import torch
import torch.nn.functional as F
import comfy
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.model_management import cast_to_device
from ..libs.log import log_node_warn, log_node_error, log_node_info
class InpaintHead(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
def __call__(self, x):
x = F.pad(x, (1, 1, 1, 1), "replicate")
return F.conv2d(x, weight=self.head)
# injected_model_patcher_calculate_weight = False
# original_calculate_weight = None
class applyFooocusInpaint:
def calculate_weight_patched(self, patches, weight, key, intermediate_dtype=torch.float32):
remaining = []
for p in patches:
alpha = p[0]
v = p[1]
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
if not is_fooocus_patch:
remaining.append(p)
continue
if alpha != 0.0:
v = v[1]
w1 = cast_to_device(v[0], weight.device, torch.float32)
if w1.shape == weight.shape:
w_min = cast_to_device(v[1], weight.device, torch.float32)
w_max = cast_to_device(v[2], weight.device, torch.float32)
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
else:
print(
f"[ApplyFooocusInpaint] Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
)
if len(remaining) > 0:
return original_calculate_weight(remaining, weight, key, intermediate_dtype)
return weight
def __enter__(self):
try:
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
self.original_calculate_weight = comfy.lora.calculate_weight
comfy.lora.calculate_weight = self.calculate_weight_patched
except AttributeError:
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
self.original_calculate_weight = ModelPatcher.calculate_weight
ModelPatcher.calculate_weight = self.calculate_weight_patched
def __exit__(self, exc_type, exc_value, traceback):
try:
comfy.lora.calculate_weight = self.original_calculate_weight
except:
ModelPatcher.calculate_weight = self.original_calculate_weight
# def inject_patched_calculate_weight():
# global injected_model_patcher_calculate_weight
# if not injected_model_patcher_calculate_weight:
# try:
# print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
# original_calculate_weight = comfy.lora.calculate_weight
# comfy.lora.original_calculate_weight = original_calculate_weight
# comfy.lora.calculate_weight = calculate_weight_patched
# except AttributeError:
# print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
# original_calculate_weight = ModelPatcher.calculate_weight
# ModelPatcher.original_calculate_weight = original_calculate_weight
# ModelPatcher.calculate_weight = calculate_weight_patched
# injected_model_patcher_calculate_weight = True
class InpaintWorker:
def __init__(self, node_name):
self.node_name = node_name if node_name is not None else ""
def load_fooocus_patch(self, lora: dict, to_load: dict):
patch_dict = {}
loaded_keys = set()
for key in to_load.values():
if value := lora.get(key, None):
patch_dict[key] = ("fooocus", value)
loaded_keys.add(key)
not_loaded = sum(1 for x in lora if x not in loaded_keys)
if not_loaded > 0:
log_node_info(self.node_name,
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
)
return patch_dict
def _input_block_patch(self, h: torch.Tensor, transformer_options: dict):
if transformer_options["block"][1] == 0:
if self._inpaint_block is None or self._inpaint_block.shape != h.shape:
assert self._inpaint_head_feature is not None
batch = h.shape[0] // self._inpaint_head_feature.shape[0]
self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1)
h = h + self._inpaint_block
return h
def patch(self, model, latent, patch):
base_model: BaseModel = model.model
latent_pixels = base_model.process_latent_in(latent["samples"])
noise_mask = latent["noise_mask"].round()
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
inpaint_head_model, inpaint_lora = patch
feed = torch.cat([latent_mask, latent_pixels], dim=1)
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
self._inpaint_head_feature = inpaint_head_model(feed)
self._inpaint_block = None
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
lora_keys.update({x: x for x in base_model.state_dict().keys()})
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
m = model.clone()
m.set_model_input_block_patch(self._input_block_patch)
patched = m.add_patches(loaded_lora, 1.0)
m.model_options['transformer_options']['fooocus'] = True
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
if not_patched_count > 0:
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
# inject_patched_calculate_weight()
return (m,)