SynapseAI / app.py
mgbam's picture
Add application file
788074d
raw
history blame
4.89 kB
# app.py
import streamlit as st
from uuid import uuid4
import sqlite3
import plotly.express as px
from langchain_groq import ChatGroq
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
# Initialize components
def init_components():
model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
tool = TavilySearchResults(max_results=3)
conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False)
memory = SqliteSaver(conn)
return model, tool, memory
# Medical Agent Class
class AdvancedMedicalAgent:
def __init__(self, model, tools, memory):
self.model = model.bind_tools(tools)
self.tools = {t.name: t for t in tools}
self.memory = memory
def process_query(self, patient_data, thread_id):
# Initialize or retrieve conversation thread
if 'history' not in st.session_state:
st.session_state.history = []
# Add patient data visualization
self.display_patient_dashboard(patient_data)
# Process medical inquiry
response = self.model.invoke([
SystemMessage(content=self.get_system_prompt()),
HumanMessage(content=f"Patient Case:\n{patient_data}")
])
# Handle tool calls
if hasattr(response, 'tool_calls'):
self.handle_medical_actions(response.tool_calls, thread_id)
st.session_state.history.append(("AI", response.content))
def display_patient_dashboard(self, data):
tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"])
with tab1:
fig = px.line(
x=["Temperature", "BP", "Heart Rate"],
y=[data['vitals']['temp'], 130, 85], # Example data
title="Vital Signs"
)
st.plotly_chart(fig)
with tab2:
st.json(data['history'])
with tab3:
st.vega_lite_chart({
"mark": {"type": "circle", "tooltip": True},
"encoding": {
"x": {"field": "timestamp", "type": "temporal"},
"y": {"field": "severity", "type": "quantitative"}
}
})
def handle_medical_actions(self, tool_calls, thread_id):
for call in tool_calls:
st.warning(f"🩺 Pending Medical Action: {call['name']}")
if st.button("Review Action Details"):
st.write(f"**Action Type:** {call['name']}")
st.write(f"**Parameters:** {call['args']}")
if st.checkbox("I approve this action"):
result = self.tools[call['name']].invoke(call['args'])
st.session_state.history.append(
("System", f"Action executed: {result}")
)
else:
st.session_state.history.append(
("System", "Action cancelled by clinician")
)
def get_system_prompt(self):
return """You are an advanced medical AI assistant. Follow these steps:
1. Analyze patient data using latest clinical guidelines
2. Consider drug interactions and contraindications
3. Propose differential diagnosis (3 possibilities)
4. Suggest evidence-based treatment options
5. Identify necessary lab tests with reasoning
6. Flag high-risk factors in RED
7. Maintain audit trail of all decisions"""
# Streamlit UI
def main():
st.set_page_config(page_title="AI Clinical Assistant", layout="wide")
with st.sidebar:
st.header("Patient Intake")
patient_data = {
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
"history": {
"conditions": st.text_input("Medical History"),
"medications": st.text_input("Current Medications")
},
"vitals": {
"temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0),
"bp": st.text_input("Blood Pressure (mmHg)", "120/80")
}
}
st.title("🧠 AI-Powered Clinical Decision Support")
model, tool, memory = init_components()
agent = AdvancedMedicalAgent(model, [tool], memory)
if st.button("Start Analysis"):
with st.spinner("Analyzing patient case..."):
agent.process_query(patient_data, str(uuid4()))
if 'history' in st.session_state:
st.subheader("Clinical Decision Log")
for role, content in st.session_state.history:
with st.chat_message(role):
st.markdown(content)
if __name__ == "__main__":
main()