InPaiting_with_mask / control.py
swoyam-sarvam
initial commit
2a0517c
import torch
import argparse
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
def main(image, mask, prompt):
check_min_version("0.30.2")
# Enable memory optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
# Set environment variable for memory allocation
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
# Build pipeline components
controlnet = FluxControlNetModel.from_pretrained(
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
torch_dtype=torch.bfloat16,
).to("cuda")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
# Enable memory efficient attention
pipe.enable_attention_slicing(1)
# Load and process images
size = (384, 384) # or even (256, 256)
image = image.convert("RGB").resize(size)
mask = mask.convert("RGB").resize(size)
# Set generator
generator = torch.Generator(device="cuda").manual_seed(24)
# Run inference with memory optimizations
with torch.cuda.amp.autocast(): # Enable automatic mixed precision
result = pipe(
prompt=prompt,
height=size[1],
width=size[0],
control_image=image,
control_mask=mask,
num_inference_steps=28,
generator=generator,
controlnet_conditioning_scale=0.9,
guidance_scale=3.5,
negative_prompt="",
true_guidance_scale=1.0,
).images[0]
# Clear cache after generation
torch.cuda.empty_cache()
print("Successfully inpaint image")
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inpaint an image using FluxControlNetInpaintingPipeline."
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the input image."
)
parser.add_argument(
"--mask_path", type=str, required=True, help="Path to the mask image."
)
parser.add_argument(
"--prompt", type=str, required=True, help="Prompt for the inpainting process."
)
args = parser.parse_args()
result = main(args.image_path, args.mask_path, args.prompt)
result.save("output.png")