sd3-ControlNet / app.py
gaur3009's picture
Update app.py
9c80116 verified
raw
history blame
1.68 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import DiffusionPipeline
import gradio as gr
from PIL import Image
# Load Stable Diffusion 3 (from InstantX)
model_id = "stabilityai/stable-diffusion-3-medium"
# Load the ControlNet model (use an appropriate pre-trained controlnet model)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
# Set up the pipeline using both SD3 and ControlNet
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id,
controlnet=controlnet,
torch_dtype=torch.float16
)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
# Function for Img2Img with ControlNet
def controlnet_img2img(image, prompt, strength=0.8, guidance=7.5):
image = Image.fromarray(image).convert("RGB") # Convert to RGB
# Run the pipeline
result = pipe(prompt=prompt, image=image, strength=strength, guidance_scale=guidance).images[0]
return result
# Gradio Interface
def img_editor(input_image, prompt):
result = controlnet_img2img(input_image, prompt)
return result
# Create Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Img2Img Editor with ControlNet and Stable Diffusion 3")
with gr.Row():
image_input = gr.Image(source="upload", type="numpy", label="Input Image")
prompt_input = gr.Textbox(label="Prompt")
result_output = gr.Image(label="Output Image")
submit_btn = gr.Button("Generate")
submit_btn.click(fn=img_editor, inputs=[image_input, prompt_input], outputs=result_output)
# Launch Gradio interface
demo.launch()