import streamlit as st from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO # Load the model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Define prediction function def predict_image(image): try: # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Streamlit app st.title("NSFW Image Classifier") # URL input image_url = st.text_input("Enter Image URL", placeholder="Enter image URL here") if image_url: try: # Load image from URL response = requests.get(image_url) image = Image.open(BytesIO(response.content)) st.image(image, caption='Image from URL', use_column_width=True) st.write("") st.write("Classifying...") # Predict and display result prediction = predict_image(image) st.write(f"Predicted Class: {prediction}") except Exception as e: st.write(f"Error: {e}")