craftgamesnetwork commited on
Commit
d02ac40
·
verified ·
1 Parent(s): 086aec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -1
app.py CHANGED
@@ -10,10 +10,11 @@ import gradio as gr
10
  import numpy as np
11
  import spaces
12
  import torch
 
13
  from PIL import Image
14
  from io import BytesIO
15
  from diffusers.utils import load_image
16
- from diffusers import AutoencoderKL, DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting
17
 
18
  DESCRIPTION = "# Run any LoRA or SD Model"
19
  if not torch.cuda.is_available():
@@ -27,6 +28,7 @@ ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
27
  ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"
28
  ENABLE_USE_IMG2IMG = os.getenv("ENABLE_USE_IMG2IMG", "1") == "1"
29
  ENABLE_USE_INPAINTING = os.getenv("ENABLE_USE_INPAINTING", "1") == "1"
 
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
@@ -59,6 +61,7 @@ def generate(
59
  lora_scale: float = 0.7,
60
  use_img2img: bool = False,
61
  use_inpainting: bool = False,
 
62
  url = '',
63
  img_url = '',
64
  mask_url = '',
@@ -84,6 +87,19 @@ def generate(
84
 
85
  image_init = load_image(img_url)
86
  mask_image = load_image(mask_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  if use_lora:
89
  pipe.load_lora_weights(lora)
@@ -107,6 +123,21 @@ def generate(
107
  if not use_negative_prompt_2:
108
  negative_prompt_2 = None # type: ignore
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  if use_inpainting:
111
  image = pipe(
112
  prompt=prompt,
@@ -185,6 +216,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
185
  result = gr.Image(label="Result", show_label=False)
186
  with gr.Accordion("Advanced options", open=False):
187
  with gr.Row():
 
188
  use_inpainting = gr.Checkbox(label='Use Inpainting', value=False, visible=ENABLE_USE_INPAINTING)
189
  use_img2img = gr.Checkbox(label='Use Img2Img', value=False, visible=ENABLE_USE_IMG2IMG)
190
  use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
@@ -319,6 +351,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
319
  queue=False,
320
  api_name=False,
321
  )
 
 
 
 
 
 
 
322
 
323
  gr.on(
324
  triggers=[
@@ -360,6 +399,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
360
  url,
361
  img_url,
362
  mask_url,
 
363
  ],
364
  outputs=result,
365
  api_name="run",
 
10
  import numpy as np
11
  import spaces
12
  import torch
13
+ import cv2
14
  from PIL import Image
15
  from io import BytesIO
16
  from diffusers.utils import load_image
17
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, AutoencoderKL, DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting
18
 
19
  DESCRIPTION = "# Run any LoRA or SD Model"
20
  if not torch.cuda.is_available():
 
28
  ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"
29
  ENABLE_USE_IMG2IMG = os.getenv("ENABLE_USE_IMG2IMG", "1") == "1"
30
  ENABLE_USE_INPAINTING = os.getenv("ENABLE_USE_INPAINTING", "1") == "1"
31
+ ENABLE_USE_CONTROLNET = os.getenv("ENABLE_USE_CONTROLNET", "1") == "1"
32
 
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
 
 
61
  lora_scale: float = 0.7,
62
  use_img2img: bool = False,
63
  use_inpainting: bool = False,
64
+ use_controlnet: bool = False,
65
  url = '',
66
  img_url = '',
67
  mask_url = '',
 
87
 
88
  image_init = load_image(img_url)
89
  mask_image = load_image(mask_url)
90
+
91
+ if use_controlnet:
92
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
93
+ pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
94
+
95
+ image = load_image(img_url)
96
+ image = np.array(image)
97
+
98
+ # get canny image
99
+ image = cv2.Canny(image, 100, 200)
100
+ image = image[:, :, None]
101
+ image = np.concatenate([image, image, image], axis=2)
102
+ canny_image = Image.fromarray(image)
103
 
104
  if use_lora:
105
  pipe.load_lora_weights(lora)
 
123
  if not use_negative_prompt_2:
124
  negative_prompt_2 = None # type: ignore
125
 
126
+ if use_controlnet:
127
+ image = pipe(
128
+ prompt=prompt,
129
+ image=image,
130
+ control_image=canny_image,
131
+ negative_prompt=negative_prompt,
132
+ prompt_2=prompt_2,
133
+ width=width,
134
+ height=height,
135
+ negative_prompt_2=negative_prompt_2,
136
+ guidance_scale=guidance_scale_base,
137
+ num_inference_steps=num_inference_steps_base,
138
+ generator=generator,
139
+ ).images[0]
140
+ return image
141
  if use_inpainting:
142
  image = pipe(
143
  prompt=prompt,
 
216
  result = gr.Image(label="Result", show_label=False)
217
  with gr.Accordion("Advanced options", open=False):
218
  with gr.Row():
219
+ use_controlnet = gr.Checkbox(label='Use Controlnet'), value=False, visible=ENABLE_USE_CONTROLNET)
220
  use_inpainting = gr.Checkbox(label='Use Inpainting', value=False, visible=ENABLE_USE_INPAINTING)
221
  use_img2img = gr.Checkbox(label='Use Img2Img', value=False, visible=ENABLE_USE_IMG2IMG)
222
  use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
 
351
  queue=False,
352
  api_name=False,
353
  )
354
+ use_controlnet.change(
355
+ fn=lambda x: gr.update(visible=x),
356
+ inputs=use_controlnet,
357
+ outputs=img_url,
358
+ queue=False,
359
+ api_name=False,
360
+ )
361
 
362
  gr.on(
363
  triggers=[
 
399
  url,
400
  img_url,
401
  mask_url,
402
+ use_controlnet,
403
  ],
404
  outputs=result,
405
  api_name="run",