slliac commited on
Commit
2f34ae4
·
verified ·
1 Parent(s): b4f5b7a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import logging
7
+ import torch
8
+
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class AgeClassifier:
14
+ def __init__(self):
15
+ try:
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.pipe = pipeline("image-classification", model="nateraw/vit-age-classifier", device=self.device)
18
+ logger.info(f"Model loaded successfully on {self.device}")
19
+ except Exception as e:
20
+ logger.error(f"Failed to initialize pipeline: {e}")
21
+ raise
22
+
23
+ def classify_image(self, image):
24
+ try:
25
+ return self.pipe(image)
26
+ except Exception as e:
27
+ logger.error(f"Classification failed: {e}")
28
+ return None
29
+
30
+ @staticmethod
31
+ def format_results(results):
32
+ if not results:
33
+ return "No valid results"
34
+ return results
35
+
36
+ def load_image_from_url(url):
37
+ try:
38
+ response = requests.get(url, timeout=10)
39
+ response.raise_for_status()
40
+ image = Image.open(BytesIO(response.content))
41
+ return image
42
+ except Exception as e:
43
+ st.error(f"Error loading image from URL: {e}")
44
+ return None
45
+
46
+ def main():
47
+ st.set_page_config(
48
+ page_title="Age Classification App",
49
+ page_icon="👤",
50
+ layout="wide"
51
+ )
52
+
53
+ st.title("Age Classification App 👤")
54
+ st.write("Upload an image or provide a URL to classify the age range of people in the image.")
55
+
56
+ # Initialize the classifier
57
+ @st.cache_resource
58
+ def get_classifier():
59
+ return AgeClassifier()
60
+
61
+ classifier = get_classifier()
62
+
63
+ # Create two columns for input methods
64
+ col1, col2 = st.columns(2)
65
+
66
+ with col1:
67
+ st.subheader("Upload Image")
68
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
69
+
70
+ with col2:
71
+ st.subheader("Image URL")
72
+ image_url = st.text_input("Enter image URL")
73
+
74
+ # Process the image
75
+ image = None
76
+ if uploaded_file is not None:
77
+ image = Image.open(uploaded_file)
78
+ elif image_url:
79
+ image = load_image_from_url(image_url)
80
+
81
+ if image:
82
+ # Display the image
83
+ st.image(image, caption="Input Image", use_column_width=True)
84
+
85
+ # Add a classify button
86
+ if st.button("Classify Age"):
87
+ with st.spinner("Classifying..."):
88
+ results = classifier.classify_image(image)
89
+
90
+ if results:
91
+ # Create a bar chart
92
+ st.subheader("Classification Results")
93
+
94
+ # Convert results to format suitable for bar chart
95
+ labels = [r['label'] for r in results]
96
+ scores = [r['score'] * 100 for r in results]
97
+
98
+ # Display most likely age range
99
+ most_likely = max(results, key=lambda x: x['score'])
100
+ st.success(f"Most likely age range: {most_likely['label']} ({most_likely['score']*100:.1f}%)")
101
+
102
+ # Create bar chart
103
+ chart_data = {
104
+ 'Age Range': labels,
105
+ 'Confidence (%)': scores
106
+ }
107
+ st.bar_chart(chart_data, x='Age Range', y='Confidence (%)')
108
+
109
+ # Display detailed results in an expander
110
+ with st.expander("See detailed results"):
111
+ st.write("Confidence scores for all age ranges:")
112
+ for result in results:
113
+ st.write(f"{result['label']}: {result['score']*100:.1f}%")
114
+ else:
115
+ st.error("Could not classify the image. Please try another image.")
116
+
117
+ # Add information about the model
118
+ with st.sidebar:
119
+ st.header("About")
120
+ st.write("""
121
+ This app uses the ViT (Vision Transformer) model trained for age classification.
122
+
123
+ The model classifies images into the following age ranges:
124
+ - 0-2 years
125
+ - 3-9 years
126
+ - 10-19 years
127
+ - 20-29 years
128
+ - 30-39 years
129
+ - 40-49 years
130
+ - 50-59 years
131
+ - 60-69 years
132
+ - 70+ years
133
+ """)
134
+
135
+ st.write("Model: nateraw/vit-age-classifier")
136
+ st.write(f"Running on: {classifier.device}")
137
+
138
+ if __name__ == "__main__":
139
+ main()