|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import logging |
|
import torch |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class AgeClassifier: |
|
def __init__(self): |
|
try: |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.pipe = pipeline("image-classification", model="nateraw/vit-age-classifier", device=self.device) |
|
logger.info(f"Model loaded successfully on {self.device}") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize pipeline: {e}") |
|
raise |
|
|
|
def classify_image(self, image): |
|
try: |
|
return self.pipe(image) |
|
except Exception as e: |
|
logger.error(f"Classification failed: {e}") |
|
return None |
|
|
|
@staticmethod |
|
def format_results(results): |
|
if not results: |
|
return "No valid results" |
|
return results |
|
|
|
def load_image_from_url(url): |
|
try: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)) |
|
return image |
|
except Exception as e: |
|
st.error(f"Error loading image from URL: {e}") |
|
return None |
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Age Classification App", |
|
page_icon="π€", |
|
layout="wide" |
|
) |
|
|
|
st.title("Age Classification App π€") |
|
st.write("Upload an image or provide a URL to classify the age range of people in the image.") |
|
|
|
|
|
@st.cache_resource |
|
def get_classifier(): |
|
return AgeClassifier() |
|
|
|
classifier = get_classifier() |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Upload Image") |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
with col2: |
|
st.subheader("Image URL") |
|
image_url = st.text_input("Enter image URL") |
|
|
|
|
|
image = None |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
elif image_url: |
|
image = load_image_from_url(image_url) |
|
|
|
if image: |
|
|
|
st.image(image, caption="Input Image", use_column_width=True) |
|
|
|
|
|
if st.button("Classify Age"): |
|
with st.spinner("Classifying..."): |
|
results = classifier.classify_image(image) |
|
|
|
if results: |
|
|
|
st.subheader("Classification Results") |
|
|
|
|
|
labels = [r['label'] for r in results] |
|
scores = [r['score'] * 100 for r in results] |
|
|
|
|
|
most_likely = max(results, key=lambda x: x['score']) |
|
st.success(f"Most likely age range: {most_likely['label']} ({most_likely['score']*100:.1f}%)") |
|
|
|
|
|
chart_data = { |
|
'Age Range': labels, |
|
'Confidence (%)': scores |
|
} |
|
st.bar_chart(chart_data, x='Age Range', y='Confidence (%)') |
|
|
|
|
|
with st.expander("See detailed results"): |
|
st.write("Confidence scores for all age ranges:") |
|
for result in results: |
|
st.write(f"{result['label']}: {result['score']*100:.1f}%") |
|
else: |
|
st.error("Could not classify the image. Please try another image.") |
|
|
|
|
|
with st.sidebar: |
|
st.header("About") |
|
st.write(""" |
|
This app uses the ViT (Vision Transformer) model trained for age classification. |
|
|
|
The model classifies images into the following age ranges: |
|
- 0-2 years |
|
- 3-9 years |
|
- 10-19 years |
|
- 20-29 years |
|
- 30-39 years |
|
- 40-49 years |
|
- 50-59 years |
|
- 60-69 years |
|
- 70+ years |
|
""") |
|
|
|
st.write("Model: nateraw/vit-age-classifier") |
|
st.write(f"Running on: {classifier.device}") |
|
|
|
if __name__ == "__main__": |
|
main() |