File size: 1,678 Bytes
d3caf74
 
 
61e8157
eb48411
456a8a0
d3caf74
 
456a8a0
d3caf74
 
456a8a0
d3caf74
 
 
 
 
 
456a8a0
d3caf74
 
 
11cf435
d3caf74
 
 
bfc70f6
d3caf74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfc70f6
d3caf74
 
bfc70f6
d3caf74
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import DiffusionPipeline
import gradio as gr
from PIL import Image

# Load Stable Diffusion 3 (from InstantX)
model_id = "instantx/stable-diffusion-3-medium"

# Load the ControlNet model (use an appropriate pre-trained controlnet model)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)

# Set up the pipeline using both SD3 and ControlNet
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16
)

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)

# Function for Img2Img with ControlNet
def controlnet_img2img(image, prompt, strength=0.8, guidance=7.5):
    image = Image.fromarray(image).convert("RGB")  # Convert to RGB
    
    # Run the pipeline
    result = pipe(prompt=prompt, image=image, strength=strength, guidance_scale=guidance).images[0]
    return result

# Gradio Interface
def img_editor(input_image, prompt):
    result = controlnet_img2img(input_image, prompt)
    return result

# Create Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Img2Img Editor with ControlNet and Stable Diffusion 3")
    with gr.Row():
        image_input = gr.Image(source="upload", type="numpy", label="Input Image")
        prompt_input = gr.Textbox(label="Prompt")
    result_output = gr.Image(label="Output Image")
    
    submit_btn = gr.Button("Generate")
    submit_btn.click(fn=img_editor, inputs=[image_input, prompt_input], outputs=result_output)

# Launch Gradio interface
demo.launch()