Deadmon's picture
Update app.py
48b8bb5 verified
raw
history blame
3.1 kB
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from einops import rearrange
import requests
import spaces
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from gradio_imageslider import ImageSlider
# Pretrained model paths
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
# Load the ControlNet and pipeline models
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Define control modes
CONTROL_MODES = {
0: "Canny",
1: "Tile",
2: "Depth",
3: "Blur",
4: "Pose",
5: "Gray (Low)",
6: "LQ"
}
def preprocess_image(image, target_width, target_height):
image = image.resize((target_width, target_height), Image.LANCZOS)
return image
@spaces.GPU(duration=120)
def generate_image(prompt, control_image, control_mode, controlnet_conditioning_scale, num_steps, guidance, width, height, seed, random_seed):
if random_seed:
seed = np.random.randint(0, 10000)
# Ensure width and height are multiples of 16
width = 16 * (width // 16)
height = 16 * (height // 16)
# Set the seed for reproducibility
torch.manual_seed(seed)
# Preprocess control image
control_image = preprocess_image(control_image, width, height)
# Generate the image with the selected control mode and other parameters
with torch.no_grad():
image = pipe(
prompt,
control_image=control_image,
control_mode=control_mode, # Pass control mode explicitly
width=width,
height=height,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_steps,
guidance_scale=guidance
).images[0]
return image
# Define the Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image(type="pil", label="Control Image"),
gr.Dropdown(choices=[f"{i}: {name}" for i, name in CONTROL_MODES.items()], type="index", label="Control Mode", value=0),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="ControlNet Conditioning Scale"),
gr.Slider(step=1, minimum=1, maximum=64, value=24, label="Num Steps"),
gr.Slider(minimum=0.1, maximum=10, value=3.5, label="Guidance"),
gr.Slider(minimum=128, maximum=1024, step=128, value=512, label="Width"),
gr.Slider(minimum=128, maximum=1024, step=128, value=512, label="Height"),
gr.Number(value=42, label="Seed"),
gr.Checkbox(label="Random Seed")
],
outputs=ImageSlider(label="Generated Image"),
title="FLUX.1 Controlnet with Multiple Modes",
description="Generate images using ControlNet and a text prompt with adjustable control modes."
)
if __name__ == "__main__":
interface.launch()