Jiaaaaaaax commited on
Commit
baeff27
·
verified ·
1 Parent(s): 7ec11b6

Update session_analysis.py

Browse files
Files changed (1) hide show
  1. session_analysis.py +433 -31
session_analysis.py CHANGED
@@ -1,44 +1,446 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
 
 
 
 
 
 
 
 
4
 
5
  def show_session_analysis():
6
- st.title("Session Analysis")
7
 
8
- # File upload section
9
- st.header("Upload Session Data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  upload_type = st.radio(
12
- "Select upload type:",
13
- ["Audio", "Video", "Text", "Manual Input"]
14
  )
15
 
16
- if upload_type in ["Audio", "Video"]:
17
- file = st.file_uploader(f"Upload {upload_type} file", type=["wav", "mp3", "mp4"])
 
 
 
18
  if file:
19
- analyze_media_file(file, upload_type)
20
 
21
- elif upload_type == "Text":
22
- file = st.file_uploader("Upload text file", type=["txt", "doc", "docx"])
23
  if file:
24
- analyze_text_file(file)
25
-
26
- else: # Manual Input
27
- text_input = st.text_area("Enter session notes or transcript:")
28
- if text_input:
29
- analyze_text_input(text_input)
30
-
31
- def analyze_media_file(file, type):
32
- # Implement media file analysis
33
- st.write(f"Analyzing {type} file...")
34
- # Use your MI analysis prompts here
35
-
36
- def analyze_text_file(file):
37
- # Implement text file analysis
38
- st.write("Analyzing text file...")
39
- # Use your MI analysis prompts here
40
-
41
- def analyze_text_input(text):
42
- # Implement text input analysis
43
- st.write("Analyzing input...")
44
- # Use your MI analysis prompts here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ import google.generativeai as genai
6
+ from datetime import datetime
7
+ import json
8
+ import numpy as np
9
+ from docx import Document
10
+ import re
11
+ from prompts import SESSION_EVALUATION_PROMPT, MI_SYSTEM_PROMPT
12
 
13
  def show_session_analysis():
14
+ st.title("MI Session Analysis Dashboard")
15
 
16
+ # Initialize session state for analysis results
17
+ if 'analysis_results' not in st.session_state:
18
+ st.session_state.analysis_results = None
19
+ if 'current_transcript' not in st.session_state:
20
+ st.session_state.current_transcript = None
21
+
22
+ # Main layout
23
+ col1, col2 = st.columns([1, 2])
24
+
25
+ with col1:
26
+ show_upload_section()
27
+
28
+ with col2:
29
+ if st.session_state.analysis_results:
30
+ show_analysis_results()
31
+
32
+ def show_upload_section():
33
+ st.header("Session Data Upload")
34
 
35
  upload_type = st.radio(
36
+ "Select Input Method:",
37
+ ["Audio Recording", "Video Recording", "Text Transcript", "Session Notes", "Previous Session Data"]
38
  )
39
 
40
+ if upload_type in ["Audio Recording", "Video Recording"]:
41
+ file = st.file_uploader(
42
+ f"Upload {upload_type}",
43
+ type=["wav", "mp3", "mp4"] if upload_type == "Audio Recording" else ["mp4", "avi", "mov"]
44
+ )
45
  if file:
46
+ process_media_file(file, upload_type)
47
 
48
+ elif upload_type == "Text Transcript":
49
+ file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"])
50
  if file:
51
+ process_text_file(file)
52
+
53
+ elif upload_type == "Session Notes":
54
+ show_manual_input_form()
55
+
56
+ else: # Previous Session Data
57
+ show_previous_sessions_selector()
58
+
59
+ def process_media_file(file, type):
60
+ st.write(f"Processing {type}...")
61
+
62
+ # Add processing status
63
+ status = st.empty()
64
+ progress_bar = st.progress(0)
65
+
66
+ try:
67
+ # Simulated processing steps
68
+ for i in range(5):
69
+ status.text(f"Step {i+1}/5: " + get_processing_step_name(i))
70
+ progress_bar.progress((i + 1) * 20)
71
+
72
+ # Generate transcript
73
+ transcript = generate_transcript(file)
74
+ if transcript:
75
+ st.session_state.current_transcript = transcript
76
+ analyze_session_content(transcript)
77
+
78
+ except Exception as e:
79
+ st.error(f"Error processing file: {str(e)}")
80
+ finally:
81
+ status.empty()
82
+ progress_bar.empty()
83
+
84
+ def get_processing_step_name(step):
85
+ steps = [
86
+ "Loading media file",
87
+ "Converting to audio",
88
+ "Performing speech recognition",
89
+ "Generating transcript",
90
+ "Preparing analysis"
91
+ ]
92
+ return steps[step]
93
+
94
+ def process_text_file(file):
95
+ try:
96
+ if file.name.endswith('.json'):
97
+ content = json.loads(file.read().decode())
98
+ transcript = extract_transcript_from_json(content)
99
+ elif file.name.endswith('.docx'):
100
+ doc = Document(file)
101
+ transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
102
+ else:
103
+ transcript = file.read().decode()
104
+
105
+ if transcript:
106
+ st.session_state.current_transcript = transcript
107
+ analyze_session_content(transcript)
108
+
109
+ except Exception as e:
110
+ st.error(f"Error processing file: {str(e)}")
111
+
112
+ def show_manual_input_form():
113
+ st.subheader("Session Details")
114
+
115
+ # Session metadata
116
+ session_date = st.date_input("Session Date", datetime.now())
117
+ session_duration = st.number_input("Session Duration (minutes)", min_value=1, max_value=180, value=50)
118
+
119
+ # Client information
120
+ client_id = st.text_input("Client ID (optional)")
121
+ session_number = st.number_input("Session Number", min_value=1, value=1)
122
+
123
+ # Session content
124
+ session_notes = st.text_area(
125
+ "Session Notes",
126
+ height=300,
127
+ help="Enter detailed session notes including key dialogues, interventions, and observations"
128
+ )
129
+
130
+ # Target behaviors
131
+ target_behaviors = st.text_area(
132
+ "Target Behaviors/Goals",
133
+ height=100,
134
+ help="Enter the specific behaviors or goals discussed in the session"
135
+ )
136
+
137
+ # MI specific elements
138
+ st.subheader("MI Elements")
139
+ change_talk = st.text_area("Observed Change Talk")
140
+ sustain_talk = st.text_area("Observed Sustain Talk")
141
+
142
+ if st.button("Analyze Session"):
143
+ session_data = compile_session_data(
144
+ session_date, session_duration, client_id, session_number,
145
+ session_notes, target_behaviors, change_talk, sustain_talk
146
+ )
147
+ analyze_session_content(session_data)
148
+
149
+ def analyze_session_content(content):
150
+ try:
151
+ # Configure Gemini model
152
+ model = genai.GenerativeModel('gemini-pro')
153
+
154
+ # Prepare analysis prompt
155
+ analysis_prompt = f"""
156
+ Analyze the following therapy session using MI principles and provide a comprehensive evaluation:
157
+
158
+ Session Content:
159
+ {content}
160
+
161
+ Please provide detailed analysis including:
162
+ 1. MI Adherence Assessment:
163
+ - OARS implementation
164
+ - Change talk identification
165
+ - Resistance management
166
+ - MI spirit adherence
167
+
168
+ 2. Technical Skills Evaluation:
169
+ - Reflection quality and frequency
170
+ - Question-to-reflection ratio
171
+ - Open vs. closed questions
172
+ - Affirmations and summaries
173
+
174
+ 3. Client Language Analysis:
175
+ - Change talk instances
176
+ - Sustain talk patterns
177
+ - Commitment language
178
+ - Resistance patterns
179
+
180
+ 4. Session Flow Analysis:
181
+ - Engagement level
182
+ - Focus maintenance
183
+ - Evocation quality
184
+ - Planning effectiveness
185
+
186
+ 5. Recommendations:
187
+ - Strength areas
188
+ - Growth opportunities
189
+ - Suggested interventions
190
+ - Next session planning
191
+
192
+ Format the analysis with clear sections and specific examples from the session.
193
+ """
194
+
195
+ # Generate analysis
196
+ response = model.generate_content(analysis_prompt)
197
+
198
+ # Process and structure the analysis results
199
+ analysis_results = process_analysis_results(response.text)
200
+
201
+ # Store results in session state
202
+ st.session_state.analysis_results = analysis_results
203
+
204
+ # Show success message
205
+ st.success("Analysis completed successfully!")
206
+
207
+ except Exception as e:
208
+ st.error(f"Error during analysis: {str(e)}")
209
+
210
+ def process_analysis_results(raw_analysis):
211
+ """Process and structure the analysis results"""
212
+ # Parse the raw analysis text and extract structured data
213
+ sections = extract_analysis_sections(raw_analysis)
214
+
215
+ # Calculate metrics
216
+ metrics = calculate_mi_metrics(raw_analysis)
217
+
218
+ return {
219
+ "raw_analysis": raw_analysis,
220
+ "structured_sections": sections,
221
+ "metrics": metrics,
222
+ "timestamp": datetime.now().isoformat()
223
+ }
224
+
225
+ def show_analysis_results():
226
+ """Display comprehensive analysis results"""
227
+ if not st.session_state.analysis_results:
228
+ return
229
+
230
+ results = st.session_state.analysis_results
231
+
232
+ # Top-level metrics
233
+ show_mi_metrics_dashboard(results['metrics'])
234
+
235
+ # Detailed analysis sections
236
+ tabs = st.tabs([
237
+ "MI Adherence",
238
+ "Technical Skills",
239
+ "Client Language",
240
+ "Session Flow",
241
+ "Recommendations"
242
+ ])
243
+
244
+ with tabs[0]:
245
+ show_mi_adherence_analysis(results)
246
+
247
+ with tabs[1]:
248
+ show_technical_skills_analysis(results)
249
+
250
+ with tabs[2]:
251
+ show_client_language_analysis(results)
252
+
253
+ with tabs[3]:
254
+ show_session_flow_analysis(results)
255
+
256
+ with tabs[4]:
257
+ show_recommendations(results)
258
+
259
+ def show_mi_metrics_dashboard(metrics):
260
+ st.subheader("MI Performance Dashboard")
261
+
262
+ col1, col2, col3, col4 = st.columns(4)
263
+
264
+ with col1:
265
+ show_metric_card(
266
+ "MI Spirit Score",
267
+ metrics.get('mi_spirit_score', 0),
268
+ "0-5 scale"
269
+ )
270
+
271
+ with col2:
272
+ show_metric_card(
273
+ "Change Talk Ratio",
274
+ metrics.get('change_talk_ratio', 0),
275
+ "Change vs Sustain"
276
+ )
277
+
278
+ with col3:
279
+ show_metric_card(
280
+ "Reflection Ratio",
281
+ metrics.get('reflection_ratio', 0),
282
+ "Reflections/Questions"
283
+ )
284
+
285
+ with col4:
286
+ show_metric_card(
287
+ "Overall Adherence",
288
+ metrics.get('overall_adherence', 0),
289
+ "Percentage"
290
+ )
291
+
292
+ def show_metric_card(title, value, subtitle):
293
+ st.markdown(
294
+ f"""
295
+ <div style="border:1px solid #ccc; padding:10px; border-radius:5px; text-align:center;">
296
+ <h3>{title}</h3>
297
+ <h2>{value:.2f}</h2>
298
+ <p>{subtitle}</p>
299
+ </div>
300
+ """,
301
+ unsafe_allow_html=True
302
+ )
303
+
304
+ def show_mi_adherence_analysis(results):
305
+ st.subheader("MI Adherence Analysis")
306
+
307
+ # OARS Implementation
308
+ st.write("### OARS Implementation")
309
+ show_oars_chart(results['metrics'].get('oars_metrics', {}))
310
+
311
+ # MI Spirit Components
312
+ st.write("### MI Spirit Components")
313
+ show_mi_spirit_chart(results['metrics'].get('mi_spirit_metrics', {}))
314
+
315
+ # Detailed breakdown
316
+ st.write("### Detailed Analysis")
317
+ st.markdown(results['structured_sections'].get('mi_adherence', ''))
318
+
319
+ def show_technical_skills_analysis(results):
320
+ st.subheader("Technical Skills Analysis")
321
+
322
+ # Question Analysis
323
+ col1, col2 = st.columns(2)
324
+
325
+ with col1:
326
+ show_question_type_chart(results['metrics'].get('question_metrics', {}))
327
+
328
+ with col2:
329
+ show_reflection_depth_chart(results['metrics'].get('reflection_metrics', {}))
330
+
331
+ # Detailed analysis
332
+ st.markdown(results['structured_sections'].get('technical_skills', ''))
333
+
334
+ def show_client_language_analysis(results):
335
+ st.subheader("Client Language Analysis")
336
+
337
+ # Change Talk Timeline
338
+ show_change_talk_timeline(results['metrics'].get('change_talk_timeline', []))
339
+
340
+ # Language Categories
341
+ show_language_categories_chart(results['metrics'].get('language_categories', {}))
342
+
343
+ # Detailed analysis
344
+ st.markdown(results['structured_sections'].get('client_language', ''))
345
+
346
+ def show_session_flow_analysis(results):
347
+ st.subheader("Session Flow Analysis")
348
+
349
+ # Session Flow Timeline
350
+ show_session_flow_timeline(results['metrics'].get('session_flow', []))
351
+
352
+ # Engagement Metrics
353
+ show_engagement_metrics(results['metrics'].get('engagement_metrics', {}))
354
+
355
+ # Detailed analysis
356
+ st.markdown(results['structured_sections'].get('session_flow', ''))
357
+
358
+ def show_recommendations(results):
359
+ st.subheader("Recommendations and Next Steps")
360
+
361
+ col1, col2 = st.columns(2)
362
+
363
+ with col1:
364
+ st.write("### Strengths")
365
+ strengths = results['structured_sections'].get('strengths', [])
366
+ for strength in strengths:
367
+ st.markdown(f"✓ {strength}")
368
+
369
+ with col2:
370
+ st.write("### Growth Areas")
371
+ growth_areas = results['structured_sections'].get('growth_areas', [])
372
+ for area in growth_areas:
373
+ st.markdown(f"→ {area}")
374
+
375
+ st.write("### Suggested Interventions")
376
+ st.markdown(results['structured_sections'].get('suggested_interventions', ''))
377
+
378
+ st.write("### Next Session Planning")
379
+ st.markdown(results['structured_sections'].get('next_session_plan', ''))
380
+
381
+ # Utility functions for charts and visualizations
382
+ def show_oars_chart(oars_metrics):
383
+ # Create OARS radar chart using plotly
384
+ categories = ['Open Questions', 'Affirmations', 'Reflections', 'Summaries']
385
+ values = [
386
+ oars_metrics.get('open_questions', 0),
387
+ oars_metrics.get('affirmations', 0),
388
+ oars_metrics.get('reflections', 0),
389
+ oars_metrics.get('summaries', 0)
390
+ ]
391
+
392
+ fig = go.Figure(data=go.Scatterpolar(
393
+ r=values,
394
+ theta=categories,
395
+ fill='toself'
396
+ ))
397
+
398
+ fig.update_layout(
399
+ polar=dict(
400
+ radialaxis=dict(
401
+ visible=True,
402
+ range=[0, max(values) + 1]
403
+ )),
404
+ showlegend=False
405
+ )
406
+
407
+ st.plotly_chart(fig)
408
+
409
+ # Add more visualization functions as needed...
410
+
411
+ def save_analysis_results():
412
+ """Save analysis results to file"""
413
+ if st.session_state.analysis_results:
414
+ try:
415
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
416
+ filename = f"analysis_results_{timestamp}.json"
417
+
418
+ with open(filename, "w") as f:
419
+ json.dump(st.session_state.analysis_results, f, indent=4)
420
+
421
+ st.success(f"Analysis results saved to {filename}")
422
+
423
+ except Exception as e:
424
+ st.error(f"Error saving analysis results: {str(e)}")
425
+
426
+ def show_export_options():
427
+ st.sidebar.subheader("Export Options")
428
+
429
+ if st.sidebar.button("Export Analysis Report"):
430
+ save_analysis_results()
431
+
432
+ report_format = st.sidebar.selectbox(
433
+ "Report Format",
434
+ ["PDF", "DOCX", "JSON"]
435
+ )
436
+
437
+ if st.sidebar.button("Generate Report"):
438
+ generate_report(report_format)
439
+
440
+ def generate_report(format):
441
+ """Generate analysis report in specified format"""
442
+ # Add report generation logic here
443
+ st.info(f"Generating {format} report... (Feature coming soon)")
444
+
445
+ if __name__ == "__main__":
446
+ show_session_analysis()