mgbam commited on
Commit
fc636ce
·
verified ·
1 Parent(s): 1807654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -95
app.py CHANGED
@@ -1,103 +1,112 @@
1
  # app.py
2
  import streamlit as st
3
- from uuid import uuid4
4
  import sqlite3
5
  import plotly.express as px
 
 
6
  from langchain_groq import ChatGroq
7
  from langgraph.checkpoint.sqlite import SqliteSaver
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Initialize components
12
- def init_components():
13
- model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
14
- tool = TavilySearchResults(max_results=3)
15
- conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False)
16
- memory = SqliteSaver(conn)
17
- return model, tool, memory
 
 
 
 
 
18
 
19
- # Medical Agent Class
20
- class AdvancedMedicalAgent:
21
- def __init__(self, model, tools, memory):
22
- self.model = model.bind_tools(tools)
23
- self.tools = {t.name: t for t in tools}
24
- self.memory = memory
25
 
26
- def process_query(self, patient_data, thread_id):
27
- # Initialize or retrieve conversation thread
28
- if 'history' not in st.session_state:
29
- st.session_state.history = []
30
-
31
- # Add patient data visualization
32
- self.display_patient_dashboard(patient_data)
 
33
 
34
- # Process medical inquiry
35
- response = self.model.invoke([
36
- SystemMessage(content=self.get_system_prompt()),
37
- HumanMessage(content=f"Patient Case:\n{patient_data}")
 
 
 
38
  ])
39
-
40
- # Handle tool calls
41
- if hasattr(response, 'tool_calls'):
42
- self.handle_medical_actions(response.tool_calls, thread_id)
43
-
44
- st.session_state.history.append(("AI", response.content))
45
-
46
- def display_patient_dashboard(self, data):
47
- tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"])
48
-
49
- with tab1:
50
- fig = px.line(
51
- x=["Temperature", "BP", "Heart Rate"],
52
- y=[data['vitals']['temp'], 130, 85], # Example data
53
- title="Vital Signs"
54
- )
55
- st.plotly_chart(fig)
56
-
57
- with tab2:
58
- st.json(data['history'])
59
-
60
- with tab3:
61
- st.vega_lite_chart({
62
- "mark": {"type": "circle", "tooltip": True},
63
- "encoding": {
64
- "x": {"field": "timestamp", "type": "temporal"},
65
- "y": {"field": "severity", "type": "quantitative"}
66
- }
67
- })
68
 
69
- def handle_medical_actions(self, tool_calls, thread_id):
70
- for call in tool_calls:
71
- st.warning(f"🩺 Pending Medical Action: {call['name']}")
72
-
73
- if st.button("Review Action Details"):
74
- st.write(f"**Action Type:** {call['name']}")
75
- st.write(f"**Parameters:** {call['args']}")
76
-
77
- if st.checkbox("I approve this action"):
78
- result = self.tools[call['name']].invoke(call['args'])
79
- st.session_state.history.append(
80
- ("System", f"Action executed: {result}")
81
- )
82
- else:
83
- st.session_state.history.append(
84
- ("System", "Action cancelled by clinician")
85
- )
86
 
87
- def get_system_prompt(self):
88
- return """You are an advanced medical AI assistant. Follow these steps:
89
- 1. Analyze patient data using latest clinical guidelines
90
- 2. Consider drug interactions and contraindications
91
- 3. Propose differential diagnosis (3 possibilities)
92
- 4. Suggest evidence-based treatment options
93
- 5. Identify necessary lab tests with reasoning
94
- 6. Flag high-risk factors in RED
95
- 7. Maintain audit trail of all decisions"""
 
 
 
96
 
97
  # Streamlit UI
98
  def main():
99
- st.set_page_config(page_title="AI Clinical Assistant", layout="wide")
 
 
 
 
100
 
 
101
  with st.sidebar:
102
  st.header("Patient Intake")
103
  patient_data = {
@@ -107,25 +116,66 @@ def main():
107
  "medications": st.text_input("Current Medications")
108
  },
109
  "vitals": {
110
- "temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0),
111
- "bp": st.text_input("Blood Pressure (mmHg)", "120/80")
112
  }
113
  }
114
 
115
- st.title("🧠 AI-Powered Clinical Decision Support")
 
116
 
117
- model, tool, memory = init_components()
118
- agent = AdvancedMedicalAgent(model, [tool], memory)
119
 
120
- if st.button("Start Analysis"):
121
- with st.spinner("Analyzing patient case..."):
122
- agent.process_query(patient_data, str(uuid4()))
123
-
124
- if 'history' in st.session_state:
125
- st.subheader("Clinical Decision Log")
126
- for role, content in st.session_state.history:
127
- with st.chat_message(role):
128
- st.markdown(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
  main()
 
1
  # app.py
2
  import streamlit as st
 
3
  import sqlite3
4
  import plotly.express as px
5
+ from uuid import uuid4
6
+ from langgraph.graph import START, StateGraph, END
7
  from langchain_groq import ChatGroq
8
  from langgraph.checkpoint.sqlite import SqliteSaver
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
11
+ from typing import TypedDict, Annotated, List
12
+
13
+ # Configuration
14
+ class MedicalConfig:
15
+ SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
16
+ 1. Analyze patient data using latest medical guidelines
17
+ 2. Always check for drug interactions
18
+ 3. Suggest tests only when necessary
19
+ 4. Use structured actions:
20
+ - order_lab_test: {test_name, reason}
21
+ - prescribe_medication: {name, dosage, frequency}
22
+ 5. Research using tavily_search when uncertain"""
23
+
24
+ RED_FLAGS = {
25
+ 'symptoms': ['chest pain', 'shortness of breath', 'severe headache'],
26
+ 'vitals': {'temp': (38.5, '°C'), 'hr': (120, 'bpm'), 'bp': ('180/120', 'mmHg')}
27
+ }
28
+
29
+ # State Management
30
+ class AgentState(TypedDict):
31
+ messages: Annotated[List[dict], lambda l, r: l + r]
32
+ patient_data: dict
33
+ approvals: Annotated[List[dict], lambda l, r: l + r]
34
 
35
+ class MedicalAgent:
36
+ def __init__(self):
37
+ self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
38
+ self.tools = {
39
+ "medical_actions": {
40
+ "order_lab_test": self.order_lab_test,
41
+ "prescribe_medication": self.prescribe_medication
42
+ },
43
+ "research": TavilySearchResults(max_results=3)
44
+ }
45
+ self.checkpointer = SqliteSaver(sqlite3.connect("medical.db"))
46
+ self._build_graph()
47
 
48
+ def _build_graph(self):
49
+ graph = StateGraph(AgentState)
50
+ graph.add_node("analyze", self.analyze_patient)
51
+ graph.add_node("execute", self.execute_actions)
52
+ graph.add_node("safety_check", self.safety_checks)
 
53
 
54
+ graph.add_edge(START, "analyze")
55
+ graph.add_conditional_edges(
56
+ "analyze",
57
+ self.route_actions,
58
+ {"medical": "safety_check", "research": "execute"}
59
+ )
60
+ graph.add_edge("safety_check", "execute")
61
+ graph.add_edge("execute", "analyze")
62
 
63
+ graph.set_entry_point("analyze")
64
+ self.graph = graph.compile(checkpointer=self.checkpointer)
65
+
66
+ def analyze_patient(self, state: AgentState):
67
+ response = self.model.bind_tools(list(self.tools['medical_actions'].keys())).invoke([
68
+ SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
69
+ HumanMessage(content=f"Patient Data: {state['patient_data']}")
70
  ])
71
+ return {"messages": [response]}
72
+
73
+ def safety_checks(self, state: AgentState):
74
+ current_action = state['messages'][-1].tool_calls[0]
75
+ risk = self.check_contraindications(state['patient_data'], current_action)
76
+ return {"approvals": [{"action": current_action, "risk": risk}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def execute_actions(self, state: AgentState):
79
+ results = []
80
+ for action in state.get('approvals', []):
81
+ if action['risk']:
82
+ results.append(f"Blocked dangerous action: {action['action']['name']}")
83
+ else:
84
+ tool = self.tools['medical_actions'][action['action']['name']]
85
+ results.append(tool(action['action']['args']))
86
+ return {"messages": [{"content": "\n".join(results)}]}
 
 
 
 
 
 
 
 
87
 
88
+ def check_contraindications(self, patient_data, action):
89
+ # Implement actual medical safety checks
90
+ if 'prescribe' in action['name']:
91
+ return any(drug in patient_data['medications']
92
+ for drug in ['warfarin', 'insulin'])
93
+ return False
94
+
95
+ def order_lab_test(self, params):
96
+ return f"Lab ordered: {params['test_name']} ({params['reason']})"
97
+
98
+ def prescribe_medication(self, params):
99
+ return f"Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
100
 
101
  # Streamlit UI
102
  def main():
103
+ st.set_page_config(page_title="AI Clinic", layout="wide")
104
+
105
+ # Initialize agent
106
+ if 'agent' not in st.session_state:
107
+ st.session_state.agent = MedicalAgent()
108
 
109
+ # Patient intake
110
  with st.sidebar:
111
  st.header("Patient Intake")
112
  patient_data = {
 
116
  "medications": st.text_input("Current Medications")
117
  },
118
  "vitals": {
119
+ "temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
120
+ "bp": st.text_input("BP (mmHg)", "120/80")
121
  }
122
  }
123
 
124
+ # Main interface
125
+ st.title("AI-Powered Clinical Support System")
126
 
127
+ col1, col2 = st.columns([2, 1])
 
128
 
129
+ with col1:
130
+ st.subheader("Patient Overview")
131
+ self.render_patient_dashboard(patient_data)
132
+
133
+ if st.button("Start Analysis"):
134
+ thread_id = f"case_{uuid4().hex[:8]}"
135
+ with st.spinner("Analyzing..."):
136
+ st.session_state.agent.graph.invoke(
137
+ {"patient_data": patient_data},
138
+ {"configurable": {"thread_id": thread_id}}
139
+ )
140
+
141
+ with col2:
142
+ st.subheader("Action Center")
143
+ self.render_approval_interface()
144
+
145
+ def render_patient_dashboard(self, data):
146
+ tab1, tab2 = st.tabs(["Vitals", "Timeline"])
147
+
148
+ with tab1:
149
+ fig = px.line(
150
+ x=["Temperature", "Blood Pressure"],
151
+ y=[data['vitals']['temp'], 120],
152
+ title="Vital Signs"
153
+ )
154
+ st.plotly_chart(fig)
155
+
156
+ with tab2:
157
+ st.vega_lite_chart({
158
+ "mark": {"type": "line", "point": True},
159
+ "encoding": {
160
+ "x": {"field": "hour", "type": "ordinal"},
161
+ "y": {"field": "temp", "type": "quantitative"}
162
+ },
163
+ "data": {"values": [{"hour": i, "temp": 36.5 + i/10} for i in range(24)]}
164
+ })
165
+
166
+ def render_approval_interface(self):
167
+ if 'current_action' in st.session_state:
168
+ action = st.session_state.current_action
169
+ st.warning("Action Requires Approval")
170
+ st.write(f"**Type:** {action['name'].replace('_', ' ').title()}")
171
+ st.json(action['args'])
172
+
173
+ if st.button("Approve"):
174
+ self.process_approval(True)
175
+ if st.button("Reject"):
176
+ self.process_approval(False)
177
+ else:
178
+ st.info("No pending actions")
179
 
180
  if __name__ == "__main__":
181
  main()