hysts HF staff commited on
Commit
e4d01a7
·
1 Parent(s): 1420de7
Files changed (1) hide show
  1. model.py +61 -0
model.py CHANGED
@@ -12,6 +12,7 @@ from diffusers import (ControlNetModel, DiffusionPipeline,
12
 
13
  from cv_utils import resize_image
14
  from preprocessor import Preprocessor
 
15
 
16
  CONTROLNET_MODEL_IDS = {
17
  'Openpose': 'lllyasviel/control_v11p_sd15_openpose',
@@ -141,6 +142,11 @@ class Model:
141
  low_threshold: int,
142
  high_threshold: int,
143
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
144
  self.preprocessor.load('Canny')
145
  control_image = self.preprocessor(image=image,
146
  low_threshold=low_threshold,
@@ -175,6 +181,11 @@ class Model:
175
  value_threshold: float,
176
  distance_threshold: float,
177
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
178
  self.preprocessor.load('MLSD')
179
  control_image = self.preprocessor(
180
  image=image,
@@ -210,6 +221,11 @@ class Model:
210
  seed: int,
211
  preprocessor_name: str,
212
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
213
  if preprocessor_name == 'None':
214
  image = HWC3(image)
215
  image = resize_image(image, resolution=image_resolution)
@@ -255,6 +271,11 @@ class Model:
255
  guidance_scale: float,
256
  seed: int,
257
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
258
  image = image_and_mask['mask']
259
  image = HWC3(image)
260
  image = resize_image(image, resolution=image_resolution)
@@ -287,6 +308,11 @@ class Model:
287
  seed: int,
288
  preprocessor_name: str,
289
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
290
  if preprocessor_name == 'None':
291
  image = HWC3(image)
292
  image = resize_image(image, resolution=image_resolution)
@@ -338,6 +364,11 @@ class Model:
338
  seed: int,
339
  preprocessor_name: str,
340
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
341
  if preprocessor_name == 'None':
342
  image = HWC3(image)
343
  image = resize_image(image, resolution=image_resolution)
@@ -377,6 +408,11 @@ class Model:
377
  seed: int,
378
  preprocessor_name: str,
379
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
380
  if preprocessor_name == 'None':
381
  image = HWC3(image)
382
  image = resize_image(image, resolution=image_resolution)
@@ -415,6 +451,11 @@ class Model:
415
  seed: int,
416
  preprocessor_name: str,
417
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
418
  if preprocessor_name == 'None':
419
  image = HWC3(image)
420
  image = resize_image(image, resolution=image_resolution)
@@ -453,6 +494,11 @@ class Model:
453
  seed: int,
454
  preprocessor_name: str,
455
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
456
  if preprocessor_name == 'None':
457
  image = HWC3(image)
458
  image = resize_image(image, resolution=image_resolution)
@@ -491,6 +537,11 @@ class Model:
491
  seed: int,
492
  preprocessor_name: str,
493
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
494
  if preprocessor_name in ['None', 'None (anime)']:
495
  image = HWC3(image)
496
  image = resize_image(image, resolution=image_resolution)
@@ -540,6 +591,11 @@ class Model:
540
  seed: int,
541
  preprocessor_name: str,
542
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
543
  if preprocessor_name == 'None':
544
  image = HWC3(image)
545
  image = resize_image(image, resolution=image_resolution)
@@ -575,6 +631,11 @@ class Model:
575
  guidance_scale: float,
576
  seed: int,
577
  ) -> list[PIL.Image.Image]:
 
 
 
 
 
578
  image = HWC3(image)
579
  image = resize_image(image, resolution=image_resolution)
580
  control_image = PIL.Image.fromarray(image)
 
12
 
13
  from cv_utils import resize_image
14
  from preprocessor import Preprocessor
15
+ from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
16
 
17
  CONTROLNET_MODEL_IDS = {
18
  'Openpose': 'lllyasviel/control_v11p_sd15_openpose',
 
142
  low_threshold: int,
143
  high_threshold: int,
144
  ) -> list[PIL.Image.Image]:
145
+ if image_resolution > MAX_IMAGE_RESOLUTION:
146
+ raise ValueError
147
+ if num_images > MAX_NUM_IMAGES:
148
+ raise ValueError
149
+
150
  self.preprocessor.load('Canny')
151
  control_image = self.preprocessor(image=image,
152
  low_threshold=low_threshold,
 
181
  value_threshold: float,
182
  distance_threshold: float,
183
  ) -> list[PIL.Image.Image]:
184
+ if image_resolution > MAX_IMAGE_RESOLUTION:
185
+ raise ValueError
186
+ if num_images > MAX_NUM_IMAGES:
187
+ raise ValueError
188
+
189
  self.preprocessor.load('MLSD')
190
  control_image = self.preprocessor(
191
  image=image,
 
221
  seed: int,
222
  preprocessor_name: str,
223
  ) -> list[PIL.Image.Image]:
224
+ if image_resolution > MAX_IMAGE_RESOLUTION:
225
+ raise ValueError
226
+ if num_images > MAX_NUM_IMAGES:
227
+ raise ValueError
228
+
229
  if preprocessor_name == 'None':
230
  image = HWC3(image)
231
  image = resize_image(image, resolution=image_resolution)
 
271
  guidance_scale: float,
272
  seed: int,
273
  ) -> list[PIL.Image.Image]:
274
+ if image_resolution > MAX_IMAGE_RESOLUTION:
275
+ raise ValueError
276
+ if num_images > MAX_NUM_IMAGES:
277
+ raise ValueError
278
+
279
  image = image_and_mask['mask']
280
  image = HWC3(image)
281
  image = resize_image(image, resolution=image_resolution)
 
308
  seed: int,
309
  preprocessor_name: str,
310
  ) -> list[PIL.Image.Image]:
311
+ if image_resolution > MAX_IMAGE_RESOLUTION:
312
+ raise ValueError
313
+ if num_images > MAX_NUM_IMAGES:
314
+ raise ValueError
315
+
316
  if preprocessor_name == 'None':
317
  image = HWC3(image)
318
  image = resize_image(image, resolution=image_resolution)
 
364
  seed: int,
365
  preprocessor_name: str,
366
  ) -> list[PIL.Image.Image]:
367
+ if image_resolution > MAX_IMAGE_RESOLUTION:
368
+ raise ValueError
369
+ if num_images > MAX_NUM_IMAGES:
370
+ raise ValueError
371
+
372
  if preprocessor_name == 'None':
373
  image = HWC3(image)
374
  image = resize_image(image, resolution=image_resolution)
 
408
  seed: int,
409
  preprocessor_name: str,
410
  ) -> list[PIL.Image.Image]:
411
+ if image_resolution > MAX_IMAGE_RESOLUTION:
412
+ raise ValueError
413
+ if num_images > MAX_NUM_IMAGES:
414
+ raise ValueError
415
+
416
  if preprocessor_name == 'None':
417
  image = HWC3(image)
418
  image = resize_image(image, resolution=image_resolution)
 
451
  seed: int,
452
  preprocessor_name: str,
453
  ) -> list[PIL.Image.Image]:
454
+ if image_resolution > MAX_IMAGE_RESOLUTION:
455
+ raise ValueError
456
+ if num_images > MAX_NUM_IMAGES:
457
+ raise ValueError
458
+
459
  if preprocessor_name == 'None':
460
  image = HWC3(image)
461
  image = resize_image(image, resolution=image_resolution)
 
494
  seed: int,
495
  preprocessor_name: str,
496
  ) -> list[PIL.Image.Image]:
497
+ if image_resolution > MAX_IMAGE_RESOLUTION:
498
+ raise ValueError
499
+ if num_images > MAX_NUM_IMAGES:
500
+ raise ValueError
501
+
502
  if preprocessor_name == 'None':
503
  image = HWC3(image)
504
  image = resize_image(image, resolution=image_resolution)
 
537
  seed: int,
538
  preprocessor_name: str,
539
  ) -> list[PIL.Image.Image]:
540
+ if image_resolution > MAX_IMAGE_RESOLUTION:
541
+ raise ValueError
542
+ if num_images > MAX_NUM_IMAGES:
543
+ raise ValueError
544
+
545
  if preprocessor_name in ['None', 'None (anime)']:
546
  image = HWC3(image)
547
  image = resize_image(image, resolution=image_resolution)
 
591
  seed: int,
592
  preprocessor_name: str,
593
  ) -> list[PIL.Image.Image]:
594
+ if image_resolution > MAX_IMAGE_RESOLUTION:
595
+ raise ValueError
596
+ if num_images > MAX_NUM_IMAGES:
597
+ raise ValueError
598
+
599
  if preprocessor_name == 'None':
600
  image = HWC3(image)
601
  image = resize_image(image, resolution=image_resolution)
 
631
  guidance_scale: float,
632
  seed: int,
633
  ) -> list[PIL.Image.Image]:
634
+ if image_resolution > MAX_IMAGE_RESOLUTION:
635
+ raise ValueError
636
+ if num_images > MAX_NUM_IMAGES:
637
+ raise ValueError
638
+
639
  image = HWC3(image)
640
  image = resize_image(image, resolution=image_resolution)
641
  control_image = PIL.Image.fromarray(image)