File size: 4,890 Bytes
788074d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# app.py
import streamlit as st
from uuid import uuid4
import sqlite3
import plotly.express as px
from langchain_groq import ChatGroq
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
# Initialize components
def init_components():
model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
tool = TavilySearchResults(max_results=3)
conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False)
memory = SqliteSaver(conn)
return model, tool, memory
# Medical Agent Class
class AdvancedMedicalAgent:
def __init__(self, model, tools, memory):
self.model = model.bind_tools(tools)
self.tools = {t.name: t for t in tools}
self.memory = memory
def process_query(self, patient_data, thread_id):
# Initialize or retrieve conversation thread
if 'history' not in st.session_state:
st.session_state.history = []
# Add patient data visualization
self.display_patient_dashboard(patient_data)
# Process medical inquiry
response = self.model.invoke([
SystemMessage(content=self.get_system_prompt()),
HumanMessage(content=f"Patient Case:\n{patient_data}")
])
# Handle tool calls
if hasattr(response, 'tool_calls'):
self.handle_medical_actions(response.tool_calls, thread_id)
st.session_state.history.append(("AI", response.content))
def display_patient_dashboard(self, data):
tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"])
with tab1:
fig = px.line(
x=["Temperature", "BP", "Heart Rate"],
y=[data['vitals']['temp'], 130, 85], # Example data
title="Vital Signs"
)
st.plotly_chart(fig)
with tab2:
st.json(data['history'])
with tab3:
st.vega_lite_chart({
"mark": {"type": "circle", "tooltip": True},
"encoding": {
"x": {"field": "timestamp", "type": "temporal"},
"y": {"field": "severity", "type": "quantitative"}
}
})
def handle_medical_actions(self, tool_calls, thread_id):
for call in tool_calls:
st.warning(f"🩺 Pending Medical Action: {call['name']}")
if st.button("Review Action Details"):
st.write(f"**Action Type:** {call['name']}")
st.write(f"**Parameters:** {call['args']}")
if st.checkbox("I approve this action"):
result = self.tools[call['name']].invoke(call['args'])
st.session_state.history.append(
("System", f"Action executed: {result}")
)
else:
st.session_state.history.append(
("System", "Action cancelled by clinician")
)
def get_system_prompt(self):
return """You are an advanced medical AI assistant. Follow these steps:
1. Analyze patient data using latest clinical guidelines
2. Consider drug interactions and contraindications
3. Propose differential diagnosis (3 possibilities)
4. Suggest evidence-based treatment options
5. Identify necessary lab tests with reasoning
6. Flag high-risk factors in RED
7. Maintain audit trail of all decisions"""
# Streamlit UI
def main():
st.set_page_config(page_title="AI Clinical Assistant", layout="wide")
with st.sidebar:
st.header("Patient Intake")
patient_data = {
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
"history": {
"conditions": st.text_input("Medical History"),
"medications": st.text_input("Current Medications")
},
"vitals": {
"temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0),
"bp": st.text_input("Blood Pressure (mmHg)", "120/80")
}
}
st.title("🧠 AI-Powered Clinical Decision Support")
model, tool, memory = init_components()
agent = AdvancedMedicalAgent(model, [tool], memory)
if st.button("Start Analysis"):
with st.spinner("Analyzing patient case..."):
agent.process_query(patient_data, str(uuid4()))
if 'history' in st.session_state:
st.subheader("Clinical Decision Log")
for role, content in st.session_state.history:
with st.chat_message(role):
st.markdown(content)
if __name__ == "__main__":
main() |