File size: 1,957 Bytes
61e8157
 
 
 
 
 
 
dee73d2
61e8157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import torch
from diffusers import StableDiffusion3Pipeline
from diffusers.models.controlnet_sd3 import ControlNetSD3Model
from diffusers.utils.torch_utils import randn_tensor
from diffusers.examples.community.pipeline_stable_diffusion_3_controlnet import StableDiffusion3CommonPipeline

# load pipeline
base_model = 'stabilityai/stable-diffusion-3-medium-diffusers'
pipe = StableDiffusion3CommonPipeline.from_pretrained(
    base_model, 
    controlnet_list=['InstantX/SD3-Controlnet-Canny'],
    hf_token=hf_token
)
pipe.to('cuda:0', torch.float16)

def infer(image_in, prompt):
    prompt = 'Anime style illustration of a girl wearing a suit. A moon in sky. In the background we see a big rain approaching. text "InstantX" on image'
    n_prompt = 'NSFW, nude, naked, porn, ugly'
    # controlnet config
    controlnet_conditioning = [
        dict(
            control_index=0,
            control_image=load_image('https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg'),
            control_weight=0.7,
            control_pooled_projections='zeros'
        )
    ]
    # infer
    image = pipe(
        prompt=prompt,
        negative_prompt=n_prompt,
        controlnet_conditioning=controlnet_conditioning,
        num_inference_steps=28,
        guidance_scale=7.0,
        height=1024,
        width=1024,
        latents=latents,
    ).images[0]

    return image


with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("""
        # SD3 ControlNet
        """)
        image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        result = gr.Image(label="Result")
    
    submit_btn.click(
        fn = infer,
        inputs = [image_in, prompt],
        outputs = [result],
        show_api=False
    )
demo.queue().launch()