mgbam commited on
Commit
b8dee1d
·
verified ·
1 Parent(s): e49f6e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -111
app.py CHANGED
@@ -1,13 +1,7 @@
1
- # app.py
2
  import streamlit as st
3
- import sqlite3
4
- import plotly.express as px
5
- from uuid import uuid4
6
- from langgraph.graph import START, StateGraph, END
7
  from langchain_groq import ChatGroq
8
- from langgraph.checkpoint.sqlite import SqliteSaver
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
11
  from typing import TypedDict, Annotated, List
12
 
13
  # Configuration
@@ -18,19 +12,12 @@ class MedicalConfig:
18
  3. Suggest tests only when necessary
19
  4. Use structured actions:
20
  - order_lab_test: {test_name, reason}
21
- - prescribe_medication: {name, dosage, frequency}
22
- 5. Research using tavily_search when uncertain"""
23
-
24
- RED_FLAGS = {
25
- 'symptoms': ['chest pain', 'shortness of breath', 'severe headache'],
26
- 'vitals': {'temp': (38.5, '°C'), 'hr': (120, 'bpm'), 'bp': ('180/120', 'mmHg')}
27
- }
28
 
29
  # State Management
30
  class AgentState(TypedDict):
31
  messages: Annotated[List[dict], lambda l, r: l + r]
32
  patient_data: dict
33
- approvals: Annotated[List[dict], lambda l, r: l + r]
34
 
35
  class MedicalAgent:
36
  def __init__(self):
@@ -42,55 +29,18 @@ class MedicalAgent:
42
  },
43
  "research": TavilySearchResults(max_results=3)
44
  }
45
- self.checkpointer = SqliteSaver(sqlite3.connect("medical.db"))
46
- self._build_graph()
47
-
48
- def _build_graph(self):
49
- graph = StateGraph(AgentState)
50
- graph.add_node("analyze", self.analyze_patient)
51
- graph.add_node("execute", self.execute_actions)
52
- graph.add_node("safety_check", self.safety_checks)
53
-
54
- graph.add_edge(START, "analyze")
55
- graph.add_conditional_edges(
56
- "analyze",
57
- self.route_actions,
58
- {"medical": "safety_check", "research": "execute"}
59
- )
60
- graph.add_edge("safety_check", "execute")
61
- graph.add_edge("execute", "analyze")
62
-
63
- graph.set_entry_point("analyze")
64
- self.graph = graph.compile(checkpointer=self.checkpointer)
65
 
66
- def analyze_patient(self, state: AgentState):
67
  response = self.model.bind_tools(list(self.tools['medical_actions'].keys())).invoke([
68
  SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
69
- HumanMessage(content=f"Patient Data: {state['patient_data']}")
70
  ])
71
- return {"messages": [response]}
72
-
73
- def safety_checks(self, state: AgentState):
74
- current_action = state['messages'][-1].tool_calls[0]
75
- risk = self.check_contraindications(state['patient_data'], current_action)
76
- return {"approvals": [{"action": current_action, "risk": risk}]}
77
 
78
- def execute_actions(self, state: AgentState):
79
- results = []
80
- for action in state.get('approvals', []):
81
- if action['risk']:
82
- results.append(f"Blocked dangerous action: {action['action']['name']}")
83
- else:
84
- tool = self.tools['medical_actions'][action['action']['name']]
85
- results.append(tool(action['action']['args']))
86
- return {"messages": [{"content": "\n".join(results)}]}
87
-
88
- def check_contraindications(self, patient_data, action):
89
- # Implement actual medical safety checks
90
- if 'prescribe' in action['name']:
91
- return any(drug in patient_data['medications']
92
- for drug in ['warfarin', 'insulin'])
93
- return False
94
 
95
  def order_lab_test(self, params):
96
  return f"Lab ordered: {params['test_name']} ({params['reason']})"
@@ -124,58 +74,16 @@ def main():
124
  # Main interface
125
  st.title("AI-Powered Clinical Support System")
126
 
127
- col1, col2 = st.columns([2, 1])
128
-
129
- with col1:
130
- st.subheader("Patient Overview")
131
- self.render_patient_dashboard(patient_data)
132
-
133
- if st.button("Start Analysis"):
134
- thread_id = f"case_{uuid4().hex[:8]}"
135
- with st.spinner("Analyzing..."):
136
- st.session_state.agent.graph.invoke(
137
- {"patient_data": patient_data},
138
- {"configurable": {"thread_id": thread_id}}
139
- )
140
-
141
- with col2:
142
- st.subheader("Action Center")
143
- self.render_approval_interface()
144
-
145
- def render_patient_dashboard(self, data):
146
- tab1, tab2 = st.tabs(["Vitals", "Timeline"])
147
-
148
- with tab1:
149
- fig = px.line(
150
- x=["Temperature", "Blood Pressure"],
151
- y=[data['vitals']['temp'], 120],
152
- title="Vital Signs"
153
- )
154
- st.plotly_chart(fig)
155
-
156
- with tab2:
157
- st.vega_lite_chart({
158
- "mark": {"type": "line", "point": True},
159
- "encoding": {
160
- "x": {"field": "hour", "type": "ordinal"},
161
- "y": {"field": "temp", "type": "quantitative"}
162
- },
163
- "data": {"values": [{"hour": i, "temp": 36.5 + i/10} for i in range(24)]}
164
- })
165
-
166
- def render_approval_interface(self):
167
- if 'current_action' in st.session_state:
168
- action = st.session_state.current_action
169
- st.warning("Action Requires Approval")
170
- st.write(f"**Type:** {action['name'].replace('_', ' ').title()}")
171
- st.json(action['args'])
172
-
173
- if st.button("Approve"):
174
- self.process_approval(True)
175
- if st.button("Reject"):
176
- self.process_approval(False)
177
- else:
178
- st.info("No pending actions")
179
 
180
  if __name__ == "__main__":
181
  main()
 
 
1
  import streamlit as st
 
 
 
 
2
  from langchain_groq import ChatGroq
 
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_core.messages import HumanMessage, SystemMessage
5
  from typing import TypedDict, Annotated, List
6
 
7
  # Configuration
 
12
  3. Suggest tests only when necessary
13
  4. Use structured actions:
14
  - order_lab_test: {test_name, reason}
15
+ - prescribe_medication: {name, dosage, frequency}"""
 
 
 
 
 
 
16
 
17
  # State Management
18
  class AgentState(TypedDict):
19
  messages: Annotated[List[dict], lambda l, r: l + r]
20
  patient_data: dict
 
21
 
22
  class MedicalAgent:
23
  def __init__(self):
 
29
  },
30
  "research": TavilySearchResults(max_results=3)
31
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def analyze_patient(self, patient_data):
34
  response = self.model.bind_tools(list(self.tools['medical_actions'].keys())).invoke([
35
  SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
36
+ HumanMessage(content=f"Patient Data: {patient_data}")
37
  ])
38
+ return response
 
 
 
 
 
39
 
40
+ def process_action(self, action):
41
+ if action['name'] in self.tools['medical_actions']:
42
+ return self.tools['medical_actions'][action['name']](action['args'])
43
+ return "Unknown action"
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def order_lab_test(self, params):
46
  return f"Lab ordered: {params['test_name']} ({params['reason']})"
 
74
  # Main interface
75
  st.title("AI-Powered Clinical Support System")
76
 
77
+ if st.button("Start Analysis"):
78
+ with st.spinner("Analyzing patient data..."):
79
+ response = st.session_state.agent.analyze_patient(patient_data)
80
+
81
+ if hasattr(response, 'tool_calls'):
82
+ for action in response.tool_calls:
83
+ result = st.session_state.agent.process_action(action)
84
+ st.success(result)
85
+ else:
86
+ st.info(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  if __name__ == "__main__":
89
  main()