Sagnik1750 commited on
Commit
26e431e
Β·
verified Β·
1 Parent(s): ae9c4b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -83
app.py CHANGED
@@ -9,7 +9,7 @@ from facenet_pytorch import MTCNN
9
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
10
  from PIL import Image
11
  import os
12
- import time
13
 
14
  # Load models
15
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -18,13 +18,13 @@ model = AutoModelForImageClassification.from_pretrained("trpakov/vit-face-expres
18
  extractor = AutoFeatureExtractor.from_pretrained("trpakov/vit-face-expression")
19
 
20
  # Emotion labels
21
- emotion_labels = {
22
  0: "neutral", 1: "happy", 2: "sad", 3: "surprise", 4: "fear",
23
  5: "disgust", 6: "anger", 7: "contempt"
24
  }
25
 
26
- # Emotion detection function
27
- def detect_emotion(frame):
28
  img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
29
  faces, _ = mtcnn.detect(img)
30
  if faces is None or len(faces) == 0:
@@ -34,29 +34,27 @@ def detect_emotion(frame):
34
  inputs = extractor(images=face, return_tensors="pt").to(device)
35
  outputs = model(**inputs)
36
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
37
  return model.config.id2label[torch.argmax(probs).item()]
38
 
39
- # Process Video
40
- def process_video(video_path, progress=gr.Progress()):
41
- cap = cv2.VideoCapture(video_path)
42
  fps = int(cap.get(cv2.CAP_PROP_FPS))
43
  frame_width, frame_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
44
- out_path = "output_video.mp4"
45
- out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
46
-
47
- emotion_counts = {}
48
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
- processed_frames = 0
50
 
 
 
51
  while cap.isOpened():
52
  ret, frame = cap.read()
53
  if not ret:
54
  break
55
 
56
- emotion = detect_emotion(frame)
57
- emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1
58
 
59
- # Overlay emotion on frame
60
  overlay = frame.copy()
61
  cv2.rectangle(overlay, (10, 10), (350, 80), (255, 255, 255), -1)
62
  cv2.putText(overlay, f'Emotion: {emotion}', (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
@@ -64,78 +62,41 @@ def process_video(video_path, progress=gr.Progress()):
64
 
65
  out.write(frame)
66
 
67
- processed_frames += 1
68
- progress((processed_frames / total_frames) * 100)
69
-
70
  cap.release()
71
  out.release()
 
 
 
 
 
72
 
73
- # Generate Pie Chart
74
- plt.figure(figsize=(6, 6))
75
- labels, sizes = zip(*emotion_counts.items())
76
  plt.pie(sizes, labels=labels, autopct='%1.1f%%', colors=sns.color_palette('pastel'))
77
  plt.title("Emotion Distribution")
78
  plt.savefig("emotion_distribution.jpg")
79
  plt.close()
80
 
81
- return emotion, "emotion_distribution.jpg", out_path
82
-
83
- # Custom CSS for styling
84
- css = """
85
- h1 {
86
- text-align: center;
87
- color: #ffffff;
88
- font-size: 32px;
89
- font-weight: bold;
90
- }
91
-
92
- .gradio-container {
93
- background-color: #1E1E1E;
94
- color: #ffffff;
95
- padding: 20px;
96
- font-family: 'Arial', sans-serif;
97
- }
98
-
99
- button {
100
- font-size: 18px !important;
101
- padding: 10px 15px !important;
102
- background-color: #00BFFF !important;
103
- color: white !important;
104
- border-radius: 10px !important;
105
- }
106
-
107
- .gr-text-input {
108
- background-color: #2E2E2E;
109
- color: white;
110
- border: 1px solid #00BFFF;
111
- }
112
- """
113
-
114
- # Gradio Interface with Enhanced UI
115
- with gr.Blocks(css=css) as demo:
116
- with gr.Row():
117
- gr.Markdown("<h1>🎭 Emotion Analysis from Video πŸŽ₯</h1>")
118
-
119
- with gr.Row():
120
- video_input = gr.File(label="πŸ“€ Upload your video", type="filepath")
121
-
122
- with gr.Row():
123
- analyze_button = gr.Button("πŸš€ Analyze Video")
124
-
125
- with gr.Row():
126
- result_text = gr.Textbox(label="Detected Emotion", interactive=False)
127
-
128
- with gr.Row():
129
- emotion_chart = gr.Image(label="πŸ“Š Emotion Distribution", interactive=False)
130
-
131
- with gr.Row():
132
- processed_video = gr.Video(label="🎞 Processed Video with Emotion Detection")
133
-
134
- analyze_button.click(
135
- process_video,
136
- inputs=[video_input],
137
- outputs=[result_text, emotion_chart, processed_video]
138
- )
139
-
140
- # Launch Gradio app
141
- demo.launch()
 
9
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
10
  from PIL import Image
11
  import os
12
+ from collections import Counter
13
 
14
  # Load models
15
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
18
  extractor = AutoFeatureExtractor.from_pretrained("trpakov/vit-face-expression")
19
 
20
  # Emotion labels
21
+ affectnet_labels = {
22
  0: "neutral", 1: "happy", 2: "sad", 3: "surprise", 4: "fear",
23
  5: "disgust", 6: "anger", 7: "contempt"
24
  }
25
 
26
+ def detect_emotions(frame):
27
+ """Detects facial emotions in a given frame."""
28
  img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
29
  faces, _ = mtcnn.detect(img)
30
  if faces is None or len(faces) == 0:
 
34
  inputs = extractor(images=face, return_tensors="pt").to(device)
35
  outputs = model(**inputs)
36
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
+
38
  return model.config.id2label[torch.argmax(probs).item()]
39
 
40
+ def process_video(input_path):
41
+ """Processes video, overlays emotions, and creates a summary chart."""
42
+ cap = cv2.VideoCapture(input_path)
43
  fps = int(cap.get(cv2.CAP_PROP_FPS))
44
  frame_width, frame_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
45
+ out = cv2.VideoWriter("output_video.mp4", cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
 
 
 
 
 
46
 
47
+ emotion_counts = []
48
+
49
  while cap.isOpened():
50
  ret, frame = cap.read()
51
  if not ret:
52
  break
53
 
54
+ emotion = detect_emotions(frame)
55
+ emotion_counts.append(emotion)
56
 
57
+ # Overlay emotion
58
  overlay = frame.copy()
59
  cv2.rectangle(overlay, (10, 10), (350, 80), (255, 255, 255), -1)
60
  cv2.putText(overlay, f'Emotion: {emotion}', (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
 
62
 
63
  out.write(frame)
64
 
 
 
 
65
  cap.release()
66
  out.release()
67
+ cv2.destroyAllWindows()
68
+
69
+ # Find major emotion
70
+ emotion_counter = Counter(emotion_counts)
71
+ major_emotion = emotion_counter.most_common(1)[0][0] if emotion_counter else "No Face Detected"
72
 
73
+ # Generate emotion distribution pie chart
74
+ plt.figure(figsize=(5, 5))
75
+ labels, sizes = zip(*emotion_counter.items())
76
  plt.pie(sizes, labels=labels, autopct='%1.1f%%', colors=sns.color_palette('pastel'))
77
  plt.title("Emotion Distribution")
78
  plt.savefig("emotion_distribution.jpg")
79
  plt.close()
80
 
81
+ return "output_video.mp4", "emotion_distribution.jpg", major_emotion
82
+
83
+ # Gradio Web Interface
84
+ gr.Interface(
85
+ fn=process_video,
86
+ inputs=gr.File(type="filepath"),
87
+ outputs=[
88
+ gr.File(label="Processed Video"),
89
+ gr.File(label="Emotion Distribution Chart"),
90
+ gr.Textbox(label="Major Emotion Detected")
91
+ ],
92
+ title="Emotion Detection from Video",
93
+ description="Upload a video, and the AI will detect emotions in each frame, providing a processed video, an emotion distribution chart, and the major detected emotion.",
94
+ css="""
95
+ .gradio-container { max-width: 800px !important; margin: auto; }
96
+ .gradio-container h1 { font-size: 22px; }
97
+ @media screen and (max-width: 768px) {
98
+ .gradio-container { width: 100%; padding: 10px; }
99
+ .gradio-container h1 { font-size: 18px; }
100
+ }
101
+ """
102
+ ).launch()