ApexID / src /streamlit_app.py
Michael Rey
added documentation and modified existing code
b0ae40a
import os
import sys
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
from resnet_model import MonkeyResNet
from data_loader import get_data_loaders
import io
# Ensure the parent directory is in the system path for module imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# Set Streamlit page configuration
st.set_page_config(page_title="ApexID: Monkey Species Classifier", layout="wide")
# Constants for model path and class labels
MODEL_PATH = os.path.join(os.path.dirname(__file__), "monkey_resnet.pth")
CLASS_NAMES = ['n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9']
LABEL_MAP = {
'n0': 'Alouatta Palliata',
'n1': 'Erythrocebus Patas',
'n2': 'Cacajao Calvus',
'n3': 'Macaca Fuscata',
'n4': 'Cebuella Pygmea',
'n5': 'Cebus Capucinus',
'n6': 'Mico Argentatus',
'n7': 'Saimiri Sciureus',
'n8': 'Aotus Nigriceps',
'n9': 'Trachypithecus Johnii'
}
# Load model with caching to avoid reloading every time
@st.cache_resource
def load_model():
model = MonkeyResNet(num_classes=len(CLASS_NAMES))
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
return model
model = load_model()
# Image preprocessing to match model input requirements
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Add batch dimension after transforming
return transform(image).unsqueeze(0)
# Title for the app
st.title("ApexID: Monkey Species Classifier")
# Create tabs for project info, classification, and model details
tab1, tab2, tab3 = st.tabs(["About the Project", "Image Classification", "Model Details"])
# Tab 1: Project description
with tab1:
st.header("About This Project")
st.write("""
This project uses a deep learning model based on ResNet18 with transfer learning to classify images of ten monkey species. It applies convolutional neural networks (CNNs) for accurate image recognition and is designed for tasks like education, wildlife monitoring, and zoo record management.
Key Points:
- Image classification using CNN and transfer learning
- Built with PyTorch for model training
- Streamlit used for a user-friendly interface
""")
# Tab 2: Image classification interface
with tab2:
st.header("Classify a Monkey Image")
st.markdown("""
Steps to classify:
1. Upload a clear monkey image.
2. Supported file types: jpg, png, jpeg.
3. See prediction below after uploading.
""")
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
raw_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
st.image(image, caption="Uploaded Image", width=300)
input_tensor = preprocess_image(image)
input_tensor = preprocess_image(image)
# Add spinner to indicate loading
with st.spinner("Classifying... Please wait."):
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
predicted_label = CLASS_NAMES[predicted.item()]
species_name = LABEL_MAP[predicted_label]
st.success(f"Predicted Monkey Species: {species_name}")
# Tab 3: Model details and performance
with tab3:
st.header("Model Information")
st.markdown("""
- Model architecture: ResNet18 using transfer learning
- Framework: PyTorch
- Final training accuracy: 90.88%
- Final validation loss: 2.44
- Test accuracy: 92%
""")
# Show training accuracy plot
if os.path.exists("plots/accuracy_plot.png"):
st.image("plots/accuracy_plot.png", caption="Training and Validation Accuracy")
# Show training loss plot
if os.path.exists("plots/loss_plot.png"):
st.image("plots/loss_plot.png", caption="Training and Validation Loss")
# Show confusion matrix
if os.path.exists("confusion_matrix.png"):
st.image("confusion_matrix.png", caption="Confusion Matrix")