nsfw_detection / app.py
yeftakun's picture
Update app.py
b9ed281 verified
raw
history blame
2.95 kB
from flask import Flask, request, jsonify, url_for
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
import threading
import gradio as gr
# Initialize the Flask app
app = Flask(__name__)
# Load the processor and model outside of the route to avoid reloading it with each request
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
@app.route('/classify', methods=['POST'])
def classify_image():
try:
# Get the image URL from the POST request
data = request.get_json()
image_url = data.get('image_url')
if not image_url:
return jsonify({"error": "Image URL not provided"}), 400
# Fetch the image from the URL
image = Image.open(requests.get(image_url, stream=True).raw)
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Run the image through the model
outputs = model(**inputs)
logits = outputs.logits
# Get the predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
# Return the classification result
return jsonify({
"image_url": image_url,
"predicted_class": predicted_class
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Function to run the Flask app in a separate thread
def run_flask():
app.run(port=5000, debug=False, use_reloader=False)
# Launch Flask in a separate thread
flask_thread = threading.Thread(target=run_flask)
flask_thread.start()
# Gradio interface
def predict_image(image_url):
try:
# Load image from URL
image = Image.open(requests.get(image_url, stream=True).raw)
# Process the image and make prediction
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
except Exception as e:
return str(e)
# Construct API endpoint URL
api_url = "http://127.0.0.1:5000/classify"
# Create Gradio interface with API info
iface = gr.Interface(
fn=predict_image,
inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
outputs=gr.Textbox(label="Predicted Class"),
title="NSFW Image Detection",
description=f"You can get your image classification by sending an API request to: {api_url}. Example:\n"
f"curl -X POST {api_url} -H 'Content-Type: application/json' -d '{{\"image_url\": \"YOUR_IMAGE_URL\"}}'"
)
# Launch Gradio interface
iface.launch()