--- license: creativeml-openrail-m tags: - python - node.js --- # Google Safesearch Mini Model Card This model is trained on 2,224,000 images scraped from Google Images, Reddit, Imgur, and Github. It predicts the likelihood of an image being nsfw_gore, nsfw_suggestive, and safe. After 20 epochs on PyTorch's InceptionV3, the model achieves 94% acc on both train and test data. After 3.3 epochs on Keras's Xception, the model scores 94% acc on train set and 92% on test set. Using this instead of the stable diffusion safety checker allows users to save 1.12GB of RAM and disk space.
# PyTorch ```bash pip install --upgrade transformers torchvision ``` ```python from transformers import AutoModelForImageClassification from torch import cuda model = AutoModelForImageClassification.from_pretrained("FredZhang7/google-safesearch-mini", trust_remote_code=True, revision="d0b4c6be6d908c39c0dd83d25dce50c0e861e46a") PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1' PRINT_TENSOR = False prediction = model.predict(PATH_TO_IMAGE, device="cuda" if cuda.is_available() else "cpu", print_tensor=PRINT_TENSOR) print('\033[1;32m' + prediction + '\033[0m' if prediction == 'safe' else '\033[1;33m' + prediction + '\033[0m') ``` Output Example: ![prediction](./output_example.png)
# Keras ```python import tensorflow as tf from PIL import Image import requests, os # download the model url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb" r = requests.get(url, allow_redirects=True) if not os.path.exists('tensorflow'): os.makedirs('tensorflow') open('tensorflow/saved_model.pb', 'wb').write(r.content) # download the variables url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001" r = requests.get(url, allow_redirects=True) if not os.path.exists('tensorflow/variables'): os.makedirs('tensorflow/variables') open('tensorflow/variables/variables.data-00000-of-00001', 'wb').write(r.content) url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index" r = requests.get(url, allow_redirects=True) open('tensorflow/variables/variables.index', 'wb').write(r.content) # load the model model = tf.saved_model.load('./tensorflow') image = Image.open('cat.jpg') image = image.resize((299, 299)) image = tf.convert_to_tensor(image) image = tf.expand_dims(image, 0) # run the model tensor = model(image) classes = ['nsfw_gore', 'nsfw_suggestive', 'safe'] prediction = classes[tf.argmax(tensor, 1)[0]] print('\033[1;32m' + prediction + '\033[0m' if prediction == 'safe' else '\033[1;33m' + prediction + '\033[0m') ``` Output Example: ![prediction](./output_example.png) # Tensorflow.js ```bash npm i @tensorflow/tfjs-node ``` ```javascript const tf = require('@tensorflow/tfjs-node'); const fs = require('fs'); const { pipeline } = require('stream'); const { promisify } = require('util'); const download = async (url, path) => { // Taken from https://levelup.gitconnected.com/how-to-download-a-file-with-node-js-e2b88fe55409 const streamPipeline = promisify(pipeline); const response = await fetch(url); if (!response.ok) { throw new Error(`unexpected response ${response.statusText}`); } await streamPipeline(response.body, fs.createWriteStream(path)); }; async function run() { // download saved model and variables from https://huggingface.co/FredZhang7/google-safesearch-mini/tree/main/tensorflow if (!fs.existsSync('tensorflow')) { fs.mkdirSync('tensorflow'); await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/saved_model.pb', 'tensorflow/saved_model.pb'); fs.mkdirSync('tensorflow/variables'); await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.data-00000-of-00001', 'tensorflow/variables/variables.data-00000-of-00001'); await download('https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/tensorflow/variables/variables.index', 'tensorflow/variables/variables.index'); } // load model and image const model = await tf.node.loadSavedModel('./tensorflow/'); const image = tf.node.decodeImage(fs.readFileSync('cat.jpg'), 3); // predict const input = tf.expandDims(image, 0); const tensor = model.predict(input); const max = tensor.argMax(1); const classes = ['nsfw_gore', 'nsfw_suggestive', 'safe']; console.log('\x1b[32m%s\x1b[0m', classes[max.dataSync()[0]], '\n'); } run(); ```
# Bias and Limitations Each person's definition of "safe" is different. The images in the dataset are classified as safe/unsafe by Google SafeSearch, Reddit, and Imgur. It is possible that some images may be safe to others but not to you.