import streamlit as st from transformers import pipeline from PIL import Image import requests from io import BytesIO import logging import torch # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class AgeClassifier: def __init__(self): try: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipe = pipeline("image-classification", model="nateraw/vit-age-classifier", device=self.device) logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to initialize pipeline: {e}") raise def classify_image(self, image): try: return self.pipe(image) except Exception as e: logger.error(f"Classification failed: {e}") return None @staticmethod def format_results(results): if not results: return "No valid results" return results def load_image_from_url(url): try: response = requests.get(url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)) return image except Exception as e: st.error(f"Error loading image from URL: {e}") return None def main(): st.set_page_config( page_title="Age Classification App", page_icon="👤", layout="wide" ) st.title("Age Classification App 👤") st.write("Upload an image or provide a URL to classify the age range of people in the image.") # Initialize the classifier @st.cache_resource def get_classifier(): return AgeClassifier() classifier = get_classifier() # Create two columns for input methods col1, col2 = st.columns(2) with col1: st.subheader("Upload Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) with col2: st.subheader("Image URL") image_url = st.text_input("Enter image URL") # Process the image image = None if uploaded_file is not None: image = Image.open(uploaded_file) elif image_url: image = load_image_from_url(image_url) if image: # Display the image st.image(image, caption="Input Image", use_column_width=True) # Add a classify button if st.button("Classify Age"): with st.spinner("Classifying..."): results = classifier.classify_image(image) if results: # Create a bar chart st.subheader("Classification Results") # Convert results to format suitable for bar chart labels = [r['label'] for r in results] scores = [r['score'] * 100 for r in results] # Display most likely age range most_likely = max(results, key=lambda x: x['score']) st.success(f"Most likely age range: {most_likely['label']} ({most_likely['score']*100:.1f}%)") # Create bar chart chart_data = { 'Age Range': labels, 'Confidence (%)': scores } st.bar_chart(chart_data, x='Age Range', y='Confidence (%)') # Display detailed results in an expander with st.expander("See detailed results"): st.write("Confidence scores for all age ranges:") for result in results: st.write(f"{result['label']}: {result['score']*100:.1f}%") else: st.error("Could not classify the image. Please try another image.") # Add information about the model with st.sidebar: st.header("About") st.write(""" This app uses the ViT (Vision Transformer) model trained for age classification. The model classifies images into the following age ranges: - 0-2 years - 3-9 years - 10-19 years - 20-29 years - 30-39 years - 40-49 years - 50-59 years - 60-69 years - 70+ years """) st.write("Model: nateraw/vit-age-classifier") st.write(f"Running on: {classifier.device}") if __name__ == "__main__": main()