Flux.1-Fill-dev / app.py
vilarin's picture
Update app.py
efca2cc verified
raw
history blame
3.69 kB
import torch
import spaces
import gradio as gr
from diffusers import FluxInpaintPipeline, FluxTransformer2DModel
import random
import os
import numpy as np
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
MAX_SEED = np.iinfo(np.int32).max
model = "black-forest-labs/FLUX.1-dev"
hf_hub_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", filename="ae.safetensors", local_dir=".")
if torch.cuda.is_available():
transformer = FluxTransformer2DModel.from_single_file(
"https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/blob/main/flux1-fill-dev.safetensors",
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True,
torch_dtype=torch.bfloat16
)
vae = AutoencoderKL.from_pretrained("./ae.safetensors")
pipe = FluxInpaintPipeline.from_pretrained(
model,
vae=vae,
transformer=transformer,
torch_dtype=torch.bfloat16)
pipe.to("cuda")
@spaces.GPU()
def inpaintGen(
imgMask,
inpaint_prompt: str,
strength: float,
guidance: float,
num_steps: int,
seed: int,
randomize_seed: bool,
progress=gr.Progress(track_tqdm=True)):
source_img = imgMask["background"]
mask_img = imgMask["layers"][0]
if not source_path:
raise gr.Error("Please upload an image.")
if not mask_path:
raise gr.Error("Please draw a mask on the image.")
width, height = source_img.size
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
result = pipe(
prompt=inpaint_prompt,
image=source_img,
seed=seed,
mask_image=mask_img,
width=width,
height=height,
strength=strength,
num_inference_steps=num_steps,
generator=generator,
guidance_scale=guidance
).images[0]
return result
with gr.Blocks(theme="ocean", title="Flux.1 dev inpaint", css=CSS) as demo:
gr.HTML("<h1><center>Flux.1 dev Inpaint</center></h1>")
gr.HTML("""
<p>
<center>
A partial redraw of the image based on your prompt words and occluded parts.
</center>
</p>
""")
with gr.Row():
with gr.Column():
imgMask = gr.ImageMask(type="pil", label="Image", layers=False, height=800)
inpaint_prompt = gr.Textbox(label='Prompts ✏️', placeholder="A hat...")
with gr.Row():
Inpaint_sendBtn = gr.Button(value="Submit", variant='primary')
Inpaint_clearBtn = gr.ClearButton([imgMask, inpaint_prompt], value="Clear")
image_out = gr.Image(type="pil", label="Output", height=960)
with gr.Accordion("Advanced ⚙️", open=False):
strength = gr.Slider(label="Strength", minimum=0, maximum=1, value=1, step=0.1)
guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, value=7.5, step=0.1)
num_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
seed = gr.Number(label="Seed", value=42, precision=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
gr.on(
triggers = [
inpaint_prompt.submit,
Inpaint_sendBtn.click,
],
fn = inpaintGen,
inputs = [
imgMask,
inpaint_prompt,
strength,
guidance,
num_steps,
seed,
randomize_seed
],
outputs = [image_out, seed]
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)