File size: 6,849 Bytes
66d76af
 
8a107f3
9a7e61d
715ede5
8a107f3
f71b0ec
66d76af
 
4651743
42aff3a
 
 
4651743
52562b6
0e1b634
7cc7300
4651743
7016bc5
6ec832e
0180191
7cc7300
7689ab2
66d76af
9f0f927
66d76af
7cc7300
42aff3a
 
9f0f927
42aff3a
 
7cc7300
 
ad73b87
7cc7300
42aff3a
9f0f927
 
42aff3a
9f0f927
42aff3a
9f0f927
 
 
42aff3a
 
 
ad73b87
42aff3a
ad73b87
42aff3a
9f0f927
 
 
 
42aff3a
 
f0c00ac
42aff3a
f0c00ac
 
 
 
9f0f927
 
42aff3a
7cc7300
 
 
 
2ece835
 
8b0209d
 
 
 
4c4b3b9
8b0209d
4c4b3b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b0209d
 
 
 
4c4b3b9
8b0209d
 
 
2ece835
8b0209d
2ece835
 
8b0209d
dae8b5c
 
1413ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae8b5c
 
1413ea5
7cc7300
66d76af
 
 
d1d5178
471a540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
license: creativeml-openrail-m
tags:
- safety-checker
- tensorflow
- node.js
pipeline_tag: image-classification
---
# Google Safesearch Mini Model Card

<a href="https://huggingface.co/FredZhang7/google-safesearch-mini-v2"> <font size="4"> <bold> Version 2 is here! </bold> </font> </a>

This model is trained on 2,220,000+ images scraped from Google Images, Reddit, Imgur, and Github.
The InceptionV3 and Xception models have been fine-tuned to predict the likelihood of an image falling into one of three categories: nsfw_gore, nsfw_suggestive, and safe.

After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.

Not only is this model accurate, but it also offers a significant advantage over stable diffusion safety checkers. By using our model, users can save 1.12GB of RAM and disk space.

<br>

# PyTorch
The PyTorch model runs much slower with transformers, so downloading it externally is a better option.
```bash
pip install --upgrade torchvision
```
```python
import torch, os, warnings, requests
from io import BytesIO
from PIL import Image
from urllib.request import urlretrieve
from torchvision import transforms

PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
USE_CUDA = False

warnings.filterwarnings("ignore")
def download_model():
    print("Downloading google_safesearch_mini.bin...")
    urlretrieve("https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin", "google_safesearch_mini.bin")

def eval():
    if not os.path.exists("google_safesearch_mini.bin"):
        download_model()
    model = torch.jit.load('./google_safesearch_mini.bin')
    img = Image.open(PATH_TO_IMAGE).convert('RGB') if not (PATH_TO_IMAGE.startswith('http://') or PATH_TO_IMAGE.startswith('https://')) else Image.open(BytesIO(requests.get(PATH_TO_IMAGE).content)).convert('RGB')
    transform = transforms.Compose([transforms.Resize(299), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = transform(img).unsqueeze(0)
    if USE_CUDA:
        img, model = img.cuda(), model.cuda()
    else:
        img, model = img.cpu(), model.cpu()
    model.eval()
    with torch.no_grad():
        out, _ = model(img)
        _, predicted = torch.max(out.data, 1)
        classes = {0: 'nsfw_gore', 1: 'nsfw_suggestive', 2: 'safe'}

        # account for edge cases
        if predicted[0] != 2 and abs(out[0][2] - out[0][predicted[0]]) > 0.20:
            img = Image.new('RGB', image.size, color = (0, 255, 255))
            print("\033[93m" + "safe" + "\033[0m")
        else:
            print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')

if __name__ == '__main__':
    eval()
```
Output Example:
![prediction](./output_example.png)

<br>

# Keras
```python
import tensorflow as tf
from PIL import Image
import requests, os

# download the model
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb"
r = requests.get(url, allow_redirects=True)
if not os.path.exists('tensorflow'):
    os.makedirs('tensorflow')
open('tensorflow/saved_model.pb', 'wb').write(r.content)

# download the variables
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001"
r = requests.get(url, allow_redirects=True)
if not os.path.exists('tensorflow/variables'):
    os.makedirs('tensorflow/variables')
open('tensorflow/variables/variables.data-00000-of-00001', 'wb').write(r.content)

url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index"
r = requests.get(url, allow_redirects=True)
open('tensorflow/variables/variables.index', 'wb').write(r.content)

# load the model
model = tf.saved_model.load('./tensorflow')
image = Image.open('cat.jpg')
image = image.resize((299, 299))
image = tf.convert_to_tensor(image)
image = tf.expand_dims(image, 0)

# run the model
tensor = model(image)
classes = ['nsfw_gore', 'nsfw_suggestive', 'safe']
prediction = classes[tf.argmax(tensor, 1)[0]]
print('\033[1;32m' + prediction + '\033[0m' if prediction == 'safe' else '\033[1;33m' + prediction + '\033[0m')
```
Output Example:
![prediction](./output_example.png)

<br>

# Tensorflow.js
```bash
npm i @tensorflow/tfjs-node
```
```javascript
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const { pipeline } = require('stream');
const { promisify } = require('util');


const download = async (url, path) => {
    // Taken from https://levelup.gitconnected.com/how-to-download-a-file-with-node-js-e2b88fe55409

    const streamPipeline = promisify(pipeline);
    const response = await fetch(url);

    if (!response.ok) {
        throw new Error(`unexpected response ${response.statusText}`);
    }

    await streamPipeline(response.body, fs.createWriteStream(path));
};


async function run() {
    // download saved model and variables from https://huggingface.co/FredZhang7/google-safesearch-mini/tree/main/tensorflow
    if (!fs.existsSync('tensorflow')) {
        fs.mkdirSync('tensorflow');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb', 'tensorflow/saved_model.pb');
        fs.mkdirSync('tensorflow/variables');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001', 'tensorflow/variables/variables.data-00000-of-00001');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index', 'tensorflow/variables/variables.index');
    }

    // load model and image
    const model = await tf.node.loadSavedModel('./tensorflow/');
    const image = tf.node.decodeImage(fs.readFileSync('cat.jpg'), 3);

    // predict
    const input = tf.expandDims(image, 0);
    const tensor = model.predict(input);
    const max = tensor.argMax(1);
    const classes = ['nsfw_gore', 'nsfw_suggestive', 'safe'];
    console.log('\x1b[32m%s\x1b[0m', classes[max.dataSync()[0]], '\n');
}

run();
```
Output Example:
![tfjs output](./tfjs_output.png)

<br>

# Bias and Limitations
Each person's definition of "safe" is different. The images in the dataset are classified as safe/unsafe by Google SafeSearch, Reddit, and Imgur.
It is possible that some images may be safe to others but not to you. Also, when a model encounters an image with things it hasn't seen, it likely makes wrong predictions.
This is why in the PyTorch example, I accounted for the "edge cases" before printing the predictions.