|
|
|
import streamlit as st |
|
import sqlite3 |
|
import plotly.express as px |
|
from uuid import uuid4 |
|
from langgraph.graph import START, StateGraph, END |
|
from langchain_groq import ChatGroq |
|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage |
|
from typing import TypedDict, Annotated, List |
|
|
|
|
|
class MedicalConfig: |
|
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules: |
|
1. Analyze patient data using latest medical guidelines |
|
2. Always check for drug interactions |
|
3. Suggest tests only when necessary |
|
4. Use structured actions: |
|
- order_lab_test: {test_name, reason} |
|
- prescribe_medication: {name, dosage, frequency} |
|
5. Research using tavily_search when uncertain""" |
|
|
|
RED_FLAGS = { |
|
'symptoms': ['chest pain', 'shortness of breath', 'severe headache'], |
|
'vitals': {'temp': (38.5, '°C'), 'hr': (120, 'bpm'), 'bp': ('180/120', 'mmHg')} |
|
} |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[List[dict], lambda l, r: l + r] |
|
patient_data: dict |
|
approvals: Annotated[List[dict], lambda l, r: l + r] |
|
|
|
class MedicalAgent: |
|
def __init__(self): |
|
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec") |
|
self.tools = { |
|
"medical_actions": { |
|
"order_lab_test": self.order_lab_test, |
|
"prescribe_medication": self.prescribe_medication |
|
}, |
|
"research": TavilySearchResults(max_results=3) |
|
} |
|
self.checkpointer = SqliteSaver(sqlite3.connect("medical.db")) |
|
self._build_graph() |
|
|
|
def _build_graph(self): |
|
graph = StateGraph(AgentState) |
|
graph.add_node("analyze", self.analyze_patient) |
|
graph.add_node("execute", self.execute_actions) |
|
graph.add_node("safety_check", self.safety_checks) |
|
|
|
graph.add_edge(START, "analyze") |
|
graph.add_conditional_edges( |
|
"analyze", |
|
self.route_actions, |
|
{"medical": "safety_check", "research": "execute"} |
|
) |
|
graph.add_edge("safety_check", "execute") |
|
graph.add_edge("execute", "analyze") |
|
|
|
graph.set_entry_point("analyze") |
|
self.graph = graph.compile(checkpointer=self.checkpointer) |
|
|
|
def analyze_patient(self, state: AgentState): |
|
response = self.model.bind_tools(list(self.tools['medical_actions'].keys())).invoke([ |
|
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT), |
|
HumanMessage(content=f"Patient Data: {state['patient_data']}") |
|
]) |
|
return {"messages": [response]} |
|
|
|
def safety_checks(self, state: AgentState): |
|
current_action = state['messages'][-1].tool_calls[0] |
|
risk = self.check_contraindications(state['patient_data'], current_action) |
|
return {"approvals": [{"action": current_action, "risk": risk}]} |
|
|
|
def execute_actions(self, state: AgentState): |
|
results = [] |
|
for action in state.get('approvals', []): |
|
if action['risk']: |
|
results.append(f"Blocked dangerous action: {action['action']['name']}") |
|
else: |
|
tool = self.tools['medical_actions'][action['action']['name']] |
|
results.append(tool(action['action']['args'])) |
|
return {"messages": [{"content": "\n".join(results)}]} |
|
|
|
def check_contraindications(self, patient_data, action): |
|
|
|
if 'prescribe' in action['name']: |
|
return any(drug in patient_data['medications'] |
|
for drug in ['warfarin', 'insulin']) |
|
return False |
|
|
|
def order_lab_test(self, params): |
|
return f"Lab ordered: {params['test_name']} ({params['reason']})" |
|
|
|
def prescribe_medication(self, params): |
|
return f"Prescribed: {params['name']} {params['dosage']} {params['frequency']}" |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="AI Clinic", layout="wide") |
|
|
|
|
|
if 'agent' not in st.session_state: |
|
st.session_state.agent = MedicalAgent() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Patient Intake") |
|
patient_data = { |
|
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]), |
|
"history": { |
|
"conditions": st.text_input("Medical History"), |
|
"medications": st.text_input("Current Medications") |
|
}, |
|
"vitals": { |
|
"temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0), |
|
"bp": st.text_input("BP (mmHg)", "120/80") |
|
} |
|
} |
|
|
|
|
|
st.title("AI-Powered Clinical Support System") |
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
st.subheader("Patient Overview") |
|
self.render_patient_dashboard(patient_data) |
|
|
|
if st.button("Start Analysis"): |
|
thread_id = f"case_{uuid4().hex[:8]}" |
|
with st.spinner("Analyzing..."): |
|
st.session_state.agent.graph.invoke( |
|
{"patient_data": patient_data}, |
|
{"configurable": {"thread_id": thread_id}} |
|
) |
|
|
|
with col2: |
|
st.subheader("Action Center") |
|
self.render_approval_interface() |
|
|
|
def render_patient_dashboard(self, data): |
|
tab1, tab2 = st.tabs(["Vitals", "Timeline"]) |
|
|
|
with tab1: |
|
fig = px.line( |
|
x=["Temperature", "Blood Pressure"], |
|
y=[data['vitals']['temp'], 120], |
|
title="Vital Signs" |
|
) |
|
st.plotly_chart(fig) |
|
|
|
with tab2: |
|
st.vega_lite_chart({ |
|
"mark": {"type": "line", "point": True}, |
|
"encoding": { |
|
"x": {"field": "hour", "type": "ordinal"}, |
|
"y": {"field": "temp", "type": "quantitative"} |
|
}, |
|
"data": {"values": [{"hour": i, "temp": 36.5 + i/10} for i in range(24)]} |
|
}) |
|
|
|
def render_approval_interface(self): |
|
if 'current_action' in st.session_state: |
|
action = st.session_state.current_action |
|
st.warning("Action Requires Approval") |
|
st.write(f"**Type:** {action['name'].replace('_', ' ').title()}") |
|
st.json(action['args']) |
|
|
|
if st.button("Approve"): |
|
self.process_approval(True) |
|
if st.button("Reject"): |
|
self.process_approval(False) |
|
else: |
|
st.info("No pending actions") |
|
|
|
if __name__ == "__main__": |
|
main() |