gaur3009 commited on
Commit
d3caf74
·
verified ·
1 Parent(s): ae61bc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -108
app.py CHANGED
@@ -1,118 +1,49 @@
 
 
 
1
  import gradio as gr
2
- import spaces
3
- import os
4
- import sys
5
- import subprocess
6
- import numpy as np
7
  from PIL import Image
8
- import cv2
9
-
10
- import torch
11
-
12
- from diffusers import StableDiffusion3ControlNetPipeline
13
- from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
14
- from diffusers.utils import load_image
15
-
16
- # load pipeline
17
- controlnet_canny = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
18
- pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
19
- "stabilityai/stable-diffusion-3-medium-diffusers",
20
- controlnet=controlnet_canny
21
- ).to("cuda", torch.float16)
22
-
23
- def resize_image(input_path, output_path, target_height):
24
- # Open the input image
25
- img = Image.open(input_path)
26
-
27
- # Calculate the aspect ratio of the original image
28
- original_width, original_height = img.size
29
- original_aspect_ratio = original_width / original_height
30
-
31
- # Calculate the new width while maintaining the aspect ratio and the target height
32
- new_width = int(target_height * original_aspect_ratio)
33
 
34
- # Resize the image while maintaining the aspect ratio and fixing the height
35
- img = img.resize((new_width, target_height), Image.LANCZOS)
36
 
37
- # Save the resized image
38
- img.save(output_path)
39
 
40
- return output_path, new_width, target_height
 
 
 
 
 
41
 
 
 
 
42
 
43
- @spaces.GPU(duration=90)
44
- def infer(image_in, prompt, inference_steps, guidance_scale, control_weight, progress=gr.Progress(track_tqdm=True)):
45
-
46
- n_prompt = 'NSFW, nude, naked, porn, ugly'
47
-
48
- # Canny preprocessing
49
- image_to_canny = load_image(image_in)
50
- image_to_canny = np.array(image_to_canny)
51
- image_to_canny = cv2.Canny(image_to_canny, 100, 200)
52
- image_to_canny = image_to_canny[:, :, None]
53
- image_to_canny = np.concatenate([image_to_canny, image_to_canny, image_to_canny], axis=2)
54
- image_to_canny = Image.fromarray(image_to_canny)
55
-
56
- control_image = image_to_canny
57
-
58
- # infer
59
- image = pipe(
60
- prompt=prompt,
61
- negative_prompt=n_prompt,
62
- control_image=control_image,
63
- controlnet_conditioning_scale=control_weight,
64
- num_inference_steps=inference_steps,
65
- guidance_scale=guidance_scale,
66
- ).images[0]
67
-
68
 
69
- image_redim, w, h = resize_image(image_in, "resized_input.jpg", 1024)
70
- image = image.resize((w, h), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- return image, gr.update(value=image_to_canny, visible=True)
73
-
74
-
75
- css="""
76
- #col-container{
77
- margin: 0 auto;
78
- max-width: 1080px;
79
- }
80
- """
81
- with gr.Blocks(css=css) as demo:
82
- with gr.Column(elem_id="col-container"):
83
- gr.Markdown("""
84
- # SD3 ControlNet
85
-
86
- Experiment with Stable Diffusion 3 ControlNet models proposed and maintained by the InstantX team.<br />
87
- Model card: [InstantX/SD3-Controlnet-Canny](https://huggingface.co/InstantX/SD3-Controlnet-Canny)
88
- """)
89
-
90
- with gr.Column():
91
-
92
- with gr.Row():
93
- with gr.Column():
94
- image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
95
- prompt = gr.Textbox(label="Prompt")
96
-
97
- with gr.Accordion("Advanced settings", open=False):
98
- with gr.Column():
99
- with gr.Row():
100
- inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25)
101
- guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
102
- control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
103
-
104
- submit_canny_btn = gr.Button("Submit")
105
-
106
- with gr.Column():
107
- result = gr.Image(label="Result")
108
- canny_used = gr.Image(label="Preprocessed Canny", visible=False)
109
-
110
-
111
 
112
- submit_canny_btn.click(
113
- fn = infer,
114
- inputs = [image_in, prompt, inference_steps, guidance_scale, control_weight],
115
- outputs = [result, canny_used],
116
- show_api=False
117
- )
118
- demo.queue().launch()
 
1
+ import torch
2
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
+ from diffusers import DiffusionPipeline
4
  import gradio as gr
 
 
 
 
 
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load Stable Diffusion 3 (from InstantX)
8
+ model_id = "instantx/stable-diffusion-3-medium"
9
 
10
+ # Load the ControlNet model (use an appropriate pre-trained controlnet model)
11
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
12
 
13
+ # Set up the pipeline using both SD3 and ControlNet
14
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
15
+ model_id,
16
+ controlnet=controlnet,
17
+ torch_dtype=torch.float16
18
+ )
19
 
20
+ # Use GPU if available
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ pipe.to(device)
23
 
24
+ # Function for Img2Img with ControlNet
25
+ def controlnet_img2img(image, prompt, strength=0.8, guidance=7.5):
26
+ image = Image.fromarray(image).convert("RGB") # Convert to RGB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Run the pipeline
29
+ result = pipe(prompt=prompt, image=image, strength=strength, guidance_scale=guidance).images[0]
30
+ return result
31
+
32
+ # Gradio Interface
33
+ def img_editor(input_image, prompt):
34
+ result = controlnet_img2img(input_image, prompt)
35
+ return result
36
+
37
+ # Create Gradio UI
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("## Img2Img Editor with ControlNet and Stable Diffusion 3")
40
+ with gr.Row():
41
+ image_input = gr.Image(source="upload", type="numpy", label="Input Image")
42
+ prompt_input = gr.Textbox(label="Prompt")
43
+ result_output = gr.Image(label="Output Image")
44
 
45
+ submit_btn = gr.Button("Generate")
46
+ submit_btn.click(fn=img_editor, inputs=[image_input, prompt_input], outputs=result_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Launch Gradio interface
49
+ demo.launch()