|
|
|
import streamlit as st |
|
from uuid import uuid4 |
|
import sqlite3 |
|
import plotly.express as px |
|
from langchain_groq import ChatGroq |
|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage |
|
|
|
|
|
def init_components(): |
|
model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec") |
|
tool = TavilySearchResults(max_results=3) |
|
conn = sqlite3.connect("medical_checkpoints.db", check_same_thread=False) |
|
memory = SqliteSaver(conn) |
|
return model, tool, memory |
|
|
|
|
|
class AdvancedMedicalAgent: |
|
def __init__(self, model, tools, memory): |
|
self.model = model.bind_tools(tools) |
|
self.tools = {t.name: t for t in tools} |
|
self.memory = memory |
|
|
|
def process_query(self, patient_data, thread_id): |
|
|
|
if 'history' not in st.session_state: |
|
st.session_state.history = [] |
|
|
|
|
|
self.display_patient_dashboard(patient_data) |
|
|
|
|
|
response = self.model.invoke([ |
|
SystemMessage(content=self.get_system_prompt()), |
|
HumanMessage(content=f"Patient Case:\n{patient_data}") |
|
]) |
|
|
|
|
|
if hasattr(response, 'tool_calls'): |
|
self.handle_medical_actions(response.tool_calls, thread_id) |
|
|
|
st.session_state.history.append(("AI", response.content)) |
|
|
|
def display_patient_dashboard(self, data): |
|
tab1, tab2, tab3 = st.tabs(["Vitals", "History", "Timeline"]) |
|
|
|
with tab1: |
|
fig = px.line( |
|
x=["Temperature", "BP", "Heart Rate"], |
|
y=[data['vitals']['temp'], 130, 85], |
|
title="Vital Signs" |
|
) |
|
st.plotly_chart(fig) |
|
|
|
with tab2: |
|
st.json(data['history']) |
|
|
|
with tab3: |
|
st.vega_lite_chart({ |
|
"mark": {"type": "circle", "tooltip": True}, |
|
"encoding": { |
|
"x": {"field": "timestamp", "type": "temporal"}, |
|
"y": {"field": "severity", "type": "quantitative"} |
|
} |
|
}) |
|
|
|
def handle_medical_actions(self, tool_calls, thread_id): |
|
for call in tool_calls: |
|
st.warning(f"🩺 Pending Medical Action: {call['name']}") |
|
|
|
if st.button("Review Action Details"): |
|
st.write(f"**Action Type:** {call['name']}") |
|
st.write(f"**Parameters:** {call['args']}") |
|
|
|
if st.checkbox("I approve this action"): |
|
result = self.tools[call['name']].invoke(call['args']) |
|
st.session_state.history.append( |
|
("System", f"Action executed: {result}") |
|
) |
|
else: |
|
st.session_state.history.append( |
|
("System", "Action cancelled by clinician") |
|
) |
|
|
|
def get_system_prompt(self): |
|
return """You are an advanced medical AI assistant. Follow these steps: |
|
1. Analyze patient data using latest clinical guidelines |
|
2. Consider drug interactions and contraindications |
|
3. Propose differential diagnosis (3 possibilities) |
|
4. Suggest evidence-based treatment options |
|
5. Identify necessary lab tests with reasoning |
|
6. Flag high-risk factors in RED |
|
7. Maintain audit trail of all decisions""" |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="AI Clinical Assistant", layout="wide") |
|
|
|
with st.sidebar: |
|
st.header("Patient Intake") |
|
patient_data = { |
|
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]), |
|
"history": { |
|
"conditions": st.text_input("Medical History"), |
|
"medications": st.text_input("Current Medications") |
|
}, |
|
"vitals": { |
|
"temp": st.number_input("Temperature (°C)", 36.0, 42.0, 37.0), |
|
"bp": st.text_input("Blood Pressure (mmHg)", "120/80") |
|
} |
|
} |
|
|
|
st.title("🧠 AI-Powered Clinical Decision Support") |
|
|
|
model, tool, memory = init_components() |
|
agent = AdvancedMedicalAgent(model, [tool], memory) |
|
|
|
if st.button("Start Analysis"): |
|
with st.spinner("Analyzing patient case..."): |
|
agent.process_query(patient_data, str(uuid4())) |
|
|
|
if 'history' in st.session_state: |
|
st.subheader("Clinical Decision Log") |
|
for role, content in st.session_state.history: |
|
with st.chat_message(role): |
|
st.markdown(content) |
|
|
|
if __name__ == "__main__": |
|
main() |