mgbam commited on
Commit
788074d
·
1 Parent(s): 50f1cde

Add application file

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from uuid import uuid4
4
+ import sqlite3
5
+ import plotly.express as px
6
+ from langchain_groq import ChatGroq
7
+ from langgraph.checkpoint.sqlite import SqliteSaver
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
10
+
11
+ # Initialize components
12
+ def init_components():
13
+ model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
14
+ tool = TavilySearchResults(max_results=3)
15
+ conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False)
16
+ memory = SqliteSaver(conn)
17
+ return model, tool, memory
18
+
19
+ # Medical Agent Class
20
+ class AdvancedMedicalAgent:
21
+ def __init__(self, model, tools, memory):
22
+ self.model = model.bind_tools(tools)
23
+ self.tools = {t.name: t for t in tools}
24
+ self.memory = memory
25
+
26
+ def process_query(self, patient_data, thread_id):
27
+ # Initialize or retrieve conversation thread
28
+ if 'history' not in st.session_state:
29
+ st.session_state.history = []
30
+
31
+ # Add patient data visualization
32
+ self.display_patient_dashboard(patient_data)
33
+
34
+ # Process medical inquiry
35
+ response = self.model.invoke([
36
+ SystemMessage(content=self.get_system_prompt()),
37
+ HumanMessage(content=f"Patient Case:\n{patient_data}")
38
+ ])
39
+
40
+ # Handle tool calls
41
+ if hasattr(response, 'tool_calls'):
42
+ self.handle_medical_actions(response.tool_calls, thread_id)
43
+
44
+ st.session_state.history.append(("AI", response.content))
45
+
46
+ def display_patient_dashboard(self, data):
47
+ tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"])
48
+
49
+ with tab1:
50
+ fig = px.line(
51
+ x=["Temperature", "BP", "Heart Rate"],
52
+ y=[data['vitals']['temp'], 130, 85], # Example data
53
+ title="Vital Signs"
54
+ )
55
+ st.plotly_chart(fig)
56
+
57
+ with tab2:
58
+ st.json(data['history'])
59
+
60
+ with tab3:
61
+ st.vega_lite_chart({
62
+ "mark": {"type": "circle", "tooltip": True},
63
+ "encoding": {
64
+ "x": {"field": "timestamp", "type": "temporal"},
65
+ "y": {"field": "severity", "type": "quantitative"}
66
+ }
67
+ })
68
+
69
+ def handle_medical_actions(self, tool_calls, thread_id):
70
+ for call in tool_calls:
71
+ st.warning(f"🩺 Pending Medical Action: {call['name']}")
72
+
73
+ if st.button("Review Action Details"):
74
+ st.write(f"**Action Type:** {call['name']}")
75
+ st.write(f"**Parameters:** {call['args']}")
76
+
77
+ if st.checkbox("I approve this action"):
78
+ result = self.tools[call['name']].invoke(call['args'])
79
+ st.session_state.history.append(
80
+ ("System", f"Action executed: {result}")
81
+ )
82
+ else:
83
+ st.session_state.history.append(
84
+ ("System", "Action cancelled by clinician")
85
+ )
86
+
87
+ def get_system_prompt(self):
88
+ return """You are an advanced medical AI assistant. Follow these steps:
89
+ 1. Analyze patient data using latest clinical guidelines
90
+ 2. Consider drug interactions and contraindications
91
+ 3. Propose differential diagnosis (3 possibilities)
92
+ 4. Suggest evidence-based treatment options
93
+ 5. Identify necessary lab tests with reasoning
94
+ 6. Flag high-risk factors in RED
95
+ 7. Maintain audit trail of all decisions"""
96
+
97
+ # Streamlit UI
98
+ def main():
99
+ st.set_page_config(page_title="AI Clinical Assistant", layout="wide")
100
+
101
+ with st.sidebar:
102
+ st.header("Patient Intake")
103
+ patient_data = {
104
+ "symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
105
+ "history": {
106
+ "conditions": st.text_input("Medical History"),
107
+ "medications": st.text_input("Current Medications")
108
+ },
109
+ "vitals": {
110
+ "temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0),
111
+ "bp": st.text_input("Blood Pressure (mmHg)", "120/80")
112
+ }
113
+ }
114
+
115
+ st.title("🧠 AI-Powered Clinical Decision Support")
116
+
117
+ model, tool, memory = init_components()
118
+ agent = AdvancedMedicalAgent(model, [tool], memory)
119
+
120
+ if st.button("Start Analysis"):
121
+ with st.spinner("Analyzing patient case..."):
122
+ agent.process_query(patient_data, str(uuid4()))
123
+
124
+ if 'history' in st.session_state:
125
+ st.subheader("Clinical Decision Log")
126
+ for role, content in st.session_state.history:
127
+ with st.chat_message(role):
128
+ st.markdown(content)
129
+
130
+ if __name__ == "__main__":
131
+ main()