Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify, url_for | |
from transformers import ViTImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import requests | |
import threading | |
import gradio as gr | |
# Initialize the Flask app | |
app = Flask(__name__) | |
# Load the processor and model outside of the route to avoid reloading it with each request | |
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') | |
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') | |
def classify_image(): | |
try: | |
# Get the image URL from the POST request | |
data = request.get_json() | |
image_url = data.get('image_url') | |
if not image_url: | |
return jsonify({"error": "Image URL not provided"}), 400 | |
# Fetch the image from the URL | |
image = Image.open(requests.get(image_url, stream=True).raw) | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
# Run the image through the model | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get the predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
# Return the classification result | |
return jsonify({ | |
"image_url": image_url, | |
"predicted_class": predicted_class | |
}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
# Function to run the Flask app in a separate thread | |
def run_flask(): | |
app.run(port=5000, debug=False, use_reloader=False) | |
# Launch Flask in a separate thread | |
flask_thread = threading.Thread(target=run_flask) | |
flask_thread.start() | |
# Gradio interface | |
def predict_image(image_url): | |
try: | |
# Load image from URL | |
image = Image.open(requests.get(image_url, stream=True).raw) | |
# Process the image and make prediction | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_label = model.config.id2label[predicted_class_idx] | |
return predicted_label | |
except Exception as e: | |
return str(e) | |
# Construct API endpoint URL | |
api_url = "http://127.0.0.1:5000/classify" | |
# Create Gradio interface with API info | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"), | |
outputs=gr.Textbox(label="Predicted Class"), | |
title="NSFW Image Detection", | |
description=f"You can get your image classification by sending an API request to: {api_url}. Example:\n" | |
f"curl -X POST {api_url} -H 'Content-Type: application/json' -d '{{\"image_url\": \"YOUR_IMAGE_URL\"}}'" | |
) | |
# Launch Gradio interface | |
iface.launch() | |