RohitGandikota commited on
Commit
d8de6d4
·
1 Parent(s): 500d414

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -22,13 +22,13 @@ def numpy_to_pil(images):
22
 
23
  return pil_images
24
 
25
- def run_safety_checker(self, image, device, dtype):
26
 
27
  feature_extractor = CLIPFeatureExtractor()
28
  safety_checker = StableDiffusionSafetyChecker()
29
  safety_checker_input = feature_extractor(numpy_to_pil(image), return_tensors="pt").to('cuda')
30
  image, has_nsfw_concept = safety_checker(
31
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
32
  )
33
  return image, has_nsfw_concept
34
 
@@ -258,9 +258,7 @@ class Demo:
258
  n_steps=50,
259
  generator=generator
260
  )
261
- images, has_nsfw_concept = StableDiffusionSafetyChecker(
262
- images=images
263
- )
264
 
265
  orig_image = images[0][0]
266
 
@@ -275,9 +273,7 @@ class Demo:
275
  n_steps=50,
276
  generator=generator
277
  )
278
- images, has_nsfw_concept = StableDiffusionSafetyChecker(
279
- images=images
280
- )
281
  edited_image = images[0][0]
282
 
283
  del finetuner
 
22
 
23
  return pil_images
24
 
25
+ def run_safety_checker(image):
26
 
27
  feature_extractor = CLIPFeatureExtractor()
28
  safety_checker = StableDiffusionSafetyChecker()
29
  safety_checker_input = feature_extractor(numpy_to_pil(image), return_tensors="pt").to('cuda')
30
  image, has_nsfw_concept = safety_checker(
31
+ images=image, clip_input=safety_checker_input.pixel_values
32
  )
33
  return image, has_nsfw_concept
34
 
 
258
  n_steps=50,
259
  generator=generator
260
  )
261
+ images, has_nsfw_concept = run_safety_checker(images)
 
 
262
 
263
  orig_image = images[0][0]
264
 
 
273
  n_steps=50,
274
  generator=generator
275
  )
276
+ images, has_nsfw_concept = run_safety_checker(images)
 
 
277
  edited_image = images[0][0]
278
 
279
  del finetuner