import streamlit as st from io import BytesIO from PIL import Image from transformers import pipeline @st.cache_data(show_spinner=False) def load_age_classifier(): # Load and cache the image-classification pipeline for the age classifier return pipeline("image-classification", model="nateraw/vit-age-classifier") def classify_age(image: Image.Image): """ Classify the age of a person in an image using the nateraw/vit-age-classifier model. Args: image (PIL.Image.Image): The image to classify. Returns: list: Predictions with labels and corresponding confidence scores. """ age_classifier = load_age_classifier() return age_classifier(image) def main(): st.title("Age Classification with ViT Age Classifier") st.write("Upload an image to predict the age category using the `nateraw/vit-age-classifier` model.") # Upload an image uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) if st.button("Classify Age"): with st.spinner("Classifying..."): predictions = classify_age(image) st.write("### Classification Results:") for pred in predictions: st.write(f"**Label:** {pred['label']} | **Confidence:** {pred['score']:.2f}") except Exception as e: st.error(f"Error processing uploaded image: {e}") if __name__ == "__main__": main()