Spaces:
Running
Running
File size: 2,960 Bytes
2a0517c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import argparse
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
def main(image, mask, prompt):
check_min_version("0.30.2")
# Enable memory optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
# Set environment variable for memory allocation
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
# Build pipeline components
controlnet = FluxControlNetModel.from_pretrained(
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
torch_dtype=torch.bfloat16,
).to("cuda")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
# Enable memory efficient attention
pipe.enable_attention_slicing(1)
# Load and process images
size = (384, 384) # or even (256, 256)
image = image.convert("RGB").resize(size)
mask = mask.convert("RGB").resize(size)
# Set generator
generator = torch.Generator(device="cuda").manual_seed(24)
# Run inference with memory optimizations
with torch.cuda.amp.autocast(): # Enable automatic mixed precision
result = pipe(
prompt=prompt,
height=size[1],
width=size[0],
control_image=image,
control_mask=mask,
num_inference_steps=28,
generator=generator,
controlnet_conditioning_scale=0.9,
guidance_scale=3.5,
negative_prompt="",
true_guidance_scale=1.0,
).images[0]
# Clear cache after generation
torch.cuda.empty_cache()
print("Successfully inpaint image")
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inpaint an image using FluxControlNetInpaintingPipeline."
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the input image."
)
parser.add_argument(
"--mask_path", type=str, required=True, help="Path to the mask image."
)
parser.add_argument(
"--prompt", type=str, required=True, help="Prompt for the inpainting process."
)
args = parser.parse_args()
result = main(args.image_path, args.mask_path, args.prompt)
result.save("output.png")
|