Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
# Load a pre-trained model and feature extractor | |
model_name = "facebook/wide_resnet50_2" # Using a general model | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
# Define the main function for the Streamlit app | |
def main(): | |
st.title("Hot or Not Image Classifier") | |
st.write("Upload an image to classify it.") | |
# Image upload | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Preprocess the image | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# Make predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits # Get the logits | |
class_idx = logits.argmax(-1).item() # Get the index of the highest probability | |
# Display results based on class index | |
st.write(f"Predicted class index: {class_idx}") | |
st.write(f"Predicted class label: {model.config.id2label[class_idx]}") | |
if __name__ == "__main__": | |
main() | |