File size: 1,772 Bytes
10a5730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
import numpy as np

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
import modules.config

safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None


def numpy_to_pil(image):
    image = (image * 255).round().astype("uint8")
    pil_image = Image.fromarray(image)

    return pil_image


# check and replace nsfw content
def check_safety(x_image):
    global safety_feature_extractor, safety_checker

    if safety_feature_extractor is None:
        safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
        safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
    
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    return x_checked_image, has_nsfw_concept


def censor_single(x):
    x_checked_image, has_nsfw_concept = check_safety(x)

    # replace image with black pixels, keep dimensions
    # workaround due to different numpy / pytorch image matrix format
    if has_nsfw_concept[0]:
        imageshape = x_checked_image.shape
        x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
    return x_checked_image


def censor_batch(images):
    images = [censor_single(image) for image in images]

    return images