K00B404 commited on
Commit
6347d0a
·
verified ·
1 Parent(s): dfcc271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -33,13 +33,13 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, times
33
  from safety_checker import StableDiffusionSafetyChecker
34
  from transformers import CLIPFeatureExtractor
35
 
36
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
37
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
38
 
39
- def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
40
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
41
- has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
42
- return has_nsfw_concepts
43
 
44
  # Function
45
  @spaces.GPU(enable_queue=True)
@@ -72,10 +72,10 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
72
 
73
  output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=step, callback=progress_callback, callback_steps=1)
74
 
75
- has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
76
- if has_nsfw_concepts[0]:
77
- gr.Warning("NSFW content detected.")
78
- return None
79
 
80
  name = str(uuid.uuid4()).replace("-", "")
81
  path = f"/tmp/{name}.mp4"
 
33
  from safety_checker import StableDiffusionSafetyChecker
34
  from transformers import CLIPFeatureExtractor
35
 
36
+ #safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
37
+ #feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
38
 
39
+ #def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
40
+ # safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
41
+ # has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
42
+ # return has_nsfw_concepts
43
 
44
  # Function
45
  @spaces.GPU(enable_queue=True)
 
72
 
73
  output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=step, callback=progress_callback, callback_steps=1)
74
 
75
+ #has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
76
+ #if has_nsfw_concepts[0]:
77
+ # gr.Warning("NSFW content detected.")
78
+ # return None
79
 
80
  name = str(uuid.uuid4()).replace("-", "")
81
  path = f"/tmp/{name}.mp4"