# app.py import streamlit as st from uuid import uuid4 import sqlite3 import plotly.express as px from langchain_groq import ChatGroq from langgraph.checkpoint.sqlite import SqliteSaver from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage # Initialize components def init_components(): model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec") tool = TavilySearchResults(max_results=3) conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False) memory = SqliteSaver(conn) return model, tool, memory # Medical Agent Class class AdvancedMedicalAgent: def __init__(self, model, tools, memory): self.model = model.bind_tools(tools) self.tools = {t.name: t for t in tools} self.memory = memory def process_query(self, patient_data, thread_id): # Initialize or retrieve conversation thread if 'history' not in st.session_state: st.session_state.history = [] # Add patient data visualization self.display_patient_dashboard(patient_data) # Process medical inquiry response = self.model.invoke([ SystemMessage(content=self.get_system_prompt()), HumanMessage(content=f"Patient Case:\n{patient_data}") ]) # Handle tool calls if hasattr(response, 'tool_calls'): self.handle_medical_actions(response.tool_calls, thread_id) st.session_state.history.append(("AI", response.content)) def display_patient_dashboard(self, data): tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"]) with tab1: fig = px.line( x=["Temperature", "BP", "Heart Rate"], y=[data['vitals']['temp'], 130, 85], # Example data title="Vital Signs" ) st.plotly_chart(fig) with tab2: st.json(data['history']) with tab3: st.vega_lite_chart({ "mark": {"type": "circle", "tooltip": True}, "encoding": { "x": {"field": "timestamp", "type": "temporal"}, "y": {"field": "severity", "type": "quantitative"} } }) def handle_medical_actions(self, tool_calls, thread_id): for call in tool_calls: st.warning(f"🩺 Pending Medical Action: {call['name']}") if st.button("Review Action Details"): st.write(f"**Action Type:** {call['name']}") st.write(f"**Parameters:** {call['args']}") if st.checkbox("I approve this action"): result = self.tools[call['name']].invoke(call['args']) st.session_state.history.append( ("System", f"Action executed: {result}") ) else: st.session_state.history.append( ("System", "Action cancelled by clinician") ) def get_system_prompt(self): return """You are an advanced medical AI assistant. Follow these steps: 1. Analyze patient data using latest clinical guidelines 2. Consider drug interactions and contraindications 3. Propose differential diagnosis (3 possibilities) 4. Suggest evidence-based treatment options 5. Identify necessary lab tests with reasoning 6. Flag high-risk factors in RED 7. Maintain audit trail of all decisions""" # Streamlit UI def main(): st.set_page_config(page_title="AI Clinical Assistant", layout="wide") with st.sidebar: st.header("Patient Intake") patient_data = { "symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]), "history": { "conditions": st.text_input("Medical History"), "medications": st.text_input("Current Medications") }, "vitals": { "temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0), "bp": st.text_input("Blood Pressure (mmHg)", "120/80") } } st.title("🧠 AI-Powered Clinical Decision Support") model, tool, memory = init_components() agent = AdvancedMedicalAgent(model, [tool], memory) if st.button("Start Analysis"): with st.spinner("Analyzing patient case..."): agent.process_query(patient_data, str(uuid4())) if 'history' in st.session_state: st.subheader("Clinical Decision Log") for role, content in st.session_state.history: with st.chat_message(role): st.markdown(content) if __name__ == "__main__": main()