import os import sys import streamlit as st import torch import torchvision.transforms as transforms from PIL import Image from resnet_model import MonkeyResNet from data_loader import get_data_loaders import io # Ensure the parent directory is in the system path for module imports sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) # Set Streamlit page configuration st.set_page_config(page_title="ApexID: Monkey Species Classifier", layout="wide") # Constants for model path and class labels MODEL_PATH = os.path.join(os.path.dirname(__file__), "monkey_resnet.pth") CLASS_NAMES = ['n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9'] LABEL_MAP = { 'n0': 'Alouatta Palliata', 'n1': 'Erythrocebus Patas', 'n2': 'Cacajao Calvus', 'n3': 'Macaca Fuscata', 'n4': 'Cebuella Pygmea', 'n5': 'Cebus Capucinus', 'n6': 'Mico Argentatus', 'n7': 'Saimiri Sciureus', 'n8': 'Aotus Nigriceps', 'n9': 'Trachypithecus Johnii' } # Load model with caching to avoid reloading every time @st.cache_resource def load_model(): model = MonkeyResNet(num_classes=len(CLASS_NAMES)) model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) model.eval() return model model = load_model() # Image preprocessing to match model input requirements def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Add batch dimension after transforming return transform(image).unsqueeze(0) # Title for the app st.title("ApexID: Monkey Species Classifier") # Create tabs for project info, classification, and model details tab1, tab2, tab3 = st.tabs(["About the Project", "Image Classification", "Model Details"]) # Tab 1: Project description with tab1: st.header("About This Project") st.write(""" This project uses a deep learning model based on ResNet18 with transfer learning to classify images of ten monkey species. It applies convolutional neural networks (CNNs) for accurate image recognition and is designed for tasks like education, wildlife monitoring, and zoo record management. Key Points: - Image classification using CNN and transfer learning - Built with PyTorch for model training - Streamlit used for a user-friendly interface """) # Tab 2: Image classification interface with tab2: st.header("Classify a Monkey Image") st.markdown(""" Steps to classify: 1. Upload a clear monkey image. 2. Supported file types: jpg, png, jpeg. 3. See prediction below after uploading. """) uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: raw_bytes = uploaded_file.read() image = Image.open(io.BytesIO(raw_bytes)).convert("RGB") st.image(image, caption="Uploaded Image", width=300) input_tensor = preprocess_image(image) input_tensor = preprocess_image(image) # Add spinner to indicate loading with st.spinner("Classifying... Please wait."): with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs, 1) predicted_label = CLASS_NAMES[predicted.item()] species_name = LABEL_MAP[predicted_label] st.success(f"Predicted Monkey Species: {species_name}") # Tab 3: Model details and performance with tab3: st.header("Model Information") st.markdown(""" - Model architecture: ResNet18 using transfer learning - Framework: PyTorch - Final training accuracy: 90.88% - Final validation loss: 2.44 - Test accuracy: 92% """) # Show training accuracy plot if os.path.exists("plots/accuracy_plot.png"): st.image("plots/accuracy_plot.png", caption="Training and Validation Accuracy") # Show training loss plot if os.path.exists("plots/loss_plot.png"): st.image("plots/loss_plot.png", caption="Training and Validation Loss") # Show confusion matrix if os.path.exists("confusion_matrix.png"): st.image("confusion_matrix.png", caption="Confusion Matrix")