FredZhang7's picture
Update README.md
312519c
|
raw
history blame
5.11 kB
metadata
license: creativeml-openrail-m
tags:
  - tensorflow.js
  - node.js

Google Safesearch Mini Model Card

Initially, the training data consisted of 278,000 images, and the model achieved 99% training and test acc. Now, this model is trained on 2,220,000+ images scraped from Google Images, Reddit, Imgur, and Github. It predicts the likelihood of an image being nsfw_gore, nsfw_suggestive, and safe.

After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.

Using this instead of the stable diffusion safety checker allows users to save 1.12GB of RAM and disk space.


PyTorch

pip install --upgrade transformers torchvision
from transformers import AutoModelForImageClassification
from torch import cuda

model = AutoModelForImageClassification.from_pretrained("FredZhang7/google-safesearch-mini", trust_remote_code=True, revision="6fcab6a27595a5f008625ec88c77c10739f3c219")

PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
PRINT_TENSOR = False

prediction = model.predict(PATH_TO_IMAGE, device="cuda" if cuda.is_available() else "cpu", print_tensor=PRINT_TENSOR)
print('\033[1;32m' + prediction + '\033[0m' if prediction == 'safe' else '\033[1;33m' + prediction + '\033[0m')

Output Example: prediction


Keras

import tensorflow as tf
from PIL import Image
import requests, os

# download the model
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb"
r = requests.get(url, allow_redirects=True)
if not os.path.exists('tensorflow'):
    os.makedirs('tensorflow')
open('tensorflow/saved_model.pb', 'wb').write(r.content)

# download the variables
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001"
r = requests.get(url, allow_redirects=True)
if not os.path.exists('tensorflow/variables'):
    os.makedirs('tensorflow/variables')
open('tensorflow/variables/variables.data-00000-of-00001', 'wb').write(r.content)

url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index"
r = requests.get(url, allow_redirects=True)
open('tensorflow/variables/variables.index', 'wb').write(r.content)

# load the model
model = tf.saved_model.load('./tensorflow')
image = Image.open('cat.jpg')
image = image.resize((299, 299))
image = tf.convert_to_tensor(image)
image = tf.expand_dims(image, 0)

# run the model
tensor = model(image)
classes = ['nsfw_gore', 'nsfw_suggestive', 'safe']
prediction = classes[tf.argmax(tensor, 1)[0]]
print('\033[1;32m' + prediction + '\033[0m' if prediction == 'safe' else '\033[1;33m' + prediction + '\033[0m')

Output Example: prediction


Tensorflow.js

npm i @tensorflow/tfjs-node
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const { pipeline } = require('stream');
const { promisify } = require('util');


const download = async (url, path) => {
    // Taken from https://levelup.gitconnected.com/how-to-download-a-file-with-node-js-e2b88fe55409

    const streamPipeline = promisify(pipeline);
    const response = await fetch(url);

    if (!response.ok) {
        throw new Error(`unexpected response ${response.statusText}`);
    }

    await streamPipeline(response.body, fs.createWriteStream(path));
};


async function run() {
    // download saved model and variables from https://huggingface.co/FredZhang7/google-safesearch-mini/tree/main/tensorflow
    if (!fs.existsSync('tensorflow')) {
        fs.mkdirSync('tensorflow');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb', 'tensorflow/saved_model.pb');
        fs.mkdirSync('tensorflow/variables');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001', 'tensorflow/variables/variables.data-00000-of-00001');
        await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index', 'tensorflow/variables/variables.index');
    }

    // load model and image
    const model = await tf.node.loadSavedModel('./tensorflow/');
    const image = tf.node.decodeImage(fs.readFileSync('cat.jpg'), 3);

    // predict
    const input = tf.expandDims(image, 0);
    const tensor = model.predict(input);
    const max = tensor.argMax(1);
    const classes = ['nsfw_gore', 'nsfw_suggestive', 'safe'];
    console.log('\x1b[32m%s\x1b[0m', classes[max.dataSync()[0]], '\n');
}

run();

Output Example: tfjs output


Bias and Limitations

Each person's definition of "safe" is different. The images in the dataset are classified as safe/unsafe by Google SafeSearch, Reddit, and Imgur. It is possible that some images may be safe to others but not to you.