File size: 4,540 Bytes
2f34ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
from transformers import pipeline
from PIL import Image
import requests
from io import BytesIO
import logging
import torch

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class AgeClassifier:
    def __init__(self):
        try:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.pipe = pipeline("image-classification", model="nateraw/vit-age-classifier", device=self.device)
            logger.info(f"Model loaded successfully on {self.device}")
        except Exception as e:
            logger.error(f"Failed to initialize pipeline: {e}")
            raise

    def classify_image(self, image):
        try:
            return self.pipe(image)
        except Exception as e:
            logger.error(f"Classification failed: {e}")
            return None

    @staticmethod
    def format_results(results):
        if not results:
            return "No valid results"
        return results

def load_image_from_url(url):
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image
    except Exception as e:
        st.error(f"Error loading image from URL: {e}")
        return None

def main():
    st.set_page_config(
        page_title="Age Classification App",
        page_icon="πŸ‘€",
        layout="wide"
    )

    st.title("Age Classification App πŸ‘€")
    st.write("Upload an image or provide a URL to classify the age range of people in the image.")

    # Initialize the classifier
    @st.cache_resource
    def get_classifier():
        return AgeClassifier()

    classifier = get_classifier()

    # Create two columns for input methods
    col1, col2 = st.columns(2)

    with col1:
        st.subheader("Upload Image")
        uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    with col2:
        st.subheader("Image URL")
        image_url = st.text_input("Enter image URL")

    # Process the image
    image = None
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
    elif image_url:
        image = load_image_from_url(image_url)

    if image:
        # Display the image
        st.image(image, caption="Input Image", use_column_width=True)

        # Add a classify button
        if st.button("Classify Age"):
            with st.spinner("Classifying..."):
                results = classifier.classify_image(image)
                
                if results:
                    # Create a bar chart
                    st.subheader("Classification Results")
                    
                    # Convert results to format suitable for bar chart
                    labels = [r['label'] for r in results]
                    scores = [r['score'] * 100 for r in results]
                    
                    # Display most likely age range
                    most_likely = max(results, key=lambda x: x['score'])
                    st.success(f"Most likely age range: {most_likely['label']} ({most_likely['score']*100:.1f}%)")
                    
                    # Create bar chart
                    chart_data = {
                        'Age Range': labels,
                        'Confidence (%)': scores
                    }
                    st.bar_chart(chart_data, x='Age Range', y='Confidence (%)')
                    
                    # Display detailed results in an expander
                    with st.expander("See detailed results"):
                        st.write("Confidence scores for all age ranges:")
                        for result in results:
                            st.write(f"{result['label']}: {result['score']*100:.1f}%")
                else:
                    st.error("Could not classify the image. Please try another image.")

    # Add information about the model
    with st.sidebar:
        st.header("About")
        st.write("""
        This app uses the ViT (Vision Transformer) model trained for age classification.
        
        The model classifies images into the following age ranges:
        - 0-2 years
        - 3-9 years
        - 10-19 years
        - 20-29 years
        - 30-39 years
        - 40-49 years
        - 50-59 years
        - 60-69 years
        - 70+ years
        """)
        
        st.write("Model: nateraw/vit-age-classifier")
        st.write(f"Running on: {classifier.device}")

if __name__ == "__main__":
    main()