MotiMeter / session_analysis.py
Jiaaaaaaax's picture
Update session_analysis.py
33510a5 verified
raw
history blame
20.8 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import google.generativeai as genai
from datetime import datetime
import json
import numpy as np
from docx import Document
import re
from prompts import SESSION_EVALUATION_PROMPT, MI_SYSTEM_PROMPT
def show_session_analysis():
st.title("MI Session Analysis Dashboard")
# Initialize session state for analysis results
if 'analysis_results' not in st.session_state:
st.session_state.analysis_results = None
if 'current_transcript' not in st.session_state:
st.session_state.current_transcript = None
# Main layout
col1, col2 = st.columns([1, 2])
with col1:
show_upload_section()
with col2:
if st.session_state.analysis_results:
show_analysis_results()
def show_upload_section():
st.header("Session Data Upload")
upload_type = st.radio(
"Select Input Method:",
["Audio Recording", "Video Recording", "Text Transcript", "Session Notes", "Previous Session Data"]
)
if upload_type in ["Audio Recording", "Video Recording"]:
file = st.file_uploader(
f"Upload {upload_type}",
type=["wav", "mp3", "mp4"] if upload_type == "Audio Recording" else ["mp4", "avi", "mov"]
)
if file:
process_media_file(file, upload_type)
elif upload_type == "Text Transcript":
file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"])
if file:
process_text_file(file)
elif upload_type == "Session Notes":
show_manual_input_form()
else: # Previous Session Data
show_previous_sessions_selector()
def process_media_file(file, type):
st.write(f"Processing {type}...")
# Add processing status
status = st.empty()
progress_bar = st.progress(0)
try:
# Simulated processing steps
for i in range(5):
status.text(f"Step {i+1}/5: " + get_processing_step_name(i))
progress_bar.progress((i + 1) * 20)
# Generate transcript
transcript = generate_transcript(file)
if transcript:
st.session_state.current_transcript = transcript
analyze_session_content(transcript)
except Exception as e:
st.error(f"Error processing file: {str(e)}")
finally:
status.empty()
progress_bar.empty()
def get_processing_step_name(step):
steps = [
"Loading media file",
"Converting to audio",
"Performing speech recognition",
"Generating transcript",
"Preparing analysis"
]
return steps[step]
def process_text_file(file):
try:
if file.name.endswith('.json'):
content = json.loads(file.read().decode())
transcript = extract_transcript_from_json(content)
elif file.name.endswith('.docx'):
doc = Document(file)
transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
else:
transcript = file.read().decode()
if transcript:
st.session_state.current_transcript = transcript
analyze_session_content(transcript)
except Exception as e:
st.error(f"Error processing file: {str(e)}")
def show_manual_input_form():
st.subheader("Session Details")
# Session metadata
session_date = st.date_input("Session Date", datetime.now())
session_duration = st.number_input("Session Duration (minutes)", min_value=1, max_value=180, value=50)
# Client information
client_id = st.text_input("Client ID (optional)")
session_number = st.number_input("Session Number", min_value=1, value=1)
# Session content
session_notes = st.text_area(
"Session Notes",
height=300,
help="Enter detailed session notes including key dialogues, interventions, and observations"
)
# Target behaviors
target_behaviors = st.text_area(
"Target Behaviors/Goals",
height=100,
help="Enter the specific behaviors or goals discussed in the session"
)
# MI specific elements
st.subheader("MI Elements")
change_talk = st.text_area("Observed Change Talk")
sustain_talk = st.text_area("Observed Sustain Talk")
if st.button("Analyze Session"):
session_data = compile_session_data(
session_date, session_duration, client_id, session_number,
session_notes, target_behaviors, change_talk, sustain_talk
)
analyze_session_content(session_data)
def analyze_session_content(content):
try:
# Configure Gemini model
model = genai.GenerativeModel('gemini-pro')
# Prepare analysis prompt
analysis_prompt = f"""
Analyze the following therapy session using MI principles and provide a comprehensive evaluation:
Session Content:
{content}
Please provide detailed analysis including:
1. MI Adherence Assessment:
- OARS implementation
- Change talk identification
- Resistance management
- MI spirit adherence
2. Technical Skills Evaluation:
- Reflection quality and frequency
- Question-to-reflection ratio
- Open vs. closed questions
- Affirmations and summaries
3. Client Language Analysis:
- Change talk instances
- Sustain talk patterns
- Commitment language
- Resistance patterns
4. Session Flow Analysis:
- Engagement level
- Focus maintenance
- Evocation quality
- Planning effectiveness
5. Recommendations:
- Strength areas
- Growth opportunities
- Suggested interventions
- Next session planning
Format the analysis with clear sections and specific examples from the session.
"""
# Generate analysis
response = model.generate_content(analysis_prompt)
# Process and structure the analysis results
analysis_results = process_analysis_results(response.text)
# Store results in session state
st.session_state.analysis_results = analysis_results
# Show success message
st.success("Analysis completed successfully!")
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
def process_analysis_results(raw_analysis):
"""Process and structure the analysis results"""
# Parse the raw analysis text and extract structured data
sections = extract_analysis_sections(raw_analysis)
# Calculate metrics
metrics = calculate_mi_metrics(raw_analysis)
return {
"raw_analysis": raw_analysis,
"structured_sections": sections,
"metrics": metrics,
"timestamp": datetime.now().isoformat()
}
def show_analysis_results():
"""Display comprehensive analysis results"""
if not st.session_state.analysis_results:
return
results = st.session_state.analysis_results
# Top-level metrics
show_mi_metrics_dashboard(results['metrics'])
# Detailed analysis sections
tabs = st.tabs([
"MI Adherence",
"Technical Skills",
"Client Language",
"Session Flow",
"Recommendations"
])
with tabs[0]:
show_mi_adherence_analysis(results)
with tabs[1]:
show_technical_skills_analysis(results)
with tabs[2]:
show_client_language_analysis(results)
with tabs[3]:
show_session_flow_analysis(results)
with tabs[4]:
show_recommendations(results)
def show_mi_metrics_dashboard(metrics):
st.subheader("MI Performance Dashboard")
col1, col2, col3, col4 = st.columns(4)
with col1:
show_metric_card(
"MI Spirit Score",
metrics.get('mi_spirit_score', 0),
"0-5 scale"
)
with col2:
show_metric_card(
"Change Talk Ratio",
metrics.get('change_talk_ratio', 0),
"Change vs Sustain"
)
with col3:
show_metric_card(
"Reflection Ratio",
metrics.get('reflection_ratio', 0),
"Reflections/Questions"
)
with col4:
show_metric_card(
"Overall Adherence",
metrics.get('overall_adherence', 0),
"Percentage"
)
def show_metric_card(title, value, subtitle):
st.markdown(
f"""
<div style="border:1px solid #ccc; padding:10px; border-radius:5px; text-align:center;">
<h3>{title}</h3>
<h2>{value:.2f}</h2>
<p>{subtitle}</p>
</div>
""",
unsafe_allow_html=True
)
def show_mi_adherence_analysis(results):
st.subheader("MI Adherence Analysis")
# OARS Implementation
st.write("### OARS Implementation")
show_oars_chart(results['metrics'].get('oars_metrics', {}))
# MI Spirit Components
st.write("### MI Spirit Components")
show_mi_spirit_chart(results['metrics'].get('mi_spirit_metrics', {}))
# Detailed breakdown
st.write("### Detailed Analysis")
st.markdown(results['structured_sections'].get('mi_adherence', ''))
def show_technical_skills_analysis(results):
st.subheader("Technical Skills Analysis")
# Question Analysis
col1, col2 = st.columns(2)
with col1:
show_question_type_chart(results['metrics'].get('question_metrics', {}))
with col2:
show_reflection_depth_chart(results['metrics'].get('reflection_metrics', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('technical_skills', ''))
def show_client_language_analysis(results):
st.subheader("Client Language Analysis")
# Change Talk Timeline
show_change_talk_timeline(results['metrics'].get('change_talk_timeline', []))
# Language Categories
show_language_categories_chart(results['metrics'].get('language_categories', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('client_language', ''))
def show_session_flow_analysis(results):
st.subheader("Session Flow Analysis")
# Session Flow Timeline
show_session_flow_timeline(results['metrics'].get('session_flow', []))
# Engagement Metrics
show_engagement_metrics(results['metrics'].get('engagement_metrics', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('session_flow', ''))
def show_recommendations(results):
st.subheader("Recommendations and Next Steps")
col1, col2 = st.columns(2)
with col1:
st.write("### Strengths")
strengths = results['structured_sections'].get('strengths', [])
for strength in strengths:
st.markdown(f"✓ {strength}")
with col2:
st.write("### Growth Areas")
growth_areas = results['structured_sections'].get('growth_areas', [])
for area in growth_areas:
st.markdown(f"→ {area}")
st.write("### Suggested Interventions")
st.markdown(results['structured_sections'].get('suggested_interventions', ''))
st.write("### Next Session Planning")
st.markdown(results['structured_sections'].get('next_session_plan', ''))
# Utility functions for charts and visualizations
def show_oars_chart(oars_metrics):
# Create OARS radar chart using plotly
categories = ['Open Questions', 'Affirmations', 'Reflections', 'Summaries']
values = [
oars_metrics.get('open_questions', 0),
oars_metrics.get('affirmations', 0),
oars_metrics.get('reflections', 0),
oars_metrics.get('summaries', 0)
]
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, max(values) + 1]
)),
showlegend=False
)
st.plotly_chart(fig)
# Add more visualization functions as needed...
def save_analysis_results():
"""Save analysis results to file"""
if st.session_state.analysis_results:
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"analysis_results_{timestamp}.json"
with open(filename, "w") as f:
json.dump(st.session_state.analysis_results, f, indent=4)
st.success(f"Analysis results saved to {filename}")
except Exception as e:
st.error(f"Error saving analysis results: {str(e)}")
def show_export_options():
st.sidebar.subheader("Export Options")
if st.sidebar.button("Export Analysis Report"):
save_analysis_results()
report_format = st.sidebar.selectbox(
"Report Format",
["PDF", "DOCX", "JSON"]
)
if st.sidebar.button("Generate Report"):
generate_report(report_format)
def generate_report(format):
"""Generate analysis report in specified format"""
# Add report generation logic here
st.info(f"Generating {format} report... (Feature coming soon)")
def show_previous_sessions_selector():
"""Display selector for previous session data"""
st.subheader("Previous Sessions")
# Load or initialize previous sessions data
if 'previous_sessions' not in st.session_state:
st.session_state.previous_sessions = load_previous_sessions()
if not st.session_state.previous_sessions:
st.info("No previous sessions found.")
return
# Create session selector
sessions = st.session_state.previous_sessions
session_dates = [session['date'] for session in sessions]
selected_date = st.selectbox(
"Select Session Date:",
session_dates,
format_func=lambda x: x.strftime("%Y-%m-%d %H:%M")
)
# Show selected session data
if selected_date:
selected_session = next(
(session for session in sessions if session['date'] == selected_date),
None
)
if selected_session:
st.session_state.current_transcript = selected_session['transcript']
analyze_session_content(selected_session['transcript'])
def load_previous_sessions():
"""Load previous session data from storage"""
try:
# Initialize empty list for sessions
sessions = []
# Here you would typically load from your database or file storage
# For demonstration, we'll create some sample data
sample_sessions = [
{
'date': datetime.now(),
'transcript': "Sample transcript 1...",
'analysis': "Sample analysis 1..."
},
{
'date': datetime.now(),
'transcript': "Sample transcript 2...",
'analysis': "Sample analysis 2..."
}
]
return sample_sessions
except Exception as e:
st.error(f"Error loading previous sessions: {str(e)}")
return []
def show_manual_input_form():
"""Display form for manual session notes input"""
st.subheader("Session Notes Input")
with st.form("session_notes_form"):
# Basic session information
session_date = st.date_input("Session Date", datetime.now())
session_duration = st.number_input("Duration (minutes)", min_value=15, max_value=120, value=50)
# Session content
session_notes = st.text_area(
"Session Notes",
height=300,
placeholder="Enter detailed session notes here..."
)
# Key themes and observations
key_themes = st.text_area(
"Key Themes",
height=100,
placeholder="Enter key themes identified during the session..."
)
# MI specific elements
mi_techniques_used = st.multiselect(
"MI Techniques Used",
["Open Questions", "Affirmations", "Reflections", "Summaries",
"Change Talk", "Commitment Language", "Planning"]
)
# Submit button
submitted = st.form_submit_button("Analyze Session")
if submitted and session_notes:
# Combine all input into a structured format
session_data = {
'date': session_date,
'duration': session_duration,
'notes': session_notes,
'themes': key_themes,
'techniques': mi_techniques_used
}
# Process the session data
st.session_state.current_transcript = format_session_data(session_data)
analyze_session_content(st.session_state.current_transcript)
def format_session_data(session_data):
"""Format session data into analyzable transcript"""
formatted_text = f"""
Session Date: {session_data['date']}
Duration: {session_data['duration']} minutes
SESSION NOTES:
{session_data['notes']}
KEY THEMES:
{session_data['themes']}
MI TECHNIQUES USED:
{', '.join(session_data['techniques'])}
"""
return formatted_text
def analyze_session_content(transcript):
"""Analyze session content using Gemini AI"""
try:
# Configure Gemini model
model = genai.GenerativeModel('gemini-pro')
# Prepare analysis prompt
analysis_prompt = SESSION_EVALUATION_PROMPT + f"\nTranscript:\n{transcript}"
# Generate analysis
response = model.generate_content(analysis_prompt)
# Store and display results
st.session_state.analysis_results = response.text
show_analysis_results()
except Exception as e:
st.error(f"Error analyzing session content: {str(e)}")
def show_analysis_results():
"""Display session analysis results"""
if not st.session_state.analysis_results:
st.warning("No analysis results available.")
return
st.header("Session Analysis Results")
# Create tabs for different aspects of analysis
tabs = st.tabs([
"MI Adherence",
"Technical Skills",
"Client Language",
"Session Flow",
"Recommendations"
])
# Parse analysis results (assuming structured response from AI)
analysis = parse_analysis_results(st.session_state.analysis_results)
# Display results in respective tabs
with tabs[0]:
show_mi_adherence_analysis(analysis.get('mi_adherence', {}))
with tabs[1]:
show_technical_skills_analysis(analysis.get('technical_skills', {}))
with tabs[2]:
show_client_language_analysis(analysis.get('client_language', {}))
with tabs[3]:
show_session_flow_analysis(analysis.get('session_flow', {}))
with tabs[4]:
show_recommendations(analysis.get('recommendations', {}))
def parse_analysis_results(results_text):
"""Parse the AI analysis results into structured format"""
# This is a placeholder for more sophisticated parsing
# In a real implementation, you'd want to parse the AI response
# into a structured format based on your specific needs
return {
'mi_adherence': {'raw_text': results_text},
'technical_skills': {'raw_text': results_text},
'client_language': {'raw_text': results_text},
'session_flow': {'raw_text': results_text},
'recommendations': {'raw_text': results_text}
}
# Analysis display functions
def show_mi_adherence_analysis(analysis):
st.subheader("MI Adherence Analysis")
st.write(analysis.get('raw_text', 'No analysis available'))
def show_technical_skills_analysis(analysis):
st.subheader("Technical Skills Analysis")
st.write(analysis.get('raw_text', 'No analysis available'))
def show_client_language_analysis(analysis):
st.subheader("Client Language Analysis")
st.write(analysis.get('raw_text', 'No analysis available'))
def show_session_flow_analysis(analysis):
st.subheader("Session Flow Analysis")
st.write(analysis.get('raw_text', 'No analysis available'))
def show_recommendations(analysis):
st.subheader("Recommendations")
st.write(analysis.get('raw_text', 'No recommendations available'))
if __name__ == "__main__":
show_session_analysis()