nsfw_api2 / app.py
yeftakun's picture
Update app.py
ef5ff29 verified
raw
history blame
3.12 kB
import streamlit as st
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import json
# Load the model and processor
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
# Define prediction function
def predict_image(image):
try:
# Process the image and make prediction
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
except Exception as e:
return str(e)
# Streamlit app for UI and API endpoint
st.title("NSFW Image Classifier")
# URL input for UI
image_url_ui = st.text_input("Enter Image URL", placeholder="Enter image URL here")
# API endpoint for classification (POST request)
@st.experimental_singleton # Ensure a single instance for performance
def api_endpoint():
if request.method == 'POST':
data = request.json
if 'image_url' in data:
try:
image_url = data['image_url']
# Load image from URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Predict and return result as JSON
prediction = predict_image(image)
return json.dumps({'predicted_class': prediction})
except Exception as e:
return json.dumps({'error': str(e)}), 500 # Internal Server Error
else:
return json.dumps({'error': 'Missing "image_url" in request body'}), 400 # Bad Request
else:
return json.dumps({'error': 'Only POST requests are allowed'}), 405 # Method Not Allowed
st.experimental_next_router(api_endpoint) # Register the API endpoint
if image_url_ui:
try:
# Load image from UI input (if URL is provided)
response = requests.get(image_url_ui)
image = Image.open(BytesIO(response.content))
st.image(image, caption='Image from URL', use_column_width=True)
st.write("")
st.write("Classifying...")
# Predict and display result (for UI)
prediction = predict_image(image)
st.write(f"Predicted Class: {prediction}")
except Exception as e:
st.write(f"Error: {e}")
# Display API endpoint information
space_url = st.session_state.get('huggingface_space_url') # Assuming it's available
if space_url:
api_endpoint_url = f"{space_url}/api/classify" # Construct the URL based on Space URL
st.write(f"You can also use this API endpoint to classify images:")
st.write(f"```curl")
st.write(f"curl -X POST -H 'Content-Type: application/json' -d '{{ \"image_url\": \"https://example.jpg\" }}' {api_endpoint_url}")
st.write(f"```")
st.write(f"This will return the predicted class in JSON format.")