ahmedmbutt commited on
Commit
60d26fc
·
1 Parent(s): d6d2c2a

added safety checker

Browse files
Files changed (1) hide show
  1. main.py +21 -9
main.py CHANGED
@@ -39,7 +39,6 @@ async def lifespan(app: FastAPI):
39
  del inpaint
40
  del img2img
41
  del text2img
42
-
43
  del safety_checker
44
  del feature_extractor
45
 
@@ -68,9 +67,14 @@ async def text_to_image(
68
  prompt: str = Form(...),
69
  num_inference_steps: int = Form(1),
70
  ):
71
- image = request.state.text2img(
72
  prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0
73
- ).images[0]
 
 
 
 
 
74
 
75
  bytes = BytesIO()
76
  image.save(bytes, "PNG")
@@ -91,14 +95,18 @@ async def image_to_image(
91
  init_width, init_height = init_image.size
92
  init_image = init_image.convert("RGB").resize((512, 512))
93
 
94
- image = request.state.img2img(
95
  prompt,
96
  image=init_image,
97
  num_inference_steps=num_inference_steps,
98
  strength=strength,
99
  guidance_scale=0.0,
100
- ).images[0]
101
- image = image.resize((init_width, init_height))
 
 
 
 
102
 
103
  bytes = BytesIO()
104
  image.save(bytes, "PNG")
@@ -123,15 +131,19 @@ async def inpainting(
123
  mask_image = Image.open(BytesIO(mask_bytes))
124
  mask_image = mask_image.convert("RGB").resize((512, 512))
125
 
126
- image = request.state.inpaint(
127
  prompt,
128
  image=init_image,
129
  mask_image=mask_image,
130
  num_inference_steps=num_inference_steps,
131
  strength=strength,
132
  guidance_scale=0.0,
133
- ).images[0]
134
- image = image.resize((init_width, init_height))
 
 
 
 
135
 
136
  bytes = BytesIO()
137
  image.save(bytes, "PNG")
 
39
  del inpaint
40
  del img2img
41
  del text2img
 
42
  del safety_checker
43
  del feature_extractor
44
 
 
67
  prompt: str = Form(...),
68
  num_inference_steps: int = Form(1),
69
  ):
70
+ results = request.state.text2img(
71
  prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0
72
+ )
73
+
74
+ if not results.nsfw_content_detected[0]:
75
+ image = results.images[0]
76
+ else:
77
+ image = Image.new("RGB", (512, 512), "black")
78
 
79
  bytes = BytesIO()
80
  image.save(bytes, "PNG")
 
95
  init_width, init_height = init_image.size
96
  init_image = init_image.convert("RGB").resize((512, 512))
97
 
98
+ results = request.state.img2img(
99
  prompt,
100
  image=init_image,
101
  num_inference_steps=num_inference_steps,
102
  strength=strength,
103
  guidance_scale=0.0,
104
+ )
105
+
106
+ if not results.nsfw_content_detected[0]:
107
+ image = results.images[0].resize((init_width, init_height))
108
+ else:
109
+ image = Image.new("RGB", (512, 512), "black")
110
 
111
  bytes = BytesIO()
112
  image.save(bytes, "PNG")
 
131
  mask_image = Image.open(BytesIO(mask_bytes))
132
  mask_image = mask_image.convert("RGB").resize((512, 512))
133
 
134
+ results = request.state.inpaint(
135
  prompt,
136
  image=init_image,
137
  mask_image=mask_image,
138
  num_inference_steps=num_inference_steps,
139
  strength=strength,
140
  guidance_scale=0.0,
141
+ )
142
+
143
+ if not results.nsfw_content_detected[0]:
144
+ image = results.images[0].resize((init_width, init_height))
145
+ else:
146
+ image = Image.new("RGB", (512, 512), "black")
147
 
148
  bytes = BytesIO()
149
  image.save(bytes, "PNG")