Spaces:
Sleeping
Sleeping
Update session_analysis.py
Browse files- session_analysis.py +230 -111
session_analysis.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import streamlit as st
|
2 |
-
from google.cloud import speech_v1
|
3 |
import io
|
4 |
import pandas as pd
|
5 |
import plotly.express as px
|
@@ -66,20 +65,40 @@ def process_media_file(file, type):
|
|
66 |
progress_bar = st.progress(0)
|
67 |
|
68 |
try:
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
progress_bar.progress(20)
|
73 |
-
# Add video to audio conversion here if needed
|
74 |
-
audio_content = convert_video_to_audio(file)
|
75 |
-
else:
|
76 |
-
audio_content = file.read()
|
77 |
-
|
78 |
-
# Generate transcript
|
79 |
status.text("Generating transcript...")
|
80 |
-
progress_bar.progress(
|
81 |
|
82 |
-
transcript
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
if transcript:
|
85 |
st.session_state.current_transcript = transcript
|
@@ -97,6 +116,7 @@ def process_media_file(file, type):
|
|
97 |
progress_bar.empty()
|
98 |
|
99 |
|
|
|
100 |
def get_processing_step_name(step):
|
101 |
steps = [
|
102 |
"Loading media file",
|
@@ -128,39 +148,49 @@ def process_text_file(file):
|
|
128 |
def show_manual_input_form():
|
129 |
st.subheader("Session Details")
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
"Target Behaviors/Goals",
|
149 |
-
height=100,
|
150 |
-
help="Enter the specific behaviors or goals discussed in the session"
|
151 |
-
)
|
152 |
-
|
153 |
-
# MI specific elements
|
154 |
-
st.subheader("MI Elements")
|
155 |
-
change_talk = st.text_area("Observed Change Talk")
|
156 |
-
sustain_talk = st.text_area("Observed Sustain Talk")
|
157 |
-
|
158 |
-
if st.button("Analyze Session"):
|
159 |
-
session_data = compile_session_data(
|
160 |
-
session_date, session_duration, client_id, session_number,
|
161 |
-
session_notes, target_behaviors, change_talk, sustain_talk
|
162 |
)
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
def analyze_session_content(content):
|
166 |
try:
|
@@ -466,7 +496,6 @@ def show_oars_chart(oars_metrics):
|
|
466 |
|
467 |
st.plotly_chart(fig)
|
468 |
|
469 |
-
# Add more visualization functions as needed...
|
470 |
|
471 |
def save_analysis_results():
|
472 |
"""Save analysis results to file"""
|
@@ -483,6 +512,42 @@ def save_analysis_results():
|
|
483 |
except Exception as e:
|
484 |
st.error(f"Error saving analysis results: {str(e)}")
|
485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
def show_export_options():
|
487 |
st.sidebar.subheader("Export Options")
|
488 |
|
@@ -559,52 +624,6 @@ def load_previous_sessions():
|
|
559 |
st.error(f"Error loading previous sessions: {str(e)}")
|
560 |
return []
|
561 |
|
562 |
-
def show_manual_input_form():
|
563 |
-
"""Display form for manual session notes input"""
|
564 |
-
st.subheader("Session Notes Input")
|
565 |
-
|
566 |
-
with st.form("session_notes_form"):
|
567 |
-
# Basic session information
|
568 |
-
session_date = st.date_input("Session Date", datetime.now())
|
569 |
-
session_duration = st.number_input("Duration (minutes)", min_value=15, max_value=120, value=50)
|
570 |
-
|
571 |
-
# Session content
|
572 |
-
session_notes = st.text_area(
|
573 |
-
"Session Notes",
|
574 |
-
height=300,
|
575 |
-
placeholder="Enter detailed session notes here..."
|
576 |
-
)
|
577 |
-
|
578 |
-
# Key themes and observations
|
579 |
-
key_themes = st.text_area(
|
580 |
-
"Key Themes",
|
581 |
-
height=100,
|
582 |
-
placeholder="Enter key themes identified during the session..."
|
583 |
-
)
|
584 |
-
|
585 |
-
# MI specific elements
|
586 |
-
mi_techniques_used = st.multiselect(
|
587 |
-
"MI Techniques Used",
|
588 |
-
["Open Questions", "Affirmations", "Reflections", "Summaries",
|
589 |
-
"Change Talk", "Commitment Language", "Planning"]
|
590 |
-
)
|
591 |
-
|
592 |
-
# Submit button
|
593 |
-
submitted = st.form_submit_button("Analyze Session")
|
594 |
-
|
595 |
-
if submitted and session_notes:
|
596 |
-
# Combine all input into a structured format
|
597 |
-
session_data = {
|
598 |
-
'date': session_date,
|
599 |
-
'duration': session_duration,
|
600 |
-
'notes': session_notes,
|
601 |
-
'themes': key_themes,
|
602 |
-
'techniques': mi_techniques_used
|
603 |
-
}
|
604 |
-
|
605 |
-
# Process the session data
|
606 |
-
st.session_state.current_transcript = format_session_data(session_data)
|
607 |
-
analyze_session_content(st.session_state.current_transcript)
|
608 |
|
609 |
def format_session_data(session_data):
|
610 |
"""Format session data into analyzable transcript"""
|
@@ -624,20 +643,29 @@ def format_session_data(session_data):
|
|
624 |
return formatted_text
|
625 |
|
626 |
def analyze_session_content(transcript):
|
627 |
-
"""Analyze session content using Gemini AI"""
|
628 |
try:
|
629 |
-
#
|
630 |
model = genai.GenerativeModel('gemini-pro')
|
631 |
|
632 |
-
# Prepare analysis prompt
|
633 |
-
analysis_prompt =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
# Generate analysis
|
636 |
response = model.generate_content(analysis_prompt)
|
637 |
|
638 |
-
#
|
639 |
-
|
640 |
-
|
|
|
|
|
641 |
|
642 |
except Exception as e:
|
643 |
st.error(f"Error analyzing session content: {str(e)}")
|
@@ -674,20 +702,111 @@ def show_analysis_results():
|
|
674 |
with tabs[4]:
|
675 |
show_recommendations(analysis.get('recommendations', {}))
|
676 |
|
677 |
-
def
|
678 |
-
"""Parse the AI
|
679 |
-
|
680 |
-
|
681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
682 |
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
# Analysis display functions
|
692 |
def show_mi_adherence_analysis(analysis):
|
693 |
st.subheader("MI Adherence Analysis")
|
|
|
1 |
import streamlit as st
|
|
|
2 |
import io
|
3 |
import pandas as pd
|
4 |
import plotly.express as px
|
|
|
65 |
progress_bar = st.progress(0)
|
66 |
|
67 |
try:
|
68 |
+
# Read file content
|
69 |
+
file_content = file.read()
|
70 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
status.text("Generating transcript...")
|
72 |
+
progress_bar.progress(50)
|
73 |
|
74 |
+
# Generate transcript using Gemini
|
75 |
+
model = genai.GenerativeModel('gemini-pro')
|
76 |
+
|
77 |
+
# Convert file content to text
|
78 |
+
if type == "Audio Recording":
|
79 |
+
# For audio files, create a prompt that describes the audio
|
80 |
+
prompt = f"""
|
81 |
+
This is an audio recording of a therapy session.
|
82 |
+
Please transcribe the conversation and include speaker labels where possible.
|
83 |
+
Focus on capturing:
|
84 |
+
1. The therapist's questions and reflections
|
85 |
+
2. The client's responses and statements
|
86 |
+
3. Any significant pauses or non-verbal sounds
|
87 |
+
"""
|
88 |
+
else: # Video Recording
|
89 |
+
# For video files, create a prompt that describes the video
|
90 |
+
prompt = f"""
|
91 |
+
This is a video recording of a therapy session.
|
92 |
+
Please transcribe the conversation and include:
|
93 |
+
1. Speaker labels
|
94 |
+
2. Verbal communication
|
95 |
+
3. Relevant non-verbal cues and body language
|
96 |
+
4. Significant pauses or interactions
|
97 |
+
"""
|
98 |
+
|
99 |
+
# Generate transcript
|
100 |
+
response = model.generate_content(prompt)
|
101 |
+
transcript = response.text
|
102 |
|
103 |
if transcript:
|
104 |
st.session_state.current_transcript = transcript
|
|
|
116 |
progress_bar.empty()
|
117 |
|
118 |
|
119 |
+
|
120 |
def get_processing_step_name(step):
|
121 |
steps = [
|
122 |
"Loading media file",
|
|
|
148 |
def show_manual_input_form():
|
149 |
st.subheader("Session Details")
|
150 |
|
151 |
+
with st.form("session_notes_form"):
|
152 |
+
# Basic session information
|
153 |
+
session_date = st.date_input("Session Date", datetime.now())
|
154 |
+
session_duration = st.number_input("Duration (minutes)", min_value=15, max_value=120, value=50)
|
155 |
+
|
156 |
+
# Session content
|
157 |
+
session_notes = st.text_area(
|
158 |
+
"Session Notes",
|
159 |
+
height=300,
|
160 |
+
placeholder="Enter detailed session notes here..."
|
161 |
+
)
|
162 |
+
|
163 |
+
# Key themes and observations
|
164 |
+
key_themes = st.text_area(
|
165 |
+
"Key Themes",
|
166 |
+
height=100,
|
167 |
+
placeholder="Enter key themes identified during the session..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
)
|
169 |
+
|
170 |
+
# MI specific elements
|
171 |
+
mi_techniques_used = st.multiselect(
|
172 |
+
"MI Techniques Used",
|
173 |
+
["Open Questions", "Affirmations", "Reflections", "Summaries",
|
174 |
+
"Change Talk", "Commitment Language", "Planning"]
|
175 |
+
)
|
176 |
+
|
177 |
+
# Submit button
|
178 |
+
submitted = st.form_submit_button("Analyze Session")
|
179 |
+
|
180 |
+
if submitted and session_notes:
|
181 |
+
# Combine all input into a structured format
|
182 |
+
session_data = {
|
183 |
+
'date': session_date,
|
184 |
+
'duration': session_duration,
|
185 |
+
'notes': session_notes,
|
186 |
+
'themes': key_themes,
|
187 |
+
'techniques': mi_techniques_used
|
188 |
+
}
|
189 |
+
|
190 |
+
# Process the session data
|
191 |
+
st.session_state.current_transcript = format_session_data(session_data)
|
192 |
+
analyze_session_content(st.session_state.current_transcript)
|
193 |
+
|
194 |
|
195 |
def analyze_session_content(content):
|
196 |
try:
|
|
|
496 |
|
497 |
st.plotly_chart(fig)
|
498 |
|
|
|
499 |
|
500 |
def save_analysis_results():
|
501 |
"""Save analysis results to file"""
|
|
|
512 |
except Exception as e:
|
513 |
st.error(f"Error saving analysis results: {str(e)}")
|
514 |
|
515 |
+
def show_upload_section():
|
516 |
+
st.header("Session Data Upload")
|
517 |
+
|
518 |
+
upload_type = st.radio(
|
519 |
+
"Select Input Method:",
|
520 |
+
["Text Transcript", "Session Notes", "Previous Session Data"] # Removed Audio/Video options
|
521 |
+
)
|
522 |
+
|
523 |
+
if upload_type == "Text Transcript":
|
524 |
+
file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"])
|
525 |
+
if file:
|
526 |
+
process_text_file(file)
|
527 |
+
|
528 |
+
elif upload_type == "Session Notes":
|
529 |
+
show_manual_input_form()
|
530 |
+
|
531 |
+
else: # Previous Session Data
|
532 |
+
show_previous_sessions_selector()
|
533 |
+
|
534 |
+
def process_text_file(file):
|
535 |
+
try:
|
536 |
+
if file.name.endswith('.json'):
|
537 |
+
content = json.loads(file.read().decode())
|
538 |
+
transcript = extract_transcript_from_json(content)
|
539 |
+
elif file.name.endswith('.docx'):
|
540 |
+
doc = Document(file)
|
541 |
+
transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
|
542 |
+
else:
|
543 |
+
transcript = file.read().decode()
|
544 |
+
|
545 |
+
if transcript:
|
546 |
+
st.session_state.current_transcript = transcript
|
547 |
+
analyze_session_content(transcript)
|
548 |
+
|
549 |
+
except Exception as e:
|
550 |
+
st.error(f"Error processing file: {str(e)}")
|
551 |
def show_export_options():
|
552 |
st.sidebar.subheader("Export Options")
|
553 |
|
|
|
624 |
st.error(f"Error loading previous sessions: {str(e)}")
|
625 |
return []
|
626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
|
628 |
def format_session_data(session_data):
|
629 |
"""Format session data into analyzable transcript"""
|
|
|
643 |
return formatted_text
|
644 |
|
645 |
def analyze_session_content(transcript):
|
|
|
646 |
try:
|
647 |
+
# Initialize Gemini
|
648 |
model = genai.GenerativeModel('gemini-pro')
|
649 |
|
650 |
+
# Prepare the analysis prompt
|
651 |
+
analysis_prompt = f"""
|
652 |
+
{MI_SYSTEM_PROMPT}
|
653 |
+
|
654 |
+
Please analyze the following therapy session transcript:
|
655 |
+
|
656 |
+
{transcript}
|
657 |
+
|
658 |
+
{SESSION_EVALUATION_PROMPT}
|
659 |
+
"""
|
660 |
|
661 |
# Generate analysis
|
662 |
response = model.generate_content(analysis_prompt)
|
663 |
|
664 |
+
# Parse the response
|
665 |
+
analysis_results = parse_analysis_response(response.text)
|
666 |
+
|
667 |
+
# Store results in session state
|
668 |
+
st.session_state.analysis_results = analysis_results
|
669 |
|
670 |
except Exception as e:
|
671 |
st.error(f"Error analyzing session content: {str(e)}")
|
|
|
702 |
with tabs[4]:
|
703 |
show_recommendations(analysis.get('recommendations', {}))
|
704 |
|
705 |
+
def parse_analysis_response(response_text):
|
706 |
+
"""Parse the AI response into structured analysis results"""
|
707 |
+
try:
|
708 |
+
# Initialize default structure for analysis results
|
709 |
+
analysis = {
|
710 |
+
'mi_adherence_score': 0.0,
|
711 |
+
'key_themes': [],
|
712 |
+
'technique_usage': {},
|
713 |
+
'strengths': [],
|
714 |
+
'areas_for_improvement': [],
|
715 |
+
'recommendations': [],
|
716 |
+
'change_talk_instances': [],
|
717 |
+
'session_summary': ""
|
718 |
+
}
|
719 |
+
|
720 |
+
# Extract MI adherence score
|
721 |
+
score_match = re.search(r'MI Adherence Score:\s*(\d+\.?\d*)', response_text)
|
722 |
+
if score_match:
|
723 |
+
analysis['mi_adherence_score'] = float(score_match.group(1))
|
724 |
+
|
725 |
+
# Extract key themes
|
726 |
+
themes_section = re.search(r'Key Themes:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
|
727 |
+
if themes_section:
|
728 |
+
themes = themes_section.group(1).strip().split('\n')
|
729 |
+
analysis['key_themes'] = [theme.strip('- ') for theme in themes if theme.strip()]
|
730 |
+
|
731 |
+
# Extract technique usage
|
732 |
+
technique_section = re.search(r'Technique Usage:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
|
733 |
+
if technique_section:
|
734 |
+
techniques = technique_section.group(1).strip().split('\n')
|
735 |
+
for technique in techniques:
|
736 |
+
if ':' in technique:
|
737 |
+
name, count = technique.split(':')
|
738 |
+
analysis['technique_usage'][name.strip()] = int(count.strip())
|
739 |
+
|
740 |
+
# Extract strengths
|
741 |
+
strengths_section = re.search(r'Strengths:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
|
742 |
+
if strengths_section:
|
743 |
+
strengths = strengths_section.group(1).strip().split('\n')
|
744 |
+
analysis['strengths'] = [s.strip('- ') for s in strengths if s.strip()]
|
745 |
+
|
746 |
+
# Extract areas for improvement
|
747 |
+
improvements_section = re.search(r'Areas for Improvement:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
|
748 |
+
if improvements_section:
|
749 |
+
improvements = improvements_section.group(1).strip().split('\n')
|
750 |
+
analysis['areas_for_improvement'] = [i.strip('- ') for i in improvements if i.strip()]
|
751 |
+
|
752 |
+
# Extract session summary
|
753 |
+
summary_section = re.search(r'Session Summary:(.*?)(?=\n\n|\Z)', response_text, re.DOTALL)
|
754 |
+
if summary_section:
|
755 |
+
analysis['session_summary'] = summary_section.group(1).strip()
|
756 |
+
|
757 |
+
return analysis
|
758 |
+
|
759 |
+
except Exception as e:
|
760 |
+
st.error(f"Error parsing analysis response: {str(e)}")
|
761 |
+
return None
|
762 |
+
|
763 |
+
def create_gauge_chart(score):
|
764 |
+
"""Create a gauge chart for MI Adherence Score"""
|
765 |
+
fig = go.Figure(go.Indicator(
|
766 |
+
mode = "gauge+number",
|
767 |
+
value = score,
|
768 |
+
domain = {'x': [0, 1], 'y': [0, 1]},
|
769 |
+
title = {'text': "MI Adherence"},
|
770 |
+
gauge = {
|
771 |
+
'axis': {'range': [0, 100]},
|
772 |
+
'bar': {'color': "darkblue"},
|
773 |
+
'steps': [
|
774 |
+
{'range': [0, 40], 'color': "lightgray"},
|
775 |
+
{'range': [40, 70], 'color': "gray"},
|
776 |
+
{'range': [70, 100], 'color': "darkgray"}
|
777 |
+
],
|
778 |
+
'threshold': {
|
779 |
+
'line': {'color': "red", 'width': 4},
|
780 |
+
'thickness': 0.75,
|
781 |
+
'value': 90
|
782 |
+
}
|
783 |
+
}
|
784 |
+
))
|
785 |
|
786 |
+
st.plotly_chart(fig)
|
787 |
+
|
788 |
+
def create_technique_usage_chart(technique_usage):
|
789 |
+
"""Create a bar chart for MI technique usage"""
|
790 |
+
df = pd.DataFrame(list(technique_usage.items()), columns=['Technique', 'Count'])
|
791 |
+
fig = px.bar(
|
792 |
+
df,
|
793 |
+
x='Technique',
|
794 |
+
y='Count',
|
795 |
+
title='MI Technique Usage Frequency'
|
796 |
+
)
|
797 |
+
fig.update_layout(
|
798 |
+
xaxis_title="Technique",
|
799 |
+
yaxis_title="Frequency",
|
800 |
+
showlegend=False
|
801 |
+
)
|
802 |
+
st.plotly_chart(fig)
|
803 |
|
804 |
+
def extract_transcript_from_json(content):
|
805 |
+
"""Extract transcript from JSON content"""
|
806 |
+
if isinstance(content, dict):
|
807 |
+
return json.dumps(content, indent=2)
|
808 |
+
return str(content)
|
809 |
+
|
810 |
# Analysis display functions
|
811 |
def show_mi_adherence_analysis(analysis):
|
812 |
st.subheader("MI Adherence Analysis")
|