charleselena commited on
Commit
c19e295
·
verified ·
1 Parent(s): e75aeb2

test integrate safer

Browse files
Files changed (1) hide show
  1. handler.py +2 -14
handler.py CHANGED
@@ -5,10 +5,7 @@ from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
- from transformers import AutoProcessor, SafetyChecker
9
- #from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
10
-
11
- import pprint
12
 
13
  import torch
14
 
@@ -91,7 +88,7 @@ class EndpointHandler():
91
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
92
  controlnet=self.controlnet,
93
  torch_dtype=dtype,
94
- safety_checker=None).to(device)
95
 
96
 
97
  # Define Generator with seed
@@ -132,15 +129,6 @@ class EndpointHandler():
132
  image = self.decode_base64_image(image)
133
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
134
 
135
-
136
- processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
137
- safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
138
-
139
- safety_features = processor(image)
140
- safety_check_result = safety_checker(images=image, features=safety_features)
141
-
142
- pprint(safety_check_result)
143
-
144
  # run inference pipeline
145
  out = self.pipe(
146
  prompt=prompt,
 
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
 
 
 
9
 
10
  import torch
11
 
 
88
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
89
  controlnet=self.controlnet,
90
  torch_dtype=dtype,
91
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")).to(device)
92
 
93
 
94
  # Define Generator with seed
 
129
  image = self.decode_base64_image(image)
130
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
131
 
 
 
 
 
 
 
 
 
 
132
  # run inference pipeline
133
  out = self.pipe(
134
  prompt=prompt,