jordandotzel commited on
Commit
38f60df
·
verified ·
1 Parent(s): 2ff3a11

add threshold and negative prompt

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -10,10 +10,10 @@ import cv2
10
  import torch
11
 
12
  from diffusers import StableDiffusion3ControlNetPipeline
13
- from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
14
  from diffusers.utils import load_image
15
 
16
- # load pipeline
17
  controlnet_canny = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
18
  pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
19
  "stabilityai/stable-diffusion-3-medium-diffusers",
@@ -41,38 +41,35 @@ def resize_image(input_path, output_path, target_height):
41
 
42
 
43
  @spaces.GPU(duration=90)
44
- def infer(image_in, prompt, inference_steps, guidance_scale, control_weight, progress=gr.Progress(track_tqdm=True)):
45
-
46
- n_prompt = 'NSFW, nude, naked, porn, ugly'
47
 
48
  # Canny preprocessing
49
  image_to_canny = load_image(image_in)
50
  image_to_canny = np.array(image_to_canny)
51
- image_to_canny = cv2.Canny(image_to_canny, 100, 200)
52
  image_to_canny = image_to_canny[:, :, None]
53
  image_to_canny = np.concatenate([image_to_canny, image_to_canny, image_to_canny], axis=2)
54
  image_to_canny = Image.fromarray(image_to_canny)
55
 
56
  control_image = image_to_canny
57
 
58
- # infer
59
  image = pipe(
60
  prompt=prompt,
61
- negative_prompt=n_prompt,
62
  control_image=control_image,
63
  controlnet_conditioning_scale=control_weight,
64
  num_inference_steps=inference_steps,
65
  guidance_scale=guidance_scale,
66
  ).images[0]
67
 
68
-
69
  image_redim, w, h = resize_image(image_in, "resized_input.jpg", 1024)
70
  image = image.resize((w, h), Image.LANCZOS)
71
 
72
  return image, gr.update(value=image_to_canny, visible=True)
73
-
74
 
75
- css="""
 
76
  #col-container{
77
  margin: 0 auto;
78
  max-width: 1080px;
@@ -92,6 +89,7 @@ with gr.Blocks(css=css) as demo:
92
  with gr.Column():
93
  image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
94
  prompt = gr.Textbox(label="Prompt")
 
95
 
96
  with gr.Accordion("Advanced settings", open=False):
97
  with gr.Column():
@@ -99,22 +97,23 @@ with gr.Blocks(css=css) as demo:
99
  inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25)
100
  guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
101
  control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
 
 
102
 
103
  submit_canny_btn = gr.Button("Submit")
104
 
105
  with gr.Column():
106
  result = gr.Image(label="Result")
107
  canny_used = gr.Image(label="Preprocessed Canny", visible=False)
108
-
109
 
110
 
111
  submit_canny_btn.click(
112
  fn=infer,
113
- inputs=[image_in, prompt, inference_steps, guidance_scale, control_weight],
114
  outputs=[result, canny_used],
115
  api_name="predict",
116
  show_api=True
117
  )
118
 
119
- # Enable API by setting enable_api=True
120
  demo.queue().launch(show_api=True)
 
 
10
  import torch
11
 
12
  from diffusers import StableDiffusion3ControlNetPipeline
13
+ from diffusers.models import SD3ControlNetModel
14
  from diffusers.utils import load_image
15
 
16
+ # Load pipeline
17
  controlnet_canny = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
18
  pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
19
  "stabilityai/stable-diffusion-3-medium-diffusers",
 
41
 
42
 
43
  @spaces.GPU(duration=90)
44
+ def infer(image_in, prompt, negative_prompt, inference_steps, guidance_scale, control_weight, low_threshold, high_threshold, progress=gr.Progress(track_tqdm=True)):
 
 
45
 
46
  # Canny preprocessing
47
  image_to_canny = load_image(image_in)
48
  image_to_canny = np.array(image_to_canny)
49
+ image_to_canny = cv2.Canny(image_to_canny, low_threshold, high_threshold)
50
  image_to_canny = image_to_canny[:, :, None]
51
  image_to_canny = np.concatenate([image_to_canny, image_to_canny, image_to_canny], axis=2)
52
  image_to_canny = Image.fromarray(image_to_canny)
53
 
54
  control_image = image_to_canny
55
 
56
+ # Infer
57
  image = pipe(
58
  prompt=prompt,
59
+ negative_prompt=negative_prompt,
60
  control_image=control_image,
61
  controlnet_conditioning_scale=control_weight,
62
  num_inference_steps=inference_steps,
63
  guidance_scale=guidance_scale,
64
  ).images[0]
65
 
 
66
  image_redim, w, h = resize_image(image_in, "resized_input.jpg", 1024)
67
  image = image.resize((w, h), Image.LANCZOS)
68
 
69
  return image, gr.update(value=image_to_canny, visible=True)
 
70
 
71
+
72
+ css = """
73
  #col-container{
74
  margin: 0 auto;
75
  max-width: 1080px;
 
89
  with gr.Column():
90
  image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
91
  prompt = gr.Textbox(label="Prompt")
92
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompts here")
93
 
94
  with gr.Accordion("Advanced settings", open=False):
95
  with gr.Column():
 
97
  inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25)
98
  guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
99
  control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
100
+ low_threshold = gr.Slider(label="Canny Low Threshold", minimum=0, maximum=255, step=1, value=100)
101
+ high_threshold = gr.Slider(label="Canny High Threshold", minimum=0, maximum=255, step=1, value=200)
102
 
103
  submit_canny_btn = gr.Button("Submit")
104
 
105
  with gr.Column():
106
  result = gr.Image(label="Result")
107
  canny_used = gr.Image(label="Preprocessed Canny", visible=False)
 
108
 
109
 
110
  submit_canny_btn.click(
111
  fn=infer,
112
+ inputs=[image_in, prompt, negative_prompt, inference_steps, guidance_scale, control_weight, low_threshold, high_threshold],
113
  outputs=[result, canny_used],
114
  api_name="predict",
115
  show_api=True
116
  )
117
 
 
118
  demo.queue().launch(show_api=True)
119
+