Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from typing import List, Tuple | |
# Assuming your model and class names are set up correctly | |
pretrained_vit_path = 'pretrained_vit_model_full.pth' | |
class_names = ['Angry', 'Other', 'Sad', 'happy'] | |
# Load your model | |
def load_model(): | |
model = torch.load(pretrained_vit_path, map_location=torch.device('cpu')) | |
model.eval() | |
return model | |
model = load_model() | |
# Function to apply transforms to the image | |
def transform_image(image, size=(224, 224)): | |
transform = transforms.Compose([ | |
transforms.Resize(size), | |
transforms.CenterCrop(size), | |
transforms.ToTensor(), | |
# Add other transformations as needed, such as normalization | |
]) | |
return transform(image).unsqueeze(0) # Add batch dimension | |
# Prediction function | |
def predict(model, image_tensor): | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
_, predicted = torch.max(outputs, 1) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
top_prob, top_catid = torch.topk(probabilities, 1) | |
return class_names[predicted[0]], top_prob[0].item() | |
# Streamlit interface | |
st.title("Animal Facial Expression Recognition") | |
# Create two columns for the layout | |
col1, col2 = st.columns([1, 1]) | |
# First column for the uploader | |
with col1: | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
# Second column for the prediction results | |
with col2: | |
if uploaded_file is not None: | |
# Display "Classifying..." text | |
st.write("Classifying...") | |
else: | |
# Show a message when no image is uploaded | |
st.write("Upload an image to see the classification result.") | |
# If an image has been uploaded, display it and run the prediction | |
if uploaded_file is not None: | |
# Display the uploaded image in the first column | |
with col1: | |
image = Image.open(uploaded_file).convert('RGB') | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Transform the image and make prediction in the second column | |
with col2: | |
image_tensor = transform_image(image) | |
predicted_class, probability = predict(model, image_tensor) | |
st.write(f'Predicted class: {predicted_class}') | |
st.write(f'Probability: {probability:.3f}') | |