File size: 5,201 Bytes
61e8157
defabb3
61e8157
1b829a0
 
eb48411
 
 
1b829a0
61e8157
1b829a0
6e5055d
 
 
61e8157
 
11cf435
 
bfc70f6
55671a2
bfc70f6
456a8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7751ed
456a8a0
11cf435
 
 
 
 
 
 
 
 
 
 
 
caed254
11cf435
ea6de43
61e8157
bfc70f6
 
8664707
1e19bab
bfc70f6
 
 
 
 
 
 
 
 
b7751ed
bfc70f6
8664707
1e19bab
bfc70f6
1e19bab
61e8157
 
 
 
bfc70f6
361d5a5
ea6de43
 
61e8157
 
bfc70f6
01da3bb
bfc70f6
 
 
 
 
 
 
 
61e8157
ea6de43
 
 
 
 
 
 
 
61e8157
 
0f7840a
ab6e3a1
61e8157
bfc70f6
 
055c2c9
bfc70f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6de43
bfc70f6
 
 
 
 
 
 
11cf435
 
 
 
 
bfc70f6
ea6de43
61e8157
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import spaces
import os
import sys
import subprocess
import numpy as np
from PIL import Image
import cv2

import torch

from diffusers import StableDiffusion3ControlNetPipeline
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
from diffusers.utils import load_image

# load pipeline
global pipe_canny
global pipe_tile
controlnet_canny = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
controlnet_tile = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Tile")

def resize_image(input_path, output_path, target_height):
    # Open the input image
    img = Image.open(input_path)

    # Calculate the aspect ratio of the original image
    original_width, original_height = img.size
    original_aspect_ratio = original_width / original_height

    # Calculate the new width while maintaining the aspect ratio and the target height
    new_width = int(target_height * original_aspect_ratio)

    # Resize the image while maintaining the aspect ratio and fixing the height
    img = img.resize((new_width, target_height), Image.LANCZOS)

    # Save the resized image
    img.save(output_path)

    return output_path, new_width, target_height

def load_pipeline(control_type):
    if control_type == "canny":
        pipe_canny = StableDiffusion3ControlNetPipeline.from_pretrained(
            "stabilityai/stable-diffusion-3-medium-diffusers",
            controlnet=controlnet_canny
        )
    elif control_type == "tile":
        pipe_tile = StableDiffusion3ControlNetPipeline.from_pretrained(
            "stabilityai/stable-diffusion-3-medium-diffusers",
            controlnet=controlnet_tile
        )

@spaces.GPU(duration=90)
def infer(image_in, prompt, control_type, inference_steps, guidance_scale, control_weight, progress=gr.Progress(track_tqdm=True)):
    
    n_prompt = 'NSFW, nude, naked, porn, ugly'

    if control_type == "canny":
        pipe = pipe_canny
        pipe.to("cuda", torch.float16)
        # Canny preprocessing
        image_to_canny = load_image(image_in)
        image_to_canny = np.array(image_to_canny)
        image_to_canny = cv2.Canny(image_to_canny, 100, 200)
        image_to_canny = image_to_canny[:, :, None]
        image_to_canny = np.concatenate([image_to_canny, image_to_canny, image_to_canny], axis=2)
        image_to_canny = Image.fromarray(image_to_canny)

        control_image = image_to_canny
    
    elif control_type == "tile":
        pipe = pipe_tile
        pipe.to("cuda", torch.float16)
        control_image = load_image(image_in)
 
    # infer
    image = pipe(
        prompt=prompt,
        negative_prompt=n_prompt,
        control_image=control_image, 
        controlnet_conditioning_scale=control_weight,
        num_inference_steps=inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    if control_type == "canny":

        image_redim, w, h = resize_image(image_in, "resized_input.jpg", 1024)
        image = image.resize((w, h), Image.LANCZOS)
    
        return image, gr.update(value=image_to_canny, visible=True)
    
    elif control_type == "tile":
        return image, gr.update(value=None, visible=False)
   

css="""
#col-container{
    margin: 0 auto;
    max-width: 1080px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # SD3 ControlNet

        Experiment with Stable Diffusion 3 ControlNet models proposed and maintained by the InstantX team.<br />
        """)
        
        with gr.Column():
            
            with gr.Row():
                with gr.Column():
                    image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
                    prompt = gr.Textbox(label="Prompt")
                    control_type = gr.Radio(
                        label="Control type",
                        choices = [
                            "canny",
                            "tile"
                        ],
                        value="canny"
                    )
                    with gr.Accordion("Advanced settings", open=False):
                        with gr.Column():
                            with gr.Row():
                                inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25)
                                guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
                            control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
                    
                    submit_canny_btn = gr.Button("Submit")
                with gr.Column():
                    result = gr.Image(label="Result")
                    canny_used = gr.Image(label="Preprocessed Canny", visible=False)


    submit_canny_btn.click(
        fn = load_pipeline,
        inputs = [control_type],
        outputs = None
    ).then(
        fn = infer,
        inputs = [image_in, prompt, control_type, inference_steps, guidance_scale, control_weight],
        outputs = [result, canny_used],
        show_api=False
    )
demo.queue().launch()