xtlyxt commited on
Commit
1423812
·
verified ·
1 Parent(s): 9d99378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
 
 
4
 
5
  # Create an image classification pipeline with scores
6
  pipe = pipeline("image-classification", model="trpakov/vit-face-expression", top_k=None)
@@ -28,6 +30,7 @@ if uploaded_images:
28
  selected_images.append(image)
29
 
30
  if st.button("Predict Emotions") and selected_images:
 
31
  if len(selected_images) == 2:
32
  # Predict emotion for each selected image using the pipeline
33
  results = [pipe(image) for image in selected_images]
@@ -37,6 +40,7 @@ if st.button("Predict Emotions") and selected_images:
37
  for i in range(2):
38
  predicted_class = results[i][0]["label"]
39
  predicted_emotion = predicted_class.split("_")[-1].capitalize()
 
40
  col = col1 if i == 0 else col2
41
  col.image(selected_images[i], caption=f"Predicted emotion: {predicted_emotion}", use_column_width=True)
42
  col.write(f"Emotion Scores: {predicted_emotion}: {results[i][0]['score']:.4f}")
@@ -61,8 +65,28 @@ if st.button("Predict Emotions") and selected_images:
61
  for i, (image, result) in enumerate(zip(selected_images, results)):
62
  predicted_class = result[0]["label"]
63
  predicted_emotion = predicted_class.split("_")[-1].capitalize()
 
64
  st.image(image, caption=f"Predicted emotion: {predicted_emotion}", use_column_width=True)
65
  st.write(f"Emotion Scores for #{i+1} Image")
66
  st.write(f"{predicted_emotion}: {result[0]['score']:.4f}")
67
  # Use the index to get the corresponding filename
68
  st.write(f"Original File Name: {uploaded_images[i].name if i < len(uploaded_images) else 'Unknown'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
 
7
  # Create an image classification pipeline with scores
8
  pipe = pipeline("image-classification", model="trpakov/vit-face-expression", top_k=None)
 
30
  selected_images.append(image)
31
 
32
  if st.button("Predict Emotions") and selected_images:
33
+ emotions = []
34
  if len(selected_images) == 2:
35
  # Predict emotion for each selected image using the pipeline
36
  results = [pipe(image) for image in selected_images]
 
40
  for i in range(2):
41
  predicted_class = results[i][0]["label"]
42
  predicted_emotion = predicted_class.split("_")[-1].capitalize()
43
+ emotions.append(predicted_emotion)
44
  col = col1 if i == 0 else col2
45
  col.image(selected_images[i], caption=f"Predicted emotion: {predicted_emotion}", use_column_width=True)
46
  col.write(f"Emotion Scores: {predicted_emotion}: {results[i][0]['score']:.4f}")
 
65
  for i, (image, result) in enumerate(zip(selected_images, results)):
66
  predicted_class = result[0]["label"]
67
  predicted_emotion = predicted_class.split("_")[-1].capitalize()
68
+ emotions.append(predicted_emotion)
69
  st.image(image, caption=f"Predicted emotion: {predicted_emotion}", use_column_width=True)
70
  st.write(f"Emotion Scores for #{i+1} Image")
71
  st.write(f"{predicted_emotion}: {result[0]['score']:.4f}")
72
  # Use the index to get the corresponding filename
73
  st.write(f"Original File Name: {uploaded_images[i].name if i < len(uploaded_images) else 'Unknown'}")
74
+
75
+ # Calculate emotion statistics
76
+ emotion_counts = pd.Series(emotions).value_counts()
77
+
78
+ # Plot pie chart
79
+ st.write("Emotion Distribution (Pie Chart):")
80
+ plt.figure(figsize=(8, 6))
81
+ plt.pie(emotion_counts, labels=emotion_counts.index, autopct='%1.1f%%', startangle=140)
82
+ plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
83
+ st.pyplot()
84
+
85
+ # Plot bar chart
86
+ st.write("Emotion Distribution (Bar Chart):")
87
+ plt.figure(figsize=(10, 6))
88
+ emotion_counts.plot(kind='bar', color='skyblue')
89
+ plt.xlabel('Emotion')
90
+ plt.ylabel('Count')
91
+ plt.title('Emotion Distribution')
92
+ st.pyplot()