Deadmon commited on
Commit
153d40e
·
verified ·
1 Parent(s): 3713c54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -53
app.py CHANGED
@@ -1,25 +1,140 @@
1
  import os
2
  import random
 
3
  import gradio as gr
 
 
4
  import torch
5
- from PIL import Image
 
6
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
7
- from diffusers import EulerAncestralDiscreteScheduler
8
- from huggingface_hub import snapshot_download
9
- import numpy as np
10
- from gradio_imageslider import ImageSlider
11
- from controlnet_aux import HEDdetector
 
12
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- MAX_SEED = np.iinfo(np.int32).max
15
 
16
- DESCRIPTION = '''This is a demo for generating images using Stable Diffusion XL models and ControlNet.'''
17
 
18
- # Setup models
 
 
19
  ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
20
  ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
21
  ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
22
 
 
23
  vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
24
  vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
25
  vae_stallion = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_stallion, "vae"), torch_dtype=torch.float16)
@@ -29,84 +144,128 @@ controlnet_cyber = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl
29
  controlnet_stallion = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)
30
 
31
  pipe_pony = StableDiffusionXLControlNetPipeline.from_pretrained(
32
- ckpt_dir_pony, controlnet=controlnet_pony, vae=vae_pony, torch_dtype=torch.float16, scheduler=EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
33
  )
34
  pipe_cyber = StableDiffusionXLControlNetPipeline.from_pretrained(
35
- ckpt_dir_cyber, controlnet=controlnet_cyber, vae=vae_cyber, torch_dtype=torch.float16, scheduler=EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
36
  )
37
  pipe_stallion = StableDiffusionXLControlNetPipeline.from_pretrained(
38
- ckpt_dir_stallion, controlnet=controlnet_stallion, vae=vae_stallion, torch_dtype=torch.float16, scheduler=EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
39
  )
40
 
41
- styles = {
42
- "(No style)": ("{prompt}", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"),
43
- "Cinematic": ("cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"),
44
- "3D Model": ("professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", "ugly, deformed, noisy, low poly, blurry, painting"),
45
- "Anime": ("anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", "photo, deformed, black and white, realism, disfigured, low contrast"),
46
- "Digital Art": ("concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", "photo, photorealistic, realism, ugly"),
47
- "Photographic": ("cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly"),
48
- "Pixel art": ("pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic"),
49
- "Fantasy art": ("ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white"),
50
- "Neonpunk": ("neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"),
51
- "Manga": ("manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style"),
52
- }
53
 
54
- DEFAULT_STYLE_NAME = "(No style)"
 
 
 
55
 
56
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
57
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
58
- return p.replace("{prompt}", positive), n + negative
59
 
60
- def generate_weighted_prompt(prompt: str) -> str:
61
- from compel import Compel
62
- comp = Compel(prompt)
63
- return comp.with_strength(0.8)
 
 
64
 
65
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  if randomize_seed:
67
- return random.randint(0, MAX_SEED)
68
  return seed
69
 
70
- def run(image, prompt, negative_prompt, model_choice, style_name, num_steps, guidance_scale, controlnet_conditioning_scale, seed, use_hed, use_canny):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  composite_image = image['composite']
72
  width, height = composite_image.size
73
 
 
74
  max_size = 1024
75
  ratio = min(max_size / width, max_size / height)
76
  new_width = int(width * ratio)
77
  new_height = int(height * ratio)
 
 
78
  resized_image = composite_image.resize((new_width, new_height), Image.LANCZOS)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
81
- prompt = generate_weighted_prompt(prompt)
82
 
83
- generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
84
 
 
85
  if model_choice == "Pony Realism v21":
86
  pipe = pipe_pony
87
  elif model_choice == "Cyber Realistic Pony v61":
88
  pipe = pipe_cyber
89
- else:
90
  pipe = pipe_stallion
91
 
92
- pipe.to("cuda" if torch.cuda.is_available() else "cpu")
93
-
94
- out = pipe(
95
- prompt=prompt,
96
- negative_prompt=negative_prompt,
97
- image=resized_image,
98
- num_inference_steps=num_steps,
99
- generator=generator,
100
- controlnet_conditioning_scale=controlnet_conditioning_scale,
101
- guidance_scale=guidance_scale,
102
- width=new_width,
103
- height=new_height,
104
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  pipe.to("cpu")
107
  torch.cuda.empty_cache()
108
 
109
- return out
110
 
111
  with gr.Blocks(css="style.css", js=js_func) as demo:
112
  gr.Markdown(DESCRIPTION, elem_id="description")
@@ -121,7 +280,7 @@ with gr.Blocks(css="style.css", js=js_func) as demo:
121
  with gr.Group():
122
  image = gr.ImageEditor(type="pil", label="Sketch your image or upload one", width=512, height=512)
123
  prompt = gr.Textbox(label="Prompt")
124
- style = gr.Dropdown(label="Style", choices=list(styles.keys()), value=DEFAULT_STYLE_NAME)
125
  model_choice = gr.Dropdown(
126
  ["Pony Realism v21", "Cyber Realistic Pony v61", "Stallion Dreams Pony Realistic v1"],
127
  label="Model Choice",
@@ -169,6 +328,7 @@ with gr.Blocks(css="style.css", js=js_func) as demo:
169
  with gr.Group():
170
  image_slider = ImageSlider(position=0.5)
171
 
 
172
  inputs = [
173
  image,
174
  prompt,
@@ -193,4 +353,6 @@ with gr.Blocks(css="style.css", js=js_func) as demo:
193
  fn=run, inputs=inputs, outputs=outputs
194
  )
195
 
196
- demo.queue().launch(show_error=True, ssl_verify=False)
 
 
 
1
  import os
2
  import random
3
+ import spaces
4
  import gradio as gr
5
+ import numpy as np
6
+ import PIL.Image
7
  import torch
8
+ import torchvision.transforms.functional as TF
9
+
10
  from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
11
+ from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
12
+ from controlnet_aux import PidiNetDetector, HEDdetector
13
+ from diffusers.utils import load_image
14
+ from huggingface_hub import HfApi, snapshot_download
15
+ from pathlib import Path
16
+ from PIL import Image, ImageOps
17
  import cv2
18
+ from gradio_imageslider import ImageSlider
19
+
20
+ js_func = """
21
+ function refresh() {
22
+ const url = new URL(window.location);
23
+ }
24
+ """
25
+ def nms(x, t, s):
26
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
27
+
28
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
29
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
30
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
31
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
32
+
33
+ y = np.zeros_like(x)
34
+
35
+ for f in [f1, f2, f3, f4]:
36
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
37
+
38
+ z = np.zeros_like(y, dtype=np.uint8)
39
+ z[y > t] = 255
40
+ return z
41
+
42
+ def HWC3(x):
43
+ assert x.dtype == np.uint8
44
+ if x.ndim == 2:
45
+ x = x[:, :, None]
46
+ assert x.ndim == 3
47
+ H, W, C = x.shape
48
+ assert C == 1 or C == 3 or C == 4
49
+ if C == 3:
50
+ return x
51
+ if C == 1:
52
+ return np.concatenate([x, x, x], axis=2)
53
+ if C == 4:
54
+ color = x[:, :, 0:3].astype(np.float32)
55
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
56
+ y = color * alpha + 255.0 * (1.0 - alpha)
57
+ y = y.clip(0, 255).astype(np.uint8)
58
+ return y
59
+
60
+ DESCRIPTION = ''''''
61
+
62
+ if not torch.cuda.is_available():
63
+ DESCRIPTION += ""
64
+
65
+ style_list = [
66
+ {
67
+ "name": "(No style)",
68
+ "prompt": "{prompt}",
69
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
70
+ },
71
+ {
72
+ "name": "Cinematic",
73
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
74
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
75
+ },
76
+ {
77
+ "name": "3D Model",
78
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
79
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
80
+ },
81
+ {
82
+ "name": "Anime",
83
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
84
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
85
+ },
86
+ {
87
+ "name": "Digital Art",
88
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
89
+ "negative_prompt": "photo, photorealistic, realism, ugly",
90
+ },
91
+ {
92
+ "name": "Photographic",
93
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
94
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
95
+ },
96
+ {
97
+ "name": "Pixel art",
98
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
99
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
100
+ },
101
+ {
102
+ "name": "Fantasy art",
103
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
104
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
105
+ },
106
+ {
107
+ "name": "Neonpunk",
108
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
109
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
110
+ },
111
+ {
112
+ "name": "Manga",
113
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
114
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
115
+ },
116
+ ]
117
+
118
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
119
+ STYLE_NAMES = list(styles.keys())
120
+ DEFAULT_STYLE_NAME = "(No style)"
121
+
122
+
123
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
124
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
125
+ return p.replace("{prompt}", positive), n + negative
126
 
 
127
 
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
 
130
+ eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
131
+
132
+ # Download the model files
133
  ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
134
  ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
135
  ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
136
 
137
+ # Load the models
138
  vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
139
  vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
140
  vae_stallion = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_stallion, "vae"), torch_dtype=torch.float16)
 
144
  controlnet_stallion = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)
145
 
146
  pipe_pony = StableDiffusionXLControlNetPipeline.from_pretrained(
147
+ ckpt_dir_pony, controlnet=controlnet_pony, vae=vae_pony, torch_dtype=torch.float16, scheduler=eulera_scheduler
148
  )
149
  pipe_cyber = StableDiffusionXLControlNetPipeline.from_pretrained(
150
+ ckpt_dir_cyber, controlnet=controlnet_cyber, vae=vae_cyber, torch_dtype=torch.float16, scheduler=eulera_scheduler
151
  )
152
  pipe_stallion = StableDiffusionXLControlNetPipeline.from_pretrained(
153
+ ckpt_dir_stallion, controlnet=controlnet_stallion, vae=vae_stallion, torch_dtype=torch.float16, scheduler=eulera_scheduler
154
  )
155
 
156
+ MAX_SEED = np.iinfo(np.int32).max
157
+ processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
158
+ def nms(x, t, s):
159
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
 
 
 
 
 
 
 
 
160
 
161
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
162
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
163
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
164
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
165
 
166
+ y = np.zeros_like(x)
 
 
167
 
168
+ for f in [f1, f2, f3, f4]:
169
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
170
+
171
+ z = np.zeros_like(y, dtype=np.uint8)
172
+ z[y > t] = 255
173
+ return z
174
 
175
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
176
  if randomize_seed:
177
+ seed = random.randint(0, MAX_SEED)
178
  return seed
179
 
180
+ @spaces.GPU(duration=120)
181
+ def run(
182
+ image: dict,
183
+ prompt: str,
184
+ negative_prompt: str,
185
+ model_choice: str, # Add this new input
186
+ style_name: str = DEFAULT_STYLE_NAME,
187
+ num_steps: int = 25,
188
+ guidance_scale: float = 5,
189
+ controlnet_conditioning_scale: float = 1.0,
190
+ seed: int = 0,
191
+ use_hed: bool = False,
192
+ use_canny: bool = False,
193
+ progress=gr.Progress(track_tqdm=True),
194
+ ) -> PIL.Image.Image:
195
+ # Get the composite image from the EditorValue dict
196
  composite_image = image['composite']
197
  width, height = composite_image.size
198
 
199
+ # Calculate new dimensions to fit within 1024x1024 while maintaining aspect ratio
200
  max_size = 1024
201
  ratio = min(max_size / width, max_size / height)
202
  new_width = int(width * ratio)
203
  new_height = int(height * ratio)
204
+
205
+ # Resize the image
206
  resized_image = composite_image.resize((new_width, new_height), Image.LANCZOS)
207
 
208
+ if use_canny:
209
+ controlnet_img = np.array(resized_image)
210
+ controlnet_img = cv2.Canny(controlnet_img, 100, 200)
211
+ controlnet_img = HWC3(controlnet_img)
212
+ image = Image.fromarray(controlnet_img)
213
+ elif not use_hed:
214
+ controlnet_img = resized_image
215
+ image = resized_image
216
+ else:
217
+ controlnet_img = processor(resized_image, scribble=False)
218
+ controlnet_img = np.array(controlnet_img)
219
+ controlnet_img = nms(controlnet_img, 127, 3)
220
+ controlnet_img = cv2.GaussianBlur(controlnet_img, (0, 0), 3)
221
+ random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
222
+ controlnet_img[controlnet_img > random_val] = 255
223
+ controlnet_img[controlnet_img < 255] = 0
224
+ image = Image.fromarray(controlnet_img)
225
+
226
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
 
227
 
228
+ generator = torch.Generator(device=device).manual_seed(seed)
229
 
230
+ # Select the appropriate pipe based on the model choice
231
  if model_choice == "Pony Realism v21":
232
  pipe = pipe_pony
233
  elif model_choice == "Cyber Realistic Pony v61":
234
  pipe = pipe_cyber
235
+ else: # "Stallion Dreams Pony Realistic v1"
236
  pipe = pipe_stallion
237
 
238
+ pipe.to(device)
239
+
240
+ if use_canny:
241
+ out = pipe(
242
+ prompt=prompt,
243
+ negative_prompt=negative_prompt,
244
+ image=image,
245
+ num_inference_steps=num_steps,
246
+ generator=generator,
247
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
248
+ guidance_scale=guidance_scale,
249
+ width=new_width,
250
+ height=new_height,
251
+ ).images[0]
252
+ else:
253
+ out = pipe(
254
+ prompt=prompt,
255
+ negative_prompt=negative_prompt,
256
+ image=image,
257
+ num_inference_steps=num_steps,
258
+ generator=generator,
259
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
260
+ guidance_scale=guidance_scale,
261
+ width=new_width,
262
+ height=new_height,
263
+ ).images[0]
264
 
265
  pipe.to("cpu")
266
  torch.cuda.empty_cache()
267
 
268
+ return (controlnet_img, out)
269
 
270
  with gr.Blocks(css="style.css", js=js_func) as demo:
271
  gr.Markdown(DESCRIPTION, elem_id="description")
 
280
  with gr.Group():
281
  image = gr.ImageEditor(type="pil", label="Sketch your image or upload one", width=512, height=512)
282
  prompt = gr.Textbox(label="Prompt")
283
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
284
  model_choice = gr.Dropdown(
285
  ["Pony Realism v21", "Cyber Realistic Pony v61", "Stallion Dreams Pony Realistic v1"],
286
  label="Model Choice",
 
328
  with gr.Group():
329
  image_slider = ImageSlider(position=0.5)
330
 
331
+
332
  inputs = [
333
  image,
334
  prompt,
 
353
  fn=run, inputs=inputs, outputs=outputs
354
  )
355
 
356
+
357
+
358
+ demo.queue().launch(show_error=True, ssl_verify=False)