File size: 1,406 Bytes
39c482f
 
81c25c0
 
39c482f
81c25c0
 
 
 
39c482f
 
 
 
 
81c25c0
39c482f
 
 
 
 
 
 
 
 
81c25c0
 
39c482f
81c25c0
 
 
 
 
 
 
 
 
39c482f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

# Load a pre-trained model and feature extractor
model_name = "facebook/wide_resnet50_2"  # Using a general model
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Define the main function for the Streamlit app
def main():
    st.title("Hot or Not Image Classifier")

    st.write("Upload an image to classify it.")

    # Image upload
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded Image", use_column_width=True)

        # Preprocess the image
        inputs = feature_extractor(images=image, return_tensors="pt")

        # Make predictions
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits  # Get the logits
            class_idx = logits.argmax(-1).item()  # Get the index of the highest probability

        # Display results based on class index
        st.write(f"Predicted class index: {class_idx}")
        st.write(f"Predicted class label: {model.config.id2label[class_idx]}")

if __name__ == "__main__":
    main()