new-pet-pet / app.py
mandali8686's picture
Update app.py
4e96e31
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from typing import List, Tuple
# Assuming your model and class names are set up correctly
pretrained_vit_path = 'pretrained_vit_model_full.pth'
class_names = ['Angry', 'Other', 'Sad', 'happy']
# Load your model
@st.experimental_singleton
def load_model():
model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
model.eval()
return model
model = load_model()
# Function to apply transforms to the image
def transform_image(image, size=(224, 224)):
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
# Add other transformations as needed, such as normalization
])
return transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs, 1)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
top_prob, top_catid = torch.topk(probabilities, 1)
return class_names[predicted[0]], top_prob[0].item()
# Streamlit interface
st.title("Animal Facial Expression Recognition")
# Create two columns for the layout
col1, col2 = st.columns([1, 1])
# First column for the uploader
with col1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# Second column for the prediction results
with col2:
if uploaded_file is not None:
# Display "Classifying..." text
st.write("Classifying...")
else:
# Show a message when no image is uploaded
st.write("Upload an image to see the classification result.")
# If an image has been uploaded, display it and run the prediction
if uploaded_file is not None:
# Display the uploaded image in the first column
with col1:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
# Transform the image and make prediction in the second column
with col2:
image_tensor = transform_image(image)
predicted_class, probability = predict(model, image_tensor)
st.write(f'Predicted class: {predicted_class}')
st.write(f'Probability: {probability:.3f}')