SynapseAI / app.py
mgbam's picture
Update app.py
fc636ce verified
raw
history blame
6.59 kB
# app.py
import streamlit as st
import sqlite3
import plotly.express as px
from uuid import uuid4
from langgraph.graph import START, StateGraph, END
from langchain_groq import ChatGroq
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from typing import TypedDict, Annotated, List
# Configuration
class MedicalConfig:
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
1. Analyze patient data using latest medical guidelines
2. Always check for drug interactions
3. Suggest tests only when necessary
4. Use structured actions:
- order_lab_test: {test_name, reason}
- prescribe_medication: {name, dosage, frequency}
5. Research using tavily_search when uncertain"""
RED_FLAGS = {
'symptoms': ['chest pain', 'shortness of breath', 'severe headache'],
'vitals': {'temp': (38.5, '°C'), 'hr': (120, 'bpm'), 'bp': ('180/120', 'mmHg')}
}
# State Management
class AgentState(TypedDict):
messages: Annotated[List[dict], lambda l, r: l + r]
patient_data: dict
approvals: Annotated[List[dict], lambda l, r: l + r]
class MedicalAgent:
def __init__(self):
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
self.tools = {
"medical_actions": {
"order_lab_test": self.order_lab_test,
"prescribe_medication": self.prescribe_medication
},
"research": TavilySearchResults(max_results=3)
}
self.checkpointer = SqliteSaver(sqlite3.connect("medical.db"))
self._build_graph()
def _build_graph(self):
graph = StateGraph(AgentState)
graph.add_node("analyze", self.analyze_patient)
graph.add_node("execute", self.execute_actions)
graph.add_node("safety_check", self.safety_checks)
graph.add_edge(START, "analyze")
graph.add_conditional_edges(
"analyze",
self.route_actions,
{"medical": "safety_check", "research": "execute"}
)
graph.add_edge("safety_check", "execute")
graph.add_edge("execute", "analyze")
graph.set_entry_point("analyze")
self.graph = graph.compile(checkpointer=self.checkpointer)
def analyze_patient(self, state: AgentState):
response = self.model.bind_tools(list(self.tools['medical_actions'].keys())).invoke([
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
HumanMessage(content=f"Patient Data: {state['patient_data']}")
])
return {"messages": [response]}
def safety_checks(self, state: AgentState):
current_action = state['messages'][-1].tool_calls[0]
risk = self.check_contraindications(state['patient_data'], current_action)
return {"approvals": [{"action": current_action, "risk": risk}]}
def execute_actions(self, state: AgentState):
results = []
for action in state.get('approvals', []):
if action['risk']:
results.append(f"Blocked dangerous action: {action['action']['name']}")
else:
tool = self.tools['medical_actions'][action['action']['name']]
results.append(tool(action['action']['args']))
return {"messages": [{"content": "\n".join(results)}]}
def check_contraindications(self, patient_data, action):
# Implement actual medical safety checks
if 'prescribe' in action['name']:
return any(drug in patient_data['medications']
for drug in ['warfarin', 'insulin'])
return False
def order_lab_test(self, params):
return f"Lab ordered: {params['test_name']} ({params['reason']})"
def prescribe_medication(self, params):
return f"Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
# Streamlit UI
def main():
st.set_page_config(page_title="AI Clinic", layout="wide")
# Initialize agent
if 'agent' not in st.session_state:
st.session_state.agent = MedicalAgent()
# Patient intake
with st.sidebar:
st.header("Patient Intake")
patient_data = {
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
"history": {
"conditions": st.text_input("Medical History"),
"medications": st.text_input("Current Medications")
},
"vitals": {
"temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
"bp": st.text_input("BP (mmHg)", "120/80")
}
}
# Main interface
st.title("AI-Powered Clinical Support System")
col1, col2 = st.columns([2, 1])
with col1:
st.subheader("Patient Overview")
self.render_patient_dashboard(patient_data)
if st.button("Start Analysis"):
thread_id = f"case_{uuid4().hex[:8]}"
with st.spinner("Analyzing..."):
st.session_state.agent.graph.invoke(
{"patient_data": patient_data},
{"configurable": {"thread_id": thread_id}}
)
with col2:
st.subheader("Action Center")
self.render_approval_interface()
def render_patient_dashboard(self, data):
tab1, tab2 = st.tabs(["Vitals", "Timeline"])
with tab1:
fig = px.line(
x=["Temperature", "Blood Pressure"],
y=[data['vitals']['temp'], 120],
title="Vital Signs"
)
st.plotly_chart(fig)
with tab2:
st.vega_lite_chart({
"mark": {"type": "line", "point": True},
"encoding": {
"x": {"field": "hour", "type": "ordinal"},
"y": {"field": "temp", "type": "quantitative"}
},
"data": {"values": [{"hour": i, "temp": 36.5 + i/10} for i in range(24)]}
})
def render_approval_interface(self):
if 'current_action' in st.session_state:
action = st.session_state.current_action
st.warning("Action Requires Approval")
st.write(f"**Type:** {action['name'].replace('_', ' ').title()}")
st.json(action['args'])
if st.button("Approve"):
self.process_approval(True)
if st.button("Reject"):
self.process_approval(False)
else:
st.info("No pending actions")
if __name__ == "__main__":
main()