Leo Liu commited on
Commit
241398f
·
verified ·
1 Parent(s): 12e5235

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ from transformers import pipeline
6
+
7
+ @st.cache_data(show_spinner=False)
8
+ def load_age_classifier():
9
+ # Load and cache the image-classification pipeline for the age classifier
10
+ return pipeline("image-classification", model="nateraw/vit-age-classifier")
11
+
12
+ def classify_age(image: Image.Image):
13
+ """
14
+ Classify the age of a person in an image using the nateraw/vit-age-classifier model.
15
+
16
+ Args:
17
+ image (PIL.Image): The image to classify.
18
+
19
+ Returns:
20
+ list: Predictions with labels and corresponding confidence scores.
21
+ """
22
+ age_classifier = load_age_classifier()
23
+ return age_classifier(image)
24
+
25
+ def main():
26
+ st.title("Age Classification with ViT Age Classifier")
27
+ st.write("This demo uses the `nateraw/vit-age-classifier` model from Hugging Face to predict age categories from facial images.")
28
+
29
+ # Let the user choose the input method
30
+ input_method = st.radio("Select input method:", ("Image URL", "Upload an Image"))
31
+
32
+ image = None
33
+
34
+ if input_method == "Image URL":
35
+ image_url = st.text_input(
36
+ "Enter the Image URL",
37
+ "https://github.com/dchen236/FairFace/blob/master/detected_faces/race_Asian_face0.jpg?raw=true"
38
+ )
39
+ if image_url:
40
+ try:
41
+ response = requests.get(image_url)
42
+ image = Image.open(BytesIO(response.content)).convert("RGB")
43
+ st.image(image, caption="Input Image from URL", use_column_width=True)
44
+ except Exception as e:
45
+ st.error(f"Error loading image from URL: {e}")
46
+ else:
47
+ uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
48
+ if uploaded_file is not None:
49
+ try:
50
+ image = Image.open(uploaded_file).convert("RGB")
51
+ st.image(image, caption="Uploaded Image", use_column_width=True)
52
+ except Exception as e:
53
+ st.error(f"Error processing uploaded image: {e}")
54
+
55
+ if image is not None:
56
+ if st.button("Classify Age"):
57
+ with st.spinner("Classifying..."):
58
+ predictions = classify_age(image)
59
+ st.write("### Classification Results:")
60
+ for pred in predictions:
61
+ st.write(f"**Label:** {pred['label']} | **Confidence:** {pred['score']:.2f}")
62
+
63
+ if __name__ == "__main__":
64
+ main()