File size: 4,890 Bytes
788074d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py
import streamlit as st
from uuid import uuid4
import sqlite3
import plotly.express as px
from langchain_groq import ChatGroq
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage

# Initialize components
def init_components():
    model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
    tool = TavilySearchResults(max_results=3)
    conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False)
    memory = SqliteSaver(conn)
    return model, tool, memory

# Medical Agent Class
class AdvancedMedicalAgent:
    def __init__(self, model, tools, memory):
        self.model = model.bind_tools(tools)
        self.tools = {t.name: t for t in tools}
        self.memory = memory
        
    def process_query(self, patient_data, thread_id):
        # Initialize or retrieve conversation thread
        if 'history' not in st.session_state:
            st.session_state.history = []
            
        # Add patient data visualization
        self.display_patient_dashboard(patient_data)
        
        # Process medical inquiry
        response = self.model.invoke([
            SystemMessage(content=self.get_system_prompt()),
            HumanMessage(content=f"Patient Case:\n{patient_data}")
        ])
        
        # Handle tool calls
        if hasattr(response, 'tool_calls'):
            self.handle_medical_actions(response.tool_calls, thread_id)
            
        st.session_state.history.append(("AI", response.content))
        
    def display_patient_dashboard(self, data):
        tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"])
        
        with tab1:
            fig = px.line(
                x=["Temperature", "BP", "Heart Rate"],
                y=[data['vitals']['temp'], 130, 85],  # Example data
                title="Vital Signs"
            )
            st.plotly_chart(fig)
            
        with tab2:
            st.json(data['history'])
            
        with tab3:
            st.vega_lite_chart({
                "mark": {"type": "circle", "tooltip": True},
                "encoding": {
                    "x": {"field": "timestamp", "type": "temporal"},
                    "y": {"field": "severity", "type": "quantitative"}
                }
            })

    def handle_medical_actions(self, tool_calls, thread_id):
        for call in tool_calls:
            st.warning(f"🩺 Pending Medical Action: {call['name']}")
            
            if st.button("Review Action Details"):
                st.write(f"**Action Type:** {call['name']}")
                st.write(f"**Parameters:** {call['args']}")
                
                if st.checkbox("I approve this action"):
                    result = self.tools[call['name']].invoke(call['args'])
                    st.session_state.history.append(
                        ("System", f"Action executed: {result}")
                    )
                else:
                    st.session_state.history.append(
                        ("System", "Action cancelled by clinician")
                    )

    def get_system_prompt(self):
        return """You are an advanced medical AI assistant. Follow these steps:
1. Analyze patient data using latest clinical guidelines
2. Consider drug interactions and contraindications
3. Propose differential diagnosis (3 possibilities)
4. Suggest evidence-based treatment options
5. Identify necessary lab tests with reasoning
6. Flag high-risk factors in RED
7. Maintain audit trail of all decisions"""

# Streamlit UI
def main():
    st.set_page_config(page_title="AI Clinical Assistant", layout="wide")
    
    with st.sidebar:
        st.header("Patient Intake")
        patient_data = {
            "symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
            "history": {
                "conditions": st.text_input("Medical History"),
                "medications": st.text_input("Current Medications")
            },
            "vitals": {
                "temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0),
                "bp": st.text_input("Blood Pressure (mmHg)", "120/80")
            }
        }
    
    st.title("🧠 AI-Powered Clinical Decision Support")
    
    model, tool, memory = init_components()
    agent = AdvancedMedicalAgent(model, [tool], memory)
    
    if st.button("Start Analysis"):
        with st.spinner("Analyzing patient case..."):
            agent.process_query(patient_data, str(uuid4()))
            
    if 'history' in st.session_state:
        st.subheader("Clinical Decision Log")
        for role, content in st.session_state.history:
            with st.chat_message(role):
                st.markdown(content)

if __name__ == "__main__":
    main()