Lab4 / app.py
slliac's picture
Create app.py
2f34ae4 verified
raw
history blame
4.54 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import requests
from io import BytesIO
import logging
import torch
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgeClassifier:
def __init__(self):
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = pipeline("image-classification", model="nateraw/vit-age-classifier", device=self.device)
logger.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to initialize pipeline: {e}")
raise
def classify_image(self, image):
try:
return self.pipe(image)
except Exception as e:
logger.error(f"Classification failed: {e}")
return None
@staticmethod
def format_results(results):
if not results:
return "No valid results"
return results
def load_image_from_url(url):
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
st.error(f"Error loading image from URL: {e}")
return None
def main():
st.set_page_config(
page_title="Age Classification App",
page_icon="πŸ‘€",
layout="wide"
)
st.title("Age Classification App πŸ‘€")
st.write("Upload an image or provide a URL to classify the age range of people in the image.")
# Initialize the classifier
@st.cache_resource
def get_classifier():
return AgeClassifier()
classifier = get_classifier()
# Create two columns for input methods
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
with col2:
st.subheader("Image URL")
image_url = st.text_input("Enter image URL")
# Process the image
image = None
if uploaded_file is not None:
image = Image.open(uploaded_file)
elif image_url:
image = load_image_from_url(image_url)
if image:
# Display the image
st.image(image, caption="Input Image", use_column_width=True)
# Add a classify button
if st.button("Classify Age"):
with st.spinner("Classifying..."):
results = classifier.classify_image(image)
if results:
# Create a bar chart
st.subheader("Classification Results")
# Convert results to format suitable for bar chart
labels = [r['label'] for r in results]
scores = [r['score'] * 100 for r in results]
# Display most likely age range
most_likely = max(results, key=lambda x: x['score'])
st.success(f"Most likely age range: {most_likely['label']} ({most_likely['score']*100:.1f}%)")
# Create bar chart
chart_data = {
'Age Range': labels,
'Confidence (%)': scores
}
st.bar_chart(chart_data, x='Age Range', y='Confidence (%)')
# Display detailed results in an expander
with st.expander("See detailed results"):
st.write("Confidence scores for all age ranges:")
for result in results:
st.write(f"{result['label']}: {result['score']*100:.1f}%")
else:
st.error("Could not classify the image. Please try another image.")
# Add information about the model
with st.sidebar:
st.header("About")
st.write("""
This app uses the ViT (Vision Transformer) model trained for age classification.
The model classifies images into the following age ranges:
- 0-2 years
- 3-9 years
- 10-19 years
- 20-29 years
- 30-39 years
- 40-49 years
- 50-59 years
- 60-69 years
- 70+ years
""")
st.write("Model: nateraw/vit-age-classifier")
st.write(f"Running on: {classifier.device}")
if __name__ == "__main__":
main()