charleselena commited on
Commit
b674557
verified
1 Parent(s): bacee68

precessor after generated image

Browse files
Files changed (1) hide show
  1. handler.py +22 -9
handler.py CHANGED
@@ -5,8 +5,9 @@ from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
- # from transformers import AutoProcessor, SafetyChecker
9
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
 
10
 
11
  import torch
12
 
@@ -64,6 +65,9 @@ class EndpointHandler():
64
  # define default controlnet id and load controlnet
65
  self.control_type = "depth"
66
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
 
 
 
67
 
68
  # Load StableDiffusionControlNetPipeline
69
  #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
@@ -80,17 +84,15 @@ class EndpointHandler():
80
  # self.stable_diffusion_id,
81
  # controlnet=self.controlnet,
82
  # torch_dtype=dtype,
83
- # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
84
  # ).to(device)
85
 
 
 
 
 
86
 
87
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
88
- self.stable_diffusion_id,
89
- controlnet=self.controlnet,
90
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
91
- ).to(device)
92
 
93
-
94
  # Define Generator with seed
95
  self.generator = torch.Generator(device="cpu").manual_seed(3)
96
 
@@ -128,6 +130,17 @@ class EndpointHandler():
128
  # process image
129
  image = self.decode_base64_image(image)
130
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # run inference pipeline
133
  out = self.pipe(
 
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
+ from transformers import AutoProcessor, SafetyChecker
9
+ #from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
10
+
11
 
12
  import torch
13
 
 
65
  # define default controlnet id and load controlnet
66
  self.control_type = "depth"
67
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
68
+
69
+ processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
70
+
71
 
72
  # Load StableDiffusionControlNetPipeline
73
  #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
 
84
  # self.stable_diffusion_id,
85
  # controlnet=self.controlnet,
86
  # torch_dtype=dtype,
87
+ # safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
88
  # ).to(device)
89
 
90
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
91
+ controlnet=self.controlnet,
92
+ torch_dtype=dtype,
93
+ safety_checker=None).to(device)
94
 
 
 
 
 
 
95
 
 
96
  # Define Generator with seed
97
  self.generator = torch.Generator(device="cpu").manual_seed(3)
98
 
 
130
  # process image
131
  image = self.decode_base64_image(image)
132
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
133
+
134
+
135
+ processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
136
+ safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
137
+
138
+ safety_features = processor(image)
139
+ safety_check_result = safety_checker(images=image, features=safety_features)
140
+
141
+ print(f'Ocurri贸 un error: {safety_check_result}')
142
+ print(f'Ocurri贸 un error: {safety_check_result["nsfw_content_detected"]')
143
+
144
 
145
  # run inference pipeline
146
  out = self.pipe(