Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from PIL import Image | |
import requests | |
import numpy as np | |
import torch | |
# Load pre-trained model and feature extractor | |
model_name = "google/vit-base-patch16-224" | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
model = ViTForImageClassification.from_pretrained(model_name) | |
# CIFAR-10 class names | |
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Streamlit app | |
st.title("CIFAR-10 Image Classification with Pre-trained Vision Transformer") | |
# Prediction on uploaded image | |
st.subheader("Make Predictions") | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Preprocess the uploaded image | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
if st.button("Predict"): | |
with st.spinner("Classifying..."): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.write(f"Predicted Class: {predicted_class_idx} ({class_names[predicted_class_idx]})") | |