ctizzzy0 commited on
Commit
13af33f
·
verified ·
1 Parent(s): 2a0fd08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -87
app.py CHANGED
@@ -1,106 +1,363 @@
1
- import gradio as gr
2
- from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC
3
- import torch
4
- from PIL import Image
5
- import cv2
 
 
6
  import numpy as np
7
- import matplotlib.pyplot as plt
8
  import pandas as pd
 
 
 
 
 
9
  from fpdf import FPDF
10
- import os
11
-
12
- # ---------------- MODELS ----------------
13
- TEXT_EMO_MODEL = "j-hartmann/emotion-english-distilroberta-base"
14
- VOICE_EMO_MODEL = "superb/wav2vec2-base-superb-er"
15
- FACE_EMO_MODEL = "trpakov/vit-face-expression" # public model
16
-
17
- # Pipelines
18
- text_emo = pipeline("text-classification", model=TEXT_EMO_MODEL, top_k=None)
19
- voice_emo = pipeline("audio-classification", model=VOICE_EMO_MODEL, top_k=None)
20
- face_emo = pipeline("image-classification", model=FACE_EMO_MODEL, top_k=None)
21
-
22
- # ---------------- HELPERS ----------------
23
- def detect_face_and_crop(image_path):
24
- img = cv2.imread(image_path)
25
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
26
- face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
27
- faces = face_cascade.detectMultiScale(gray, 1.3, 5)
28
- if len(faces) > 0:
29
- (x, y, w, h) = faces[0]
30
- img = img[y:y+h, x:x+w]
31
- img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
32
- return img_pil
33
-
34
- def analyze(text, audio_path, image_path):
35
- results = {}
36
-
37
- if text:
38
- text_res = text_emo(text)[0]
39
- results["Text"] = text_res
40
- else:
41
- text_res = []
42
-
43
- if audio_path:
44
- voice_res = voice_emo(audio_path)[0]
45
- results["Voice"] = voice_res
46
- else:
47
- voice_res = []
48
-
49
- if image_path:
50
- cropped_face = detect_face_and_crop(image_path)
51
- face_res = face_emo(cropped_face)[0]
52
- results["Face"] = face_res
53
- else:
54
- face_res = []
55
-
56
- # Create plot
57
- plt.figure(figsize=(8,5))
58
- for modality, data in results.items():
59
- labels = [d['label'] for d in data]
60
- scores = [d['score']*100 for d in data]
61
- plt.bar(labels, scores, alpha=0.6, label=modality)
62
- plt.xticks(rotation=45, ha="right")
63
- plt.ylabel("Probability (%)")
64
- plt.legend()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  plt.tight_layout()
66
- plot_path = "emotion_plot.png"
67
- plt.savefig(plot_path)
68
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Create PDF
71
  pdf = FPDF()
72
  pdf.add_page()
73
  pdf.set_font("Arial", size=16)
74
- pdf.cell(200, 10, "Multi-Modal Emotion Analysis", ln=True, align='C')
75
- pdf.image(plot_path, x=10, y=30, w=180)
76
- pdf.ln(120)
77
  pdf.set_font("Arial", size=12)
78
- for modality, data in results.items():
79
- pdf.cell(200, 10, f"{modality}:", ln=True)
80
- for d in data:
81
- pdf.cell(200, 8, f"{d['label']}: {d['score']*100:.2f}%", ln=True)
82
- pdf_output = "emotion_report.pdf"
83
- pdf.output(pdf_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- return plot_path, pdf_output
 
 
 
86
 
87
- # ---------------- UI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
89
- gr.Markdown("## 🧠 Multi-Modal Emotion AI (Text + Voice + Face)")
90
- gr.Markdown("Analyze emotions from your **words**, **voice**, and **face**, then download a PDF report.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- with gr.Row():
93
- text_in = gr.Textbox(label="Enter Text", placeholder="Type something meaningful...")
94
- audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or record voice (optional)")
95
- img_in = gr.Image(type="filepath", label="Upload a face image (optional)")
 
 
96
 
97
- run_btn = gr.Button("Analyze Emotions", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- with gr.Row():
100
- plot_out = gr.Image(label="Emotion Plot")
101
- pdf_out = gr.File(label="Download Report")
 
102
 
103
- run_btn.click(analyze, inputs=[text_in, audio_in, img_in], outputs=[plot_out, pdf_out])
 
 
 
 
 
104
 
105
  app = demo
106
 
 
1
+ # app.py Multi-Modal Emotion AI (Text + Voice + Face)
2
+ # Features: per-modality analysis, fusion (weighted), safety screen, CBT distortions,
3
+ # PDF report with charts, trends logging, face auto-crop. CPU-friendly for HF Spaces.
4
+
5
+ import os, io, json, datetime
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
  import numpy as np
 
9
  import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+ import cv2
13
+
14
+ import gradio as gr
15
  from fpdf import FPDF
16
+ from transformers import pipeline
17
+
18
+ # -----------------------------
19
+ # Public, lightweight models
20
+ # -----------------------------
21
+ TEXT_MODEL = "SamLowe/roberta-base-go_emotions" # 27 emotions
22
+ VOICE_MODEL = "superb/wav2vec2-base-superb-er" # speech emotion recognition
23
+ FACE_MODEL = "trpakov/vit-face-expression" # facial expression (ViT)
24
+
25
+ text_pipe = pipeline("text-classification", model=TEXT_MODEL, top_k=None)
26
+ voice_pipe = pipeline("audio-classification", model=VOICE_MODEL, top_k=None)
27
+ face_pipe = pipeline("image-classification", model=FACE_MODEL, top_k=None)
28
+
29
+ # -----------------------------
30
+ # Files / persistence
31
+ # -----------------------------
32
+ RUN_LOG = "runs.csv"
33
+ if not os.path.exists(RUN_LOG):
34
+ pd.DataFrame(columns=["timestamp","text","text_top","voice_top","face_top","fused_top","pos_index"]).to_csv(RUN_LOG, index=False)
35
+
36
+ os.makedirs("charts", exist_ok=True)
37
+
38
+ # -----------------------------
39
+ # Safety & CBT
40
+ # -----------------------------
41
+ RISK_TERMS = {
42
+ "self_harm": ["kill myself","end it","suicide","self harm","cutting","overdose"],
43
+ "violence": ["hurt them","attack","kill them","shoot","stab","revenge"]
44
+ }
45
+
46
+ DISTORTIONS = {
47
+ "catastrophizing": ["ruined","disaster","worst ever","nothing will work","everything is over"],
48
+ "all_or_nothing": ["always","never","completely","totally","entirely"],
49
+ "mind_reading": ["they think","everyone thinks","people will think"],
50
+ "fortune_telling": ["will fail","will go wrong","i'm doomed"],
51
+ "labeling": ["i'm a failure","i'm useless","i'm stupid"],
52
+ "should_statements": ["should","must","have to"],
53
+ "discount_positive": ["doesn't count","just luck","not a big deal"]
54
+ }
55
+ REFRAMES = {
56
+ "catastrophizing": "Zoom out: list 3 realistic outcomes besides worst-case.",
57
+ "all_or_nothing": "Find the gray: what % went right vs wrong?",
58
+ "mind_reading": "Check evidence: what did they actually say/do?",
59
+ "fortune_telling": "Run a small test that could disconfirm your prediction.",
60
+ "labeling": "Describe the behavior, not your identity.",
61
+ "should_statements": "Swap ‘should’ → ‘I prefer / I will try’.",
62
+ "discount_positive": "Write 3 things you handled well and why they matter."
63
+ }
64
+
65
+ def safety_screen(text: str) -> Tuple[str, Dict[str, List[str]]]:
66
+ t = (text or "").lower()
67
+ hits = {k:[w for w in v if w in t] for k,v in RISK_TERMS.items()}
68
+ hits = {k:v for k,v in hits.items() if v}
69
+ return ("high" if hits else "low"), hits
70
+
71
+ def detect_distortions(text: str) -> List[str]:
72
+ t = (text or "").lower()
73
+ found = []
74
+ for name, cues in DISTORTIONS.items():
75
+ if any(cue in t for cue in cues):
76
+ found.append(name)
77
+ return sorted(set(found))
78
+
79
+ def reframe_tips(names: List[str]) -> List[str]:
80
+ return [REFRAMES[n] for n in names if n in REFRAMES]
81
+
82
+ # -----------------------------
83
+ # Emotion utilities
84
+ # -----------------------------
85
+ POSITIVE = set(["admiration","amusement","approval","gratitude","joy","love","optimism","relief","pride","excitement"])
86
+ NEGATIVE = set(["anger","annoyance","disappointment","disapproval","disgust","embarrassment","fear","grief","nervousness","remorse","sadness"])
87
+
88
+ def to_probs(outputs) -> Dict[str,float]:
89
+ # pipelines return list[list[{"label","score"}]] when top_k=None
90
+ if isinstance(outputs, list) and outputs and isinstance(outputs[0], list):
91
+ outputs = outputs[0]
92
+ d = {o["label"]: float(o["score"]) for o in outputs}
93
+ s = sum(d.values()) or 1.0
94
+ return {k: v/s for k,v in d.items()}
95
+
96
+ def top_item(prob: Optional[Dict[str,float]]) -> str:
97
+ if not prob: return ""
98
+ k = max(prob, key=prob.get)
99
+ return f"{k} ({prob[k]*100:.1f}%)"
100
+
101
+ def positivity_index(prob: Optional[Dict[str,float]]) -> float:
102
+ if not prob: return 0.5
103
+ pos = sum(prob.get(k,0.0) for k in POSITIVE)
104
+ neg = sum(prob.get(k,0.0) for k in NEGATIVE)
105
+ return round((pos - neg + 1)/2, 4) # [-1,1] -> [0,1]
106
+
107
+ def union_merge(dicts: List[Optional[Dict[str,float]]], weights: List[float]) -> Dict[str,float]:
108
+ labels = set()
109
+ for d in dicts:
110
+ if d: labels |= set(d.keys())
111
+ merged = {l:0.0 for l in labels}
112
+ for d, w in zip(dicts, weights):
113
+ if not d: continue
114
+ for l in labels:
115
+ merged[l] += w * d.get(l, 0.0)
116
+ s = sum(merged.values()) or 1.0
117
+ return {k:v/s for k,v in merged.items()}
118
+
119
+ def bar_fig(prob: Dict[str,float], title: str):
120
+ labels = list(prob.keys())
121
+ vals = [prob[k]*100 for k in labels]
122
+ fig, ax = plt.subplots(figsize=(7.0, 3.6))
123
+ ax.bar(labels, vals)
124
+ ax.set_ylim(0, 100)
125
+ ax.set_ylabel("Probability (%)")
126
+ ax.set_title(title)
127
+ for i, v in enumerate(vals):
128
+ ax.text(i, v + 1, f"{v:.1f}%", ha="center", fontsize=8)
129
+ plt.xticks(rotation=28, ha="right")
130
  plt.tight_layout()
131
+ return fig
132
+
133
+ def save_chart(prob: Dict[str,float], title: str, path: str):
134
+ fig = bar_fig(prob, title)
135
+ fig.savefig(path, dpi=160, bbox_inches="tight")
136
+ plt.close(fig)
137
+
138
+ # -----------------------------
139
+ # Computer vision: face crop
140
+ # -----------------------------
141
+ HAAR = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
142
+ def crop_face(image_path: str) -> Image.Image:
143
+ try:
144
+ img = cv2.imread(image_path)
145
+ if img is None: # fallback
146
+ return Image.open(image_path).convert("RGB")
147
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
148
+ faces = HAAR.detectMultiScale(gray, scaleFactor=1.2, minNeighbors=5, minSize=(80,80))
149
+ if len(faces) > 0:
150
+ x,y,w,h = sorted(faces, key=lambda b:b[2]*b[3], reverse=True)[0]
151
+ img = img[y:y+h, x:x+w]
152
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
153
+ except Exception:
154
+ return Image.open(image_path).convert("RGB")
155
+
156
+ # -----------------------------
157
+ # Per-modality inference
158
+ # -----------------------------
159
+ def analyze_text(text: str):
160
+ if not text or not text.strip():
161
+ return gr.Error("Please enter text."), None, None
162
+ probs = to_probs(text_pipe(text))
163
+ msg = f"**Top Text Emotion:** {top_item(probs)} | **Positivity Index:** {positivity_index(probs):.2f}"
164
+ fig = bar_fig(probs, "Text Emotions")
165
+ return msg, fig, json.dumps(probs)
166
+
167
+ def analyze_voice(audio_path: Optional[str]):
168
+ if not audio_path:
169
+ return "No audio provided.", None, None
170
+ probs = to_probs(voice_pipe(audio_path))
171
+ msg = f"**Top Voice Emotion:** {top_item(probs)}"
172
+ fig = bar_fig(probs, "Voice Emotions")
173
+ return msg, fig, json.dumps(probs)
174
+
175
+ def analyze_face(image_path: Optional[str]):
176
+ if not image_path:
177
+ return "No image provided.", None, None
178
+ face_img = crop_face(image_path)
179
+ probs = to_probs(face_pipe(face_img))
180
+ msg = f"**Top Face Emotion:** {top_item(probs)}"
181
+ fig = bar_fig(probs, "Face Emotions")
182
+ return msg, fig, json.dumps(probs)
183
+
184
+ # -----------------------------
185
+ # PDF Report
186
+ # -----------------------------
187
+ def build_pdf(text_in: str,
188
+ text_prob: Optional[Dict[str,float]],
189
+ voice_prob: Optional[Dict[str,float]],
190
+ face_prob: Optional[Dict[str,float]],
191
+ fused_prob: Optional[Dict[str,float]],
192
+ safety_level: str, safety_hits: Dict[str,List[str]],
193
+ distortions: List[str], tips: List[str]) -> str:
194
+
195
+ # save charts
196
+ paths = []
197
+ if text_prob: save_chart(text_prob, "Text Emotions", "charts/text.png"); paths.append("charts/text.png")
198
+ if voice_prob: save_chart(voice_prob, "Voice Emotions", "charts/voice.png"); paths.append("charts/voice.png")
199
+ if face_prob: save_chart(face_prob, "Face Emotions", "charts/face.png"); paths.append("charts/face.png")
200
+ if fused_prob: save_chart(fused_prob, "Fused Profile", "charts/fused.png"); paths.append("charts/fused.png")
201
 
 
202
  pdf = FPDF()
203
  pdf.add_page()
204
  pdf.set_font("Arial", size=16)
205
+ pdf.cell(0, 10, "Multi-Modal Emotion Report", ln=True, align="C")
206
+
 
207
  pdf.set_font("Arial", size=12)
208
+ pdf.cell(0, 8, f"Timestamp: {datetime.datetime.now().isoformat(sep=' ', timespec='seconds')}", ln=True)
209
+ pdf.multi_cell(0, 8, f"Input Text: {text_in or '(none)'}")
210
+ pdf.ln(2)
211
+
212
+ if safety_level == "high":
213
+ pdf.set_text_color(220,0,0)
214
+ pdf.multi_cell(0, 8, "⚠ High-risk language detected. If you’re in immediate danger, contact local emergency services.")
215
+ pdf.multi_cell(0, 8, "US: 988 (Suicide & Crisis Lifeline)")
216
+ if safety_hits:
217
+ pdf.multi_cell(0, 8, f"Matched terms: {json.dumps(safety_hits)}")
218
+ pdf.set_text_color(0,0,0)
219
+ pdf.ln(2)
220
+
221
+ if distortions:
222
+ pdf.cell(0, 8, f"Cognitive distortions: {', '.join(distortions)}", ln=True)
223
+ if tips:
224
+ pdf.cell(0, 8, "Reframe suggestions:", ln=True)
225
+ for t in tips:
226
+ pdf.multi_cell(0, 7, f" • {t}")
227
+ pdf.ln(2)
228
 
229
+ for p in paths:
230
+ if os.path.exists(p):
231
+ pdf.image(p, w=180)
232
+ pdf.ln(4)
233
 
234
+ out = "emotion_report.pdf"
235
+ pdf.output(out)
236
+ return out
237
+
238
+ # -----------------------------
239
+ # Trends
240
+ # -----------------------------
241
+ def log_run(row: dict):
242
+ df = pd.read_csv(RUN_LOG)
243
+ df.loc[len(df)] = row
244
+ df.to_csv(RUN_LOG, index=False)
245
+
246
+ def plot_trends():
247
+ if not os.path.exists(RUN_LOG) or os.path.getsize(RUN_LOG) == 0:
248
+ return None
249
+ df = pd.read_csv(RUN_LOG)
250
+ if df.empty: return None
251
+ df["date"] = pd.to_datetime(df["timestamp"]).dt.date
252
+ daily = df.groupby("date")["pos_index"].mean().reset_index()
253
+ fig, ax = plt.subplots(figsize=(7,3.2))
254
+ ax.plot(daily["date"], daily["pos_index"], marker="o")
255
+ ax.set_ylim(0,1)
256
+ ax.set_ylabel("Positivity Index (0-1)")
257
+ ax.set_title("Positivity Trend")
258
+ plt.xticks(rotation=25, ha="right"); plt.tight_layout()
259
+ return fig
260
+
261
+ # -----------------------------
262
+ # Fusion handler
263
+ # -----------------------------
264
+ def fuse_and_report(text_json, voice_json, face_json, text_raw, w_text, w_voice, w_face):
265
+ te = json.loads(text_json) if text_json else None
266
+ ve = json.loads(voice_json) if voice_json else None
267
+ fe = json.loads(face_json) if face_json else None
268
+ weights = [w_text, w_voice, w_face]
269
+ s = sum(weights) or 1.0
270
+ weights = [w/s for w in weights]
271
+ fused = union_merge([te, ve, fe], weights) if (te or ve or fe) else None
272
+
273
+ # safety + CBT
274
+ safety_level, safety_hits = safety_screen(text_raw or "")
275
+ distos = detect_distortions(text_raw or "")
276
+ tips = reframe_tips(distos)
277
+
278
+ # pdf
279
+ pdf_path = build_pdf(text_raw, te, ve, fe, fused, safety_level, safety_hits, distos, tips)
280
+
281
+ # log
282
+ pi_val = positivity_index(te)
283
+ log_run({
284
+ "timestamp": datetime.datetime.now().isoformat(sep=" ", timespec="seconds"),
285
+ "text": text_raw or "",
286
+ "text_top": top_item(te),
287
+ "voice_top": top_item(ve),
288
+ "face_top": top_item(fe),
289
+ "fused_top": top_item(fused),
290
+ "pos_index": pi_val
291
+ })
292
+
293
+ msg = f"**Fused Top:** {top_item(fused) or '(insufficient inputs)'} | Weights → Text:{weights[0]:.2f}, Voice:{weights[1]:.2f}, Face:{weights[2]:.2f}"
294
+ plot = bar_fig(fused, "Fused Emotional Profile") if fused else None
295
+ return msg, plot, pdf_path
296
+
297
+ # -----------------------------
298
+ # Gradio UI
299
+ # -----------------------------
300
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
301
+ gr.Markdown("# 🧠 Multi-Modal Emotion AI (Text + Voice + Face)")
302
+ gr.Markdown("Analyze emotions across **text, voice, and face**, detect **safety risks** and **cognitive distortions**, "
303
+ "tune **fusion weights**, and download a **PDF report**. Audio/image are optional.")
304
+
305
+ # state holders
306
+ st_text_json = gr.State()
307
+ st_voice_json = gr.State()
308
+ st_face_json = gr.State()
309
+ st_text_raw = gr.State()
310
+
311
+ with gr.Tab("📝 Text"):
312
+ t_in = gr.Textbox(label="Your text", lines=3, placeholder="How are you feeling today?")
313
+ t_btn = gr.Button("Analyze Text", variant="primary")
314
+ t_msg = gr.Markdown()
315
+ t_plot = gr.Plot()
316
+ def _t_chain(txt):
317
+ msg, fig, j = analyze_text(txt)
318
+ return msg, fig, j, txt
319
+ t_btn.click(_t_chain, inputs=t_in, outputs=[t_msg, t_plot, st_text_json, st_text_raw])
320
+
321
+ with gr.Tab("🎤 Voice"):
322
+ a_in = gr.Audio(sources=["microphone","upload"], type="filepath", label="Record or upload audio (optional)")
323
+ a_btn = gr.Button("Analyze Voice", variant="primary")
324
+ a_msg = gr.Markdown()
325
+ a_plot = gr.Plot()
326
+ a_btn.click(analyze_voice, inputs=a_in, outputs=[a_msg, a_plot, st_voice_json])
327
 
328
+ with gr.Tab("📷 Face"):
329
+ f_in = gr.Image(type="filepath", label="Upload a face image (optional)")
330
+ f_btn = gr.Button("Analyze Face", variant="primary")
331
+ f_msg = gr.Markdown()
332
+ f_plot = gr.Plot()
333
+ f_btn.click(analyze_face, inputs=f_in, outputs=[f_msg, f_plot, st_face_json])
334
 
335
+ with gr.Tab("🧩 Fusion + Report"):
336
+ with gr.Row():
337
+ w_text = gr.Slider(0, 1, value=0.5, step=0.05, label="Text weight")
338
+ w_voice = gr.Slider(0, 1, value=0.3, step=0.05, label="Voice weight")
339
+ w_face = gr.Slider(0, 1, value=0.2, step=0.05, label="Face weight")
340
+ fuse_btn = gr.Button("Fuse & Generate PDF", variant="primary")
341
+ fuse_msg = gr.Markdown()
342
+ fuse_plot = gr.Plot()
343
+ fuse_pdf = gr.File(label="Download Report")
344
+ fuse_btn.click(
345
+ fuse_and_report,
346
+ inputs=[st_text_json, st_voice_json, st_face_json, st_text_raw, w_text, w_voice, w_face],
347
+ outputs=[fuse_msg, fuse_plot, fuse_pdf]
348
+ )
349
 
350
+ with gr.Tab("📈 Trends"):
351
+ tr_btn = gr.Button("Refresh Positivity Trend")
352
+ tr_plot = gr.Plot()
353
+ tr_btn.click(plot_trends, inputs=None, outputs=tr_plot)
354
 
355
+ with gr.Tab("ℹ️ About"):
356
+ gr.Markdown(
357
+ "Models: **GoEmotions (text)**, **Wav2Vec2-ER (audio)**, **ViT-Face-Expression (image)**. "
358
+ "Privacy: inputs are processed in-session; reports are generated client-side on this Space. "
359
+ "This is an educational demo — not medical advice."
360
+ )
361
 
362
  app = demo
363