Spaces:
Running
Running
import torch | |
import argparse | |
from diffusers.utils import load_image, check_min_version | |
from controlnet_flux import FluxControlNetModel | |
from transformer_flux import FluxTransformer2DModel | |
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline | |
def main(image, mask, prompt): | |
check_min_version("0.30.2") | |
# Enable memory optimizations | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
# Set environment variable for memory allocation | |
import os | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
# Build pipeline components | |
controlnet = FluxControlNetModel.from_pretrained( | |
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
transformer = FluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
pipe = FluxControlNetInpaintingPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
controlnet=controlnet, | |
transformer=transformer, | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
# Enable memory efficient attention | |
pipe.enable_attention_slicing(1) | |
# Load and process images | |
size = (384, 384) # or even (256, 256) | |
image = image.convert("RGB").resize(size) | |
mask = mask.convert("RGB").resize(size) | |
# Set generator | |
generator = torch.Generator(device="cuda").manual_seed(24) | |
# Run inference with memory optimizations | |
with torch.cuda.amp.autocast(): # Enable automatic mixed precision | |
result = pipe( | |
prompt=prompt, | |
height=size[1], | |
width=size[0], | |
control_image=image, | |
control_mask=mask, | |
num_inference_steps=28, | |
generator=generator, | |
controlnet_conditioning_scale=0.9, | |
guidance_scale=3.5, | |
negative_prompt="", | |
true_guidance_scale=1.0, | |
).images[0] | |
# Clear cache after generation | |
torch.cuda.empty_cache() | |
print("Successfully inpaint image") | |
return result | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Inpaint an image using FluxControlNetInpaintingPipeline." | |
) | |
parser.add_argument( | |
"--image_path", type=str, required=True, help="Path to the input image." | |
) | |
parser.add_argument( | |
"--mask_path", type=str, required=True, help="Path to the mask image." | |
) | |
parser.add_argument( | |
"--prompt", type=str, required=True, help="Prompt for the inpainting process." | |
) | |
args = parser.parse_args() | |
result = main(args.image_path, args.mask_path, args.prompt) | |
result.save("output.png") | |