Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import google.generativeai as genai | |
from datetime import datetime | |
import json | |
import numpy as np | |
from docx import Document | |
import re | |
from prompts import SESSION_EVALUATION_PROMPT, MI_SYSTEM_PROMPT | |
def show_session_analysis(): | |
st.title("MI Session Analysis Dashboard") | |
# Initialize session state for analysis results | |
if 'analysis_results' not in st.session_state: | |
st.session_state.analysis_results = None | |
if 'current_transcript' not in st.session_state: | |
st.session_state.current_transcript = None | |
# Main layout | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
show_upload_section() | |
with col2: | |
if st.session_state.analysis_results: | |
show_analysis_results() | |
def show_upload_section(): | |
st.header("Session Data Upload") | |
upload_type = st.radio( | |
"Select Input Method:", | |
["Audio Recording", "Video Recording", "Text Transcript", "Session Notes", "Previous Session Data"] | |
) | |
if upload_type in ["Audio Recording", "Video Recording"]: | |
file = st.file_uploader( | |
f"Upload {upload_type}", | |
type=["wav", "mp3", "mp4"] if upload_type == "Audio Recording" else ["mp4", "avi", "mov"] | |
) | |
if file: | |
process_media_file(file, upload_type) | |
elif upload_type == "Text Transcript": | |
file = st.file_uploader("Upload Transcript", type=["txt", "doc", "docx", "json"]) | |
if file: | |
process_text_file(file) | |
elif upload_type == "Session Notes": | |
show_manual_input_form() | |
else: # Previous Session Data | |
show_previous_sessions_selector() | |
def process_media_file(file, type): | |
st.write(f"Processing {type}...") | |
# Add processing status | |
status = st.empty() | |
progress_bar = st.progress(0) | |
try: | |
# Simulated processing steps | |
for i in range(5): | |
status.text(f"Step {i+1}/5: " + get_processing_step_name(i)) | |
progress_bar.progress((i + 1) * 20) | |
# Generate transcript | |
transcript = generate_transcript(file) | |
if transcript: | |
st.session_state.current_transcript = transcript | |
analyze_session_content(transcript) | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
finally: | |
status.empty() | |
progress_bar.empty() | |
def get_processing_step_name(step): | |
steps = [ | |
"Loading media file", | |
"Converting to audio", | |
"Performing speech recognition", | |
"Generating transcript", | |
"Preparing analysis" | |
] | |
return steps[step] | |
def process_text_file(file): | |
try: | |
if file.name.endswith('.json'): | |
content = json.loads(file.read().decode()) | |
transcript = extract_transcript_from_json(content) | |
elif file.name.endswith('.docx'): | |
doc = Document(file) | |
transcript = '\n'.join([paragraph.text for paragraph in doc.paragraphs]) | |
else: | |
transcript = file.read().decode() | |
if transcript: | |
st.session_state.current_transcript = transcript | |
analyze_session_content(transcript) | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
def show_manual_input_form(): | |
st.subheader("Session Details") | |
# Session metadata | |
session_date = st.date_input("Session Date", datetime.now()) | |
session_duration = st.number_input("Session Duration (minutes)", min_value=1, max_value=180, value=50) | |
# Client information | |
client_id = st.text_input("Client ID (optional)") | |
session_number = st.number_input("Session Number", min_value=1, value=1) | |
# Session content | |
session_notes = st.text_area( | |
"Session Notes", | |
height=300, | |
help="Enter detailed session notes including key dialogues, interventions, and observations" | |
) | |
# Target behaviors | |
target_behaviors = st.text_area( | |
"Target Behaviors/Goals", | |
height=100, | |
help="Enter the specific behaviors or goals discussed in the session" | |
) | |
# MI specific elements | |
st.subheader("MI Elements") | |
change_talk = st.text_area("Observed Change Talk") | |
sustain_talk = st.text_area("Observed Sustain Talk") | |
if st.button("Analyze Session"): | |
session_data = compile_session_data( | |
session_date, session_duration, client_id, session_number, | |
session_notes, target_behaviors, change_talk, sustain_talk | |
) | |
analyze_session_content(session_data) | |
def analyze_session_content(content): | |
try: | |
# Configure Gemini model | |
model = genai.GenerativeModel('gemini-pro') | |
# Prepare analysis prompt | |
analysis_prompt = f""" | |
Analyze the following therapy session using MI principles and provide a comprehensive evaluation: | |
Session Content: | |
{content} | |
Please provide detailed analysis including: | |
1. MI Adherence Assessment: | |
- OARS implementation | |
- Change talk identification | |
- Resistance management | |
- MI spirit adherence | |
2. Technical Skills Evaluation: | |
- Reflection quality and frequency | |
- Question-to-reflection ratio | |
- Open vs. closed questions | |
- Affirmations and summaries | |
3. Client Language Analysis: | |
- Change talk instances | |
- Sustain talk patterns | |
- Commitment language | |
- Resistance patterns | |
4. Session Flow Analysis: | |
- Engagement level | |
- Focus maintenance | |
- Evocation quality | |
- Planning effectiveness | |
5. Recommendations: | |
- Strength areas | |
- Growth opportunities | |
- Suggested interventions | |
- Next session planning | |
Format the analysis with clear sections and specific examples from the session. | |
""" | |
# Generate analysis | |
response = model.generate_content(analysis_prompt) | |
# Process and structure the analysis results | |
analysis_results = process_analysis_results(response.text) | |
# Store results in session state | |
st.session_state.analysis_results = analysis_results | |
# Show success message | |
st.success("Analysis completed successfully!") | |
except Exception as e: | |
st.error(f"Error during analysis: {str(e)}") | |
def process_analysis_results(raw_analysis): | |
"""Process and structure the analysis results""" | |
# Parse the raw analysis text and extract structured data | |
sections = extract_analysis_sections(raw_analysis) | |
# Calculate metrics | |
metrics = calculate_mi_metrics(raw_analysis) | |
return { | |
"raw_analysis": raw_analysis, | |
"structured_sections": sections, | |
"metrics": metrics, | |
"timestamp": datetime.now().isoformat() | |
} | |
def show_analysis_results(): | |
"""Display comprehensive analysis results""" | |
if not st.session_state.analysis_results: | |
return | |
results = st.session_state.analysis_results | |
# Top-level metrics | |
show_mi_metrics_dashboard(results['metrics']) | |
# Detailed analysis sections | |
tabs = st.tabs([ | |
"MI Adherence", | |
"Technical Skills", | |
"Client Language", | |
"Session Flow", | |
"Recommendations" | |
]) | |
with tabs[0]: | |
show_mi_adherence_analysis(results) | |
with tabs[1]: | |
show_technical_skills_analysis(results) | |
with tabs[2]: | |
show_client_language_analysis(results) | |
with tabs[3]: | |
show_session_flow_analysis(results) | |
with tabs[4]: | |
show_recommendations(results) | |
def show_mi_metrics_dashboard(metrics): | |
st.subheader("MI Performance Dashboard") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
show_metric_card( | |
"MI Spirit Score", | |
metrics.get('mi_spirit_score', 0), | |
"0-5 scale" | |
) | |
with col2: | |
show_metric_card( | |
"Change Talk Ratio", | |
metrics.get('change_talk_ratio', 0), | |
"Change vs Sustain" | |
) | |
with col3: | |
show_metric_card( | |
"Reflection Ratio", | |
metrics.get('reflection_ratio', 0), | |
"Reflections/Questions" | |
) | |
with col4: | |
show_metric_card( | |
"Overall Adherence", | |
metrics.get('overall_adherence', 0), | |
"Percentage" | |
) | |
def show_metric_card(title, value, subtitle): | |
st.markdown( | |
f""" | |
<div style="border:1px solid #ccc; padding:10px; border-radius:5px; text-align:center;"> | |
<h3>{title}</h3> | |
<h2>{value:.2f}</h2> | |
<p>{subtitle}</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
def show_mi_adherence_analysis(results): | |
st.subheader("MI Adherence Analysis") | |
# OARS Implementation | |
st.write("### OARS Implementation") | |
show_oars_chart(results['metrics'].get('oars_metrics', {})) | |
# MI Spirit Components | |
st.write("### MI Spirit Components") | |
show_mi_spirit_chart(results['metrics'].get('mi_spirit_metrics', {})) | |
# Detailed breakdown | |
st.write("### Detailed Analysis") | |
st.markdown(results['structured_sections'].get('mi_adherence', '')) | |
def show_technical_skills_analysis(results): | |
st.subheader("Technical Skills Analysis") | |
# Question Analysis | |
col1, col2 = st.columns(2) | |
with col1: | |
show_question_type_chart(results['metrics'].get('question_metrics', {})) | |
with col2: | |
show_reflection_depth_chart(results['metrics'].get('reflection_metrics', {})) | |
# Detailed analysis | |
st.markdown(results['structured_sections'].get('technical_skills', '')) | |
def show_client_language_analysis(results): | |
st.subheader("Client Language Analysis") | |
# Change Talk Timeline | |
show_change_talk_timeline(results['metrics'].get('change_talk_timeline', [])) | |
# Language Categories | |
show_language_categories_chart(results['metrics'].get('language_categories', {})) | |
# Detailed analysis | |
st.markdown(results['structured_sections'].get('client_language', '')) | |
def show_session_flow_analysis(results): | |
st.subheader("Session Flow Analysis") | |
# Session Flow Timeline | |
show_session_flow_timeline(results['metrics'].get('session_flow', [])) | |
# Engagement Metrics | |
show_engagement_metrics(results['metrics'].get('engagement_metrics', {})) | |
# Detailed analysis | |
st.markdown(results['structured_sections'].get('session_flow', '')) | |
def show_recommendations(results): | |
st.subheader("Recommendations and Next Steps") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("### Strengths") | |
strengths = results['structured_sections'].get('strengths', []) | |
for strength in strengths: | |
st.markdown(f"✓ {strength}") | |
with col2: | |
st.write("### Growth Areas") | |
growth_areas = results['structured_sections'].get('growth_areas', []) | |
for area in growth_areas: | |
st.markdown(f"→ {area}") | |
st.write("### Suggested Interventions") | |
st.markdown(results['structured_sections'].get('suggested_interventions', '')) | |
st.write("### Next Session Planning") | |
st.markdown(results['structured_sections'].get('next_session_plan', '')) | |
# Utility functions for charts and visualizations | |
def show_oars_chart(oars_metrics): | |
# Create OARS radar chart using plotly | |
categories = ['Open Questions', 'Affirmations', 'Reflections', 'Summaries'] | |
values = [ | |
oars_metrics.get('open_questions', 0), | |
oars_metrics.get('affirmations', 0), | |
oars_metrics.get('reflections', 0), | |
oars_metrics.get('summaries', 0) | |
] | |
fig = go.Figure(data=go.Scatterpolar( | |
r=values, | |
theta=categories, | |
fill='toself' | |
)) | |
fig.update_layout( | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, max(values) + 1] | |
)), | |
showlegend=False | |
) | |
st.plotly_chart(fig) | |
# Add more visualization functions as needed... | |
def save_analysis_results(): | |
"""Save analysis results to file""" | |
if st.session_state.analysis_results: | |
try: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"analysis_results_{timestamp}.json" | |
with open(filename, "w") as f: | |
json.dump(st.session_state.analysis_results, f, indent=4) | |
st.success(f"Analysis results saved to {filename}") | |
except Exception as e: | |
st.error(f"Error saving analysis results: {str(e)}") | |
def show_export_options(): | |
st.sidebar.subheader("Export Options") | |
if st.sidebar.button("Export Analysis Report"): | |
save_analysis_results() | |
report_format = st.sidebar.selectbox( | |
"Report Format", | |
["PDF", "DOCX", "JSON"] | |
) | |
if st.sidebar.button("Generate Report"): | |
generate_report(report_format) | |
def generate_report(format): | |
"""Generate analysis report in specified format""" | |
# Add report generation logic here | |
st.info(f"Generating {format} report... (Feature coming soon)") | |
if __name__ == "__main__": | |
show_session_analysis() |