ruidanwang's picture
Update app.py
4f83afd verified
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
import gradio as gr
# Load the model and processor
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# Define a function to classify the image and return the results
def classify_image(img):
pil_image = Image.fromarray(img.astype('uint8'), 'RGB')
inputs = processor(images=pil_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
results = {model.config.id2label[i]: float(probs[i]) for i in range(len(probs))}
return results
# Create the Gradio interface
image_input = gr.Image()
label_output = gr.Label(num_top_classes=2)
interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
# Launch the interface
interface.launch()