MotiMeter / session_analysis.py
Jiaaaaaaax's picture
Update session_analysis.py
baeff27 verified
raw
history blame
13.6 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import google.generativeai as genai
from datetime import datetime
import json
import numpy as np
from docx import Document
import re
from prompts import SESSION_EVALUATION_PROMPT, MI_SYSTEM_PROMPT
def show_session_analysis():
st.title("MI Session Analysis Dashboard")
# Initialize session state for analysis results
if 'analysis_results' not in st.session_state:
st.session_state.analysis_results = None
if 'current_transcript' not in st.session_state:
st.session_state.current_transcript = None
# Main layout
col1, col2 = st.columns([1, 2])
with col1:
show_upload_section()
with col2:
if st.session_state.analysis_results:
show_analysis_results()
def show_upload_section():
st.header("Session Data Upload")
upload_type = st.radio(
"Select Input Method:",
["Audio Recording", "Video Recording", "Text Transcript", "Session Notes", "Previous Session Data"]
)
if upload_type in ["Audio Recording", "Video Recording"]:
file = st.file_uploader(
f"Upload {upload_type}",
type=["wav", "mp3", "mp4"] if upload_type == "Audio Recording" else ["mp4", "avi", "mov"]
)
if file:
process_media_file(file, upload_type)
elif upload_type == "Text Transcript":
file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"])
if file:
process_text_file(file)
elif upload_type == "Session Notes":
show_manual_input_form()
else: # Previous Session Data
show_previous_sessions_selector()
def process_media_file(file, type):
st.write(f"Processing {type}...")
# Add processing status
status = st.empty()
progress_bar = st.progress(0)
try:
# Simulated processing steps
for i in range(5):
status.text(f"Step {i+1}/5: " + get_processing_step_name(i))
progress_bar.progress((i + 1) * 20)
# Generate transcript
transcript = generate_transcript(file)
if transcript:
st.session_state.current_transcript = transcript
analyze_session_content(transcript)
except Exception as e:
st.error(f"Error processing file: {str(e)}")
finally:
status.empty()
progress_bar.empty()
def get_processing_step_name(step):
steps = [
"Loading media file",
"Converting to audio",
"Performing speech recognition",
"Generating transcript",
"Preparing analysis"
]
return steps[step]
def process_text_file(file):
try:
if file.name.endswith('.json'):
content = json.loads(file.read().decode())
transcript = extract_transcript_from_json(content)
elif file.name.endswith('.docx'):
doc = Document(file)
transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
else:
transcript = file.read().decode()
if transcript:
st.session_state.current_transcript = transcript
analyze_session_content(transcript)
except Exception as e:
st.error(f"Error processing file: {str(e)}")
def show_manual_input_form():
st.subheader("Session Details")
# Session metadata
session_date = st.date_input("Session Date", datetime.now())
session_duration = st.number_input("Session Duration (minutes)", min_value=1, max_value=180, value=50)
# Client information
client_id = st.text_input("Client ID (optional)")
session_number = st.number_input("Session Number", min_value=1, value=1)
# Session content
session_notes = st.text_area(
"Session Notes",
height=300,
help="Enter detailed session notes including key dialogues, interventions, and observations"
)
# Target behaviors
target_behaviors = st.text_area(
"Target Behaviors/Goals",
height=100,
help="Enter the specific behaviors or goals discussed in the session"
)
# MI specific elements
st.subheader("MI Elements")
change_talk = st.text_area("Observed Change Talk")
sustain_talk = st.text_area("Observed Sustain Talk")
if st.button("Analyze Session"):
session_data = compile_session_data(
session_date, session_duration, client_id, session_number,
session_notes, target_behaviors, change_talk, sustain_talk
)
analyze_session_content(session_data)
def analyze_session_content(content):
try:
# Configure Gemini model
model = genai.GenerativeModel('gemini-pro')
# Prepare analysis prompt
analysis_prompt = f"""
Analyze the following therapy session using MI principles and provide a comprehensive evaluation:
Session Content:
{content}
Please provide detailed analysis including:
1. MI Adherence Assessment:
- OARS implementation
- Change talk identification
- Resistance management
- MI spirit adherence
2. Technical Skills Evaluation:
- Reflection quality and frequency
- Question-to-reflection ratio
- Open vs. closed questions
- Affirmations and summaries
3. Client Language Analysis:
- Change talk instances
- Sustain talk patterns
- Commitment language
- Resistance patterns
4. Session Flow Analysis:
- Engagement level
- Focus maintenance
- Evocation quality
- Planning effectiveness
5. Recommendations:
- Strength areas
- Growth opportunities
- Suggested interventions
- Next session planning
Format the analysis with clear sections and specific examples from the session.
"""
# Generate analysis
response = model.generate_content(analysis_prompt)
# Process and structure the analysis results
analysis_results = process_analysis_results(response.text)
# Store results in session state
st.session_state.analysis_results = analysis_results
# Show success message
st.success("Analysis completed successfully!")
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
def process_analysis_results(raw_analysis):
"""Process and structure the analysis results"""
# Parse the raw analysis text and extract structured data
sections = extract_analysis_sections(raw_analysis)
# Calculate metrics
metrics = calculate_mi_metrics(raw_analysis)
return {
"raw_analysis": raw_analysis,
"structured_sections": sections,
"metrics": metrics,
"timestamp": datetime.now().isoformat()
}
def show_analysis_results():
"""Display comprehensive analysis results"""
if not st.session_state.analysis_results:
return
results = st.session_state.analysis_results
# Top-level metrics
show_mi_metrics_dashboard(results['metrics'])
# Detailed analysis sections
tabs = st.tabs([
"MI Adherence",
"Technical Skills",
"Client Language",
"Session Flow",
"Recommendations"
])
with tabs[0]:
show_mi_adherence_analysis(results)
with tabs[1]:
show_technical_skills_analysis(results)
with tabs[2]:
show_client_language_analysis(results)
with tabs[3]:
show_session_flow_analysis(results)
with tabs[4]:
show_recommendations(results)
def show_mi_metrics_dashboard(metrics):
st.subheader("MI Performance Dashboard")
col1, col2, col3, col4 = st.columns(4)
with col1:
show_metric_card(
"MI Spirit Score",
metrics.get('mi_spirit_score', 0),
"0-5 scale"
)
with col2:
show_metric_card(
"Change Talk Ratio",
metrics.get('change_talk_ratio', 0),
"Change vs Sustain"
)
with col3:
show_metric_card(
"Reflection Ratio",
metrics.get('reflection_ratio', 0),
"Reflections/Questions"
)
with col4:
show_metric_card(
"Overall Adherence",
metrics.get('overall_adherence', 0),
"Percentage"
)
def show_metric_card(title, value, subtitle):
st.markdown(
f"""
<div style="border:1px solid #ccc; padding:10px; border-radius:5px; text-align:center;">
<h3>{title}</h3>
<h2>{value:.2f}</h2>
<p>{subtitle}</p>
</div>
""",
unsafe_allow_html=True
)
def show_mi_adherence_analysis(results):
st.subheader("MI Adherence Analysis")
# OARS Implementation
st.write("### OARS Implementation")
show_oars_chart(results['metrics'].get('oars_metrics', {}))
# MI Spirit Components
st.write("### MI Spirit Components")
show_mi_spirit_chart(results['metrics'].get('mi_spirit_metrics', {}))
# Detailed breakdown
st.write("### Detailed Analysis")
st.markdown(results['structured_sections'].get('mi_adherence', ''))
def show_technical_skills_analysis(results):
st.subheader("Technical Skills Analysis")
# Question Analysis
col1, col2 = st.columns(2)
with col1:
show_question_type_chart(results['metrics'].get('question_metrics', {}))
with col2:
show_reflection_depth_chart(results['metrics'].get('reflection_metrics', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('technical_skills', ''))
def show_client_language_analysis(results):
st.subheader("Client Language Analysis")
# Change Talk Timeline
show_change_talk_timeline(results['metrics'].get('change_talk_timeline', []))
# Language Categories
show_language_categories_chart(results['metrics'].get('language_categories', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('client_language', ''))
def show_session_flow_analysis(results):
st.subheader("Session Flow Analysis")
# Session Flow Timeline
show_session_flow_timeline(results['metrics'].get('session_flow', []))
# Engagement Metrics
show_engagement_metrics(results['metrics'].get('engagement_metrics', {}))
# Detailed analysis
st.markdown(results['structured_sections'].get('session_flow', ''))
def show_recommendations(results):
st.subheader("Recommendations and Next Steps")
col1, col2 = st.columns(2)
with col1:
st.write("### Strengths")
strengths = results['structured_sections'].get('strengths', [])
for strength in strengths:
st.markdown(f"✓ {strength}")
with col2:
st.write("### Growth Areas")
growth_areas = results['structured_sections'].get('growth_areas', [])
for area in growth_areas:
st.markdown(f"→ {area}")
st.write("### Suggested Interventions")
st.markdown(results['structured_sections'].get('suggested_interventions', ''))
st.write("### Next Session Planning")
st.markdown(results['structured_sections'].get('next_session_plan', ''))
# Utility functions for charts and visualizations
def show_oars_chart(oars_metrics):
# Create OARS radar chart using plotly
categories = ['Open Questions', 'Affirmations', 'Reflections', 'Summaries']
values = [
oars_metrics.get('open_questions', 0),
oars_metrics.get('affirmations', 0),
oars_metrics.get('reflections', 0),
oars_metrics.get('summaries', 0)
]
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, max(values) + 1]
)),
showlegend=False
)
st.plotly_chart(fig)
# Add more visualization functions as needed...
def save_analysis_results():
"""Save analysis results to file"""
if st.session_state.analysis_results:
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"analysis_results_{timestamp}.json"
with open(filename, "w") as f:
json.dump(st.session_state.analysis_results, f, indent=4)
st.success(f"Analysis results saved to {filename}")
except Exception as e:
st.error(f"Error saving analysis results: {str(e)}")
def show_export_options():
st.sidebar.subheader("Export Options")
if st.sidebar.button("Export Analysis Report"):
save_analysis_results()
report_format = st.sidebar.selectbox(
"Report Format",
["PDF", "DOCX", "JSON"]
)
if st.sidebar.button("Generate Report"):
generate_report(report_format)
def generate_report(format):
"""Generate analysis report in specified format"""
# Add report generation logic here
st.info(f"Generating {format} report... (Feature coming soon)")
if __name__ == "__main__":
show_session_analysis()