File size: 3,954 Bytes
a686326
23c3d80
 
 
 
 
7f15638
 
23c3d80
68bcf92
23c3d80
7f15638
 
 
23c3d80
 
 
7f15638
 
 
23c3d80
 
 
7f15638
23c3d80
e19c312
23c3d80
 
 
 
9f0f9a3
23c3d80
 
 
 
 
 
 
 
 
 
9f0f9a3
23c3d80
 
7f15638
 
 
23c3d80
754b60e
7f15638
23c3d80
 
 
 
 
a686326
23c3d80
e794576
7f15638
 
 
23c3d80
7f15638
 
 
 
 
 
 
 
 
 
 
 
 
e794576
23c3d80
 
 
 
 
7f15638
 
23c3d80
 
 
7f15638
 
23c3d80
 
 
 
 
9f0f9a3
23c3d80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from einops import rearrange
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.utils import load_image
from gradio_imageslider import ImageSlider  # Import ImageSlider
import cv2  # Import OpenCV for Canny edge detection

# Load the new ControlNet model
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'

device = torch.device("cuda")

controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to(device)

def preprocess_image(image, target_width, target_height, crop=True):
    if crop:
        image = image.crop((0, 0, min(image.size), min(image.size)))  # Crop the image to square
        original_width, original_height = image.size

        # Resize to match the target size without stretching
        scale = max(target_width / original_width, target_height / original_height)
        resized_width = int(scale * original_width)
        resized_height = int(scale * original_height)

        image = image.resize((resized_width, resized_height), Image.LANCZOS)
        
        # Center crop to match the target dimensions
        left = (resized_width - target_width) // 2
        top = (resized_height - target_height) // 2
        image = image.crop((left, top, left + target_width, top + target_height))
    else:
        image = image.resize((target_width, target_height), Image.LANCZOS)
    
    return image

def preprocess_canny_image(image, target_width, target_height, crop=True):
    image = preprocess_image(image, target_width, target_height, crop=crop)
    image = np.array(image.convert('L'))  # Convert to grayscale for Canny processing
    image = cv2.Canny(image, 100, 200)  # Apply Canny edge detection
    image = Image.fromarray(image)
    return image

def generate_image(prompt, control_image, num_steps=24, guidance=3.5, width=512, height=512, seed=42, random_seed=False, control_mode=0):
    if random_seed:
        seed = np.random.randint(0, 10000)
    
    if not os.path.isdir("./controlnet_results/"):
        os.makedirs("./controlnet_results/")

    torch.manual_seed(seed)

    control_image = preprocess_canny_image(control_image, width, height)  # Preprocess the control image for Canny mode
    
    controlnet_conditioning_scale = 0.5  # ControlNet conditioning scale
    
    # Generate the image using the pipeline
    image = pipe(
        prompt, 
        control_image=control_image,
        control_mode=control_mode,
        width=width,
        height=height,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_steps, 
        guidance_scale=guidance,
    ).images[0]

    return [control_image, image]  # Return both images for slider

interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Control Image"),
        gr.Slider(step=1, minimum=1, maximum=64, value=24, label="Num Steps"),
        gr.Slider(minimum=0.1, maximum=10, value=3.5, label="Guidance"),
        gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Width"),
        gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Height"),
        gr.Number(value=42, label="Seed"),
        gr.Checkbox(label="Random Seed"),
        gr.Radio(choices=[0, 1, 2, 3, 4, 5, 6], value=0, label="Control Mode")
    ],
    outputs=ImageSlider(label="Before / After"),  # Use ImageSlider as the output
    title="FLUX.1 Controlnet Canny",
    description="Generate images using ControlNet and a text prompt.\n[[non-commercial license, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]"
)

if __name__ == "__main__":
    interface.launch()