|
import os |
|
import numpy as np |
|
|
|
import torch |
|
from torchvision import transforms |
|
from .pipeline_flux_fill import FluxFillPipeline |
|
|
|
import comfy.model_management as mm |
|
from .utils import convert_diffusers_flux_lora |
|
|
|
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
class LoadCatvtonFlux: |
|
|
|
RETURN_TYPES = ("CatvtonFluxModel",) |
|
FUNCTION = "load_catvton_flux" |
|
CATEGORY = "CatvtonFluxWrapper" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
}, |
|
} |
|
|
|
def load_catvton_flux(self): |
|
|
|
load_device = mm.text_encoder_device() |
|
offload_device = mm.text_encoder_offload_device() |
|
|
|
print("Start loading LoRA weights") |
|
state_dict, network_alphas = FluxFillPipeline.lora_state_dict( |
|
pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", |
|
weight_name="pytorch_lora_weights.safetensors", |
|
return_alphas=True |
|
) |
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) |
|
if not is_correct_format: |
|
raise ValueError("Invalid LoRA checkpoint.") |
|
print('Loading diffusion model ...') |
|
pipe = FluxFillPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-Fill-dev", |
|
torch_dtype=torch.bfloat16 |
|
).to(load_device) |
|
FluxFillPipeline.load_lora_into_transformer( |
|
state_dict=state_dict, |
|
network_alphas=network_alphas, |
|
transformer=pipe.transformer, |
|
) |
|
pipe.transformer.to(torch.bfloat16) |
|
print('Loading Finished!') |
|
|
|
model = {"pipe": pipe} |
|
|
|
return (model,) |
|
|
|
|
|
class CatvtonFluxSampler: |
|
|
|
RETURN_TYPES = ("IMAGE", "IMAGE",) |
|
RETURN_NAMES = ("TryonResult", "GarmentResult",) |
|
|
|
FUNCTION = "sample" |
|
CATEGORY = "CatvtonFluxWrapper" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"CatvtonFluxModel": ("CatvtonFluxModel",), |
|
"prompt": ("STRING",), |
|
"image": ("IMAGE",), |
|
"mask": ("MASK",), |
|
"garment": ("IMAGE",), |
|
"steps": ("INT", {"default": 30}), |
|
"guidance_scale": ("FLOAT", {"default": 30.0}), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"width": ("INT", {"default": 768}), |
|
"height": ("INT", {"default": 1024}), |
|
"keep_in_GPU": ("BOOLEAN", {"default": False}), |
|
}, |
|
} |
|
|
|
def sample(self, CatvtonFluxModel, prompt, image, mask, garment, steps=30, guidance_scale=30.0, seed=-1, width=768, height=1024, keep_in_GPU=False): |
|
load_device = mm.text_encoder_device() |
|
offload_device = mm.text_encoder_offload_device() |
|
|
|
pipe = CatvtonFluxModel["pipe"] |
|
|
|
|
|
if not pipe.transformer.device == load_device: |
|
pipe.transformer.to(load_device) |
|
|
|
size=(width, height) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
image = image.permute(0, 3, 1, 2) |
|
mask = mask[:, None, ...] |
|
garment = garment.permute(0, 3, 1, 2) |
|
|
|
|
|
image = transform(image) |
|
garment = transform(garment) |
|
|
|
|
|
inpaint_image = torch.cat([garment, image], dim=3) |
|
garment_mask = torch.zeros_like(mask) |
|
extended_mask = torch.cat([garment_mask, mask], dim=3) |
|
|
|
result = pipe( |
|
height=size[1], |
|
width=size[0] * 2, |
|
image=inpaint_image[0], |
|
mask_image=extended_mask[0], |
|
num_inference_steps=steps, |
|
generator=torch.Generator(device=load_device).manual_seed(seed), |
|
max_sequence_length=512, |
|
guidance_scale=guidance_scale, |
|
prompt=prompt, |
|
).images[0] |
|
|
|
if not keep_in_GPU: |
|
pipe.transformer.to(offload_device) |
|
|
|
|
|
width = size[0] |
|
garment_result = result.crop((0, 0, width, size[1])) |
|
tryon_result = result.crop((width, 0, width * 2, size[1])) |
|
|
|
tryon_result = torch.tensor( |
|
np.array(tryon_result) / 255.0, dtype=torch.float32 |
|
).unsqueeze(0) |
|
garment_result = torch.tensor( |
|
np.array(garment_result) / 255.0, dtype=torch.float32 |
|
).unsqueeze(0) |
|
|
|
return (tryon_result, garment_result,) |
|
|
|
class LoadCatvtonFluxLoRA: |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "load_catvton_flux_lora" |
|
CATEGORY = "CatvtonFluxWrapper" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"MODEL": ("MODEL",), |
|
}, |
|
} |
|
|
|
def load_catvton_flux_lora(self, MODEL): |
|
|
|
load_device = mm.text_encoder_device() |
|
offload_device = mm.text_encoder_offload_device() |
|
|
|
print("Start loading LoRA weights") |
|
state_dict, _ = FluxFillPipeline.lora_state_dict( |
|
pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", |
|
weight_name="pytorch_lora_weights.safetensors", |
|
return_alphas=True |
|
) |
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) |
|
if not is_correct_format: |
|
raise ValueError("Invalid LoRA checkpoint.") |
|
|
|
print('Start converting the lora ...') |
|
lora_sd = convert_diffusers_flux_lora(state_dict, "") |
|
|
|
print('Start combining the model ...') |
|
state_dict = MODEL.model.diffusion_model.state_dict() |
|
for key, value in state_dict.items(): |
|
if key in lora_sd: |
|
state_dict[key] += lora_sd[key].to(load_device) |
|
|
|
MODEL.model.diffusion_model.load_state_dict(state_dict) |
|
|
|
return (MODEL,) |
|
|
|
class ModelPrinter: |
|
|
|
RETURN_TYPES = ("MODEL", ) |
|
RETURN_NAMES = ("MODEL", ) |
|
|
|
FUNCTION = "print_model" |
|
CATEGORY = "CatvtonFluxWrapper" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"MODEL": ("MODEL",), |
|
}, |
|
} |
|
|
|
def print_model(self, MODEL): |
|
state_dict = MODEL.model.diffusion_model.state_dict() |
|
for key, value in state_dict.items(): |
|
print(f"Key: {key}, Shape: {value.shape}") |
|
return (MODEL,) |
|
|