Jiaaaaaaax commited on
Commit
82ec579
·
verified ·
1 Parent(s): 500a392

Update session_analysis.py

Browse files
Files changed (1) hide show
  1. session_analysis.py +230 -111
session_analysis.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
- from google.cloud import speech_v1
3
  import io
4
  import pandas as pd
5
  import plotly.express as px
@@ -66,20 +65,40 @@ def process_media_file(file, type):
66
  progress_bar = st.progress(0)
67
 
68
  try:
69
- # Convert file to audio if needed
70
- if type == "Video Recording":
71
- status.text("Converting video to audio...")
72
- progress_bar.progress(20)
73
- # Add video to audio conversion here if needed
74
- audio_content = convert_video_to_audio(file)
75
- else:
76
- audio_content = file.read()
77
-
78
- # Generate transcript
79
  status.text("Generating transcript...")
80
- progress_bar.progress(60)
81
 
82
- transcript = generate_transcript(audio_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  if transcript:
85
  st.session_state.current_transcript = transcript
@@ -97,6 +116,7 @@ def process_media_file(file, type):
97
  progress_bar.empty()
98
 
99
 
 
100
  def get_processing_step_name(step):
101
  steps = [
102
  "Loading media file",
@@ -128,39 +148,49 @@ def process_text_file(file):
128
  def show_manual_input_form():
129
  st.subheader("Session Details")
130
 
131
- # Session metadata
132
- session_date = st.date_input("Session Date", datetime.now())
133
- session_duration = st.number_input("Session Duration (minutes)", min_value=1, max_value=180, value=50)
134
-
135
- # Client information
136
- client_id = st.text_input("Client ID (optional)")
137
- session_number = st.number_input("Session Number", min_value=1, value=1)
138
-
139
- # Session content
140
- session_notes = st.text_area(
141
- "Session Notes",
142
- height=300,
143
- help="Enter detailed session notes including key dialogues, interventions, and observations"
144
- )
145
-
146
- # Target behaviors
147
- target_behaviors = st.text_area(
148
- "Target Behaviors/Goals",
149
- height=100,
150
- help="Enter the specific behaviors or goals discussed in the session"
151
- )
152
-
153
- # MI specific elements
154
- st.subheader("MI Elements")
155
- change_talk = st.text_area("Observed Change Talk")
156
- sustain_talk = st.text_area("Observed Sustain Talk")
157
-
158
- if st.button("Analyze Session"):
159
- session_data = compile_session_data(
160
- session_date, session_duration, client_id, session_number,
161
- session_notes, target_behaviors, change_talk, sustain_talk
162
  )
163
- analyze_session_content(session_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  def analyze_session_content(content):
166
  try:
@@ -466,7 +496,6 @@ def show_oars_chart(oars_metrics):
466
 
467
  st.plotly_chart(fig)
468
 
469
- # Add more visualization functions as needed...
470
 
471
  def save_analysis_results():
472
  """Save analysis results to file"""
@@ -483,6 +512,42 @@ def save_analysis_results():
483
  except Exception as e:
484
  st.error(f"Error saving analysis results: {str(e)}")
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  def show_export_options():
487
  st.sidebar.subheader("Export Options")
488
 
@@ -559,52 +624,6 @@ def load_previous_sessions():
559
  st.error(f"Error loading previous sessions: {str(e)}")
560
  return []
561
 
562
- def show_manual_input_form():
563
- """Display form for manual session notes input"""
564
- st.subheader("Session Notes Input")
565
-
566
- with st.form("session_notes_form"):
567
- # Basic session information
568
- session_date = st.date_input("Session Date", datetime.now())
569
- session_duration = st.number_input("Duration (minutes)", min_value=15, max_value=120, value=50)
570
-
571
- # Session content
572
- session_notes = st.text_area(
573
- "Session Notes",
574
- height=300,
575
- placeholder="Enter detailed session notes here..."
576
- )
577
-
578
- # Key themes and observations
579
- key_themes = st.text_area(
580
- "Key Themes",
581
- height=100,
582
- placeholder="Enter key themes identified during the session..."
583
- )
584
-
585
- # MI specific elements
586
- mi_techniques_used = st.multiselect(
587
- "MI Techniques Used",
588
- ["Open Questions", "Affirmations", "Reflections", "Summaries",
589
- "Change Talk", "Commitment Language", "Planning"]
590
- )
591
-
592
- # Submit button
593
- submitted = st.form_submit_button("Analyze Session")
594
-
595
- if submitted and session_notes:
596
- # Combine all input into a structured format
597
- session_data = {
598
- 'date': session_date,
599
- 'duration': session_duration,
600
- 'notes': session_notes,
601
- 'themes': key_themes,
602
- 'techniques': mi_techniques_used
603
- }
604
-
605
- # Process the session data
606
- st.session_state.current_transcript = format_session_data(session_data)
607
- analyze_session_content(st.session_state.current_transcript)
608
 
609
  def format_session_data(session_data):
610
  """Format session data into analyzable transcript"""
@@ -624,20 +643,29 @@ def format_session_data(session_data):
624
  return formatted_text
625
 
626
  def analyze_session_content(transcript):
627
- """Analyze session content using Gemini AI"""
628
  try:
629
- # Configure Gemini model
630
  model = genai.GenerativeModel('gemini-pro')
631
 
632
- # Prepare analysis prompt
633
- analysis_prompt = SESSION_EVALUATION_PROMPT + f"\nTranscript:\n{transcript}"
 
 
 
 
 
 
 
 
634
 
635
  # Generate analysis
636
  response = model.generate_content(analysis_prompt)
637
 
638
- # Store and display results
639
- st.session_state.analysis_results = response.text
640
- show_analysis_results()
 
 
641
 
642
  except Exception as e:
643
  st.error(f"Error analyzing session content: {str(e)}")
@@ -674,20 +702,111 @@ def show_analysis_results():
674
  with tabs[4]:
675
  show_recommendations(analysis.get('recommendations', {}))
676
 
677
- def parse_analysis_results(results_text):
678
- """Parse the AI analysis results into structured format"""
679
- # This is a placeholder for more sophisticated parsing
680
- # In a real implementation, you'd want to parse the AI response
681
- # into a structured format based on your specific needs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
 
683
- return {
684
- 'mi_adherence': {'raw_text': results_text},
685
- 'technical_skills': {'raw_text': results_text},
686
- 'client_language': {'raw_text': results_text},
687
- 'session_flow': {'raw_text': results_text},
688
- 'recommendations': {'raw_text': results_text}
689
- }
 
 
 
 
 
 
 
 
 
 
690
 
 
 
 
 
 
 
691
  # Analysis display functions
692
  def show_mi_adherence_analysis(analysis):
693
  st.subheader("MI Adherence Analysis")
 
1
  import streamlit as st
 
2
  import io
3
  import pandas as pd
4
  import plotly.express as px
 
65
  progress_bar = st.progress(0)
66
 
67
  try:
68
+ # Read file content
69
+ file_content = file.read()
70
+
 
 
 
 
 
 
 
71
  status.text("Generating transcript...")
72
+ progress_bar.progress(50)
73
 
74
+ # Generate transcript using Gemini
75
+ model = genai.GenerativeModel('gemini-pro')
76
+
77
+ # Convert file content to text
78
+ if type == "Audio Recording":
79
+ # For audio files, create a prompt that describes the audio
80
+ prompt = f"""
81
+ This is an audio recording of a therapy session.
82
+ Please transcribe the conversation and include speaker labels where possible.
83
+ Focus on capturing:
84
+ 1. The therapist's questions and reflections
85
+ 2. The client's responses and statements
86
+ 3. Any significant pauses or non-verbal sounds
87
+ """
88
+ else: # Video Recording
89
+ # For video files, create a prompt that describes the video
90
+ prompt = f"""
91
+ This is a video recording of a therapy session.
92
+ Please transcribe the conversation and include:
93
+ 1. Speaker labels
94
+ 2. Verbal communication
95
+ 3. Relevant non-verbal cues and body language
96
+ 4. Significant pauses or interactions
97
+ """
98
+
99
+ # Generate transcript
100
+ response = model.generate_content(prompt)
101
+ transcript = response.text
102
 
103
  if transcript:
104
  st.session_state.current_transcript = transcript
 
116
  progress_bar.empty()
117
 
118
 
119
+
120
  def get_processing_step_name(step):
121
  steps = [
122
  "Loading media file",
 
148
  def show_manual_input_form():
149
  st.subheader("Session Details")
150
 
151
+ with st.form("session_notes_form"):
152
+ # Basic session information
153
+ session_date = st.date_input("Session Date", datetime.now())
154
+ session_duration = st.number_input("Duration (minutes)", min_value=15, max_value=120, value=50)
155
+
156
+ # Session content
157
+ session_notes = st.text_area(
158
+ "Session Notes",
159
+ height=300,
160
+ placeholder="Enter detailed session notes here..."
161
+ )
162
+
163
+ # Key themes and observations
164
+ key_themes = st.text_area(
165
+ "Key Themes",
166
+ height=100,
167
+ placeholder="Enter key themes identified during the session..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  )
169
+
170
+ # MI specific elements
171
+ mi_techniques_used = st.multiselect(
172
+ "MI Techniques Used",
173
+ ["Open Questions", "Affirmations", "Reflections", "Summaries",
174
+ "Change Talk", "Commitment Language", "Planning"]
175
+ )
176
+
177
+ # Submit button
178
+ submitted = st.form_submit_button("Analyze Session")
179
+
180
+ if submitted and session_notes:
181
+ # Combine all input into a structured format
182
+ session_data = {
183
+ 'date': session_date,
184
+ 'duration': session_duration,
185
+ 'notes': session_notes,
186
+ 'themes': key_themes,
187
+ 'techniques': mi_techniques_used
188
+ }
189
+
190
+ # Process the session data
191
+ st.session_state.current_transcript = format_session_data(session_data)
192
+ analyze_session_content(st.session_state.current_transcript)
193
+
194
 
195
  def analyze_session_content(content):
196
  try:
 
496
 
497
  st.plotly_chart(fig)
498
 
 
499
 
500
  def save_analysis_results():
501
  """Save analysis results to file"""
 
512
  except Exception as e:
513
  st.error(f"Error saving analysis results: {str(e)}")
514
 
515
+ def show_upload_section():
516
+ st.header("Session Data Upload")
517
+
518
+ upload_type = st.radio(
519
+ "Select Input Method:",
520
+ ["Text Transcript", "Session Notes", "Previous Session Data"] # Removed Audio/Video options
521
+ )
522
+
523
+ if upload_type == "Text Transcript":
524
+ file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"])
525
+ if file:
526
+ process_text_file(file)
527
+
528
+ elif upload_type == "Session Notes":
529
+ show_manual_input_form()
530
+
531
+ else: # Previous Session Data
532
+ show_previous_sessions_selector()
533
+
534
+ def process_text_file(file):
535
+ try:
536
+ if file.name.endswith('.json'):
537
+ content = json.loads(file.read().decode())
538
+ transcript = extract_transcript_from_json(content)
539
+ elif file.name.endswith('.docx'):
540
+ doc = Document(file)
541
+ transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
542
+ else:
543
+ transcript = file.read().decode()
544
+
545
+ if transcript:
546
+ st.session_state.current_transcript = transcript
547
+ analyze_session_content(transcript)
548
+
549
+ except Exception as e:
550
+ st.error(f"Error processing file: {str(e)}")
551
  def show_export_options():
552
  st.sidebar.subheader("Export Options")
553
 
 
624
  st.error(f"Error loading previous sessions: {str(e)}")
625
  return []
626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
  def format_session_data(session_data):
629
  """Format session data into analyzable transcript"""
 
643
  return formatted_text
644
 
645
  def analyze_session_content(transcript):
 
646
  try:
647
+ # Initialize Gemini
648
  model = genai.GenerativeModel('gemini-pro')
649
 
650
+ # Prepare the analysis prompt
651
+ analysis_prompt = f"""
652
+ {MI_SYSTEM_PROMPT}
653
+
654
+ Please analyze the following therapy session transcript:
655
+
656
+ {transcript}
657
+
658
+ {SESSION_EVALUATION_PROMPT}
659
+ """
660
 
661
  # Generate analysis
662
  response = model.generate_content(analysis_prompt)
663
 
664
+ # Parse the response
665
+ analysis_results = parse_analysis_response(response.text)
666
+
667
+ # Store results in session state
668
+ st.session_state.analysis_results = analysis_results
669
 
670
  except Exception as e:
671
  st.error(f"Error analyzing session content: {str(e)}")
 
702
  with tabs[4]:
703
  show_recommendations(analysis.get('recommendations', {}))
704
 
705
+ def parse_analysis_response(response_text):
706
+ """Parse the AI response into structured analysis results"""
707
+ try:
708
+ # Initialize default structure for analysis results
709
+ analysis = {
710
+ 'mi_adherence_score': 0.0,
711
+ 'key_themes': [],
712
+ 'technique_usage': {},
713
+ 'strengths': [],
714
+ 'areas_for_improvement': [],
715
+ 'recommendations': [],
716
+ 'change_talk_instances': [],
717
+ 'session_summary': ""
718
+ }
719
+
720
+ # Extract MI adherence score
721
+ score_match = re.search(r'MI Adherence Score:\s*(\d+\.?\d*)', response_text)
722
+ if score_match:
723
+ analysis['mi_adherence_score'] = float(score_match.group(1))
724
+
725
+ # Extract key themes
726
+ themes_section = re.search(r'Key Themes:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
727
+ if themes_section:
728
+ themes = themes_section.group(1).strip().split('\n')
729
+ analysis['key_themes'] = [theme.strip('- ') for theme in themes if theme.strip()]
730
+
731
+ # Extract technique usage
732
+ technique_section = re.search(r'Technique Usage:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
733
+ if technique_section:
734
+ techniques = technique_section.group(1).strip().split('\n')
735
+ for technique in techniques:
736
+ if ':' in technique:
737
+ name, count = technique.split(':')
738
+ analysis['technique_usage'][name.strip()] = int(count.strip())
739
+
740
+ # Extract strengths
741
+ strengths_section = re.search(r'Strengths:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
742
+ if strengths_section:
743
+ strengths = strengths_section.group(1).strip().split('\n')
744
+ analysis['strengths'] = [s.strip('- ') for s in strengths if s.strip()]
745
+
746
+ # Extract areas for improvement
747
+ improvements_section = re.search(r'Areas for Improvement:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
748
+ if improvements_section:
749
+ improvements = improvements_section.group(1).strip().split('\n')
750
+ analysis['areas_for_improvement'] = [i.strip('- ') for i in improvements if i.strip()]
751
+
752
+ # Extract session summary
753
+ summary_section = re.search(r'Session Summary:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
754
+ if summary_section:
755
+ analysis['session_summary'] = summary_section.group(1).strip()
756
+
757
+ return analysis
758
+
759
+ except Exception as e:
760
+ st.error(f"Error parsing analysis response: {str(e)}")
761
+ return None
762
+
763
+ def create_gauge_chart(score):
764
+ """Create a gauge chart for MI Adherence Score"""
765
+ fig = go.Figure(go.Indicator(
766
+ mode = "gauge+number",
767
+ value = score,
768
+ domain = {'x': [0, 1], 'y': [0, 1]},
769
+ title = {'text': "MI Adherence"},
770
+ gauge = {
771
+ 'axis': {'range': [0, 100]},
772
+ 'bar': {'color': "darkblue"},
773
+ 'steps': [
774
+ {'range': [0, 40], 'color': "lightgray"},
775
+ {'range': [40, 70], 'color': "gray"},
776
+ {'range': [70, 100], 'color': "darkgray"}
777
+ ],
778
+ 'threshold': {
779
+ 'line': {'color': "red", 'width': 4},
780
+ 'thickness': 0.75,
781
+ 'value': 90
782
+ }
783
+ }
784
+ ))
785
 
786
+ st.plotly_chart(fig)
787
+
788
+ def create_technique_usage_chart(technique_usage):
789
+ """Create a bar chart for MI technique usage"""
790
+ df = pd.DataFrame(list(technique_usage.items()), columns=['Technique', 'Count'])
791
+ fig = px.bar(
792
+ df,
793
+ x='Technique',
794
+ y='Count',
795
+ title='MI Technique Usage Frequency'
796
+ )
797
+ fig.update_layout(
798
+ xaxis_title="Technique",
799
+ yaxis_title="Frequency",
800
+ showlegend=False
801
+ )
802
+ st.plotly_chart(fig)
803
 
804
+ def extract_transcript_from_json(content):
805
+ """Extract transcript from JSON content"""
806
+ if isinstance(content, dict):
807
+ return json.dumps(content, indent=2)
808
+ return str(content)
809
+
810
  # Analysis display functions
811
  def show_mi_adherence_analysis(analysis):
812
  st.subheader("MI Adherence Analysis")