SynapseAI / app.py
mgbam's picture
Update app.py
03e1441 verified
raw
history blame
4.75 kB
import streamlit as st
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage
from typing import TypedDict, Annotated, List, Optional
import json
# Configuration
class MedicalConfig:
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
1. Analyze patient data using latest medical guidelines
2. Always check for drug interactions
3. Suggest tests only when necessary
4. Use structured actions:
- order_lab_test: {test_name, reason}
- prescribe_medication: {name, dosage, frequency}
5. Flag high-risk conditions immediately"""
RED_FLAGS = {
"symptoms": ["chest pain", "shortness of breath", "severe headache"],
"vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
}
# State Management
class AgentState(TypedDict):
messages: Annotated[List[dict], lambda l, r: l + r]
patient_data: dict
approvals: Annotated[List[dict], lambda l, r: l + r]
class MedicalAgent:
def __init__(self):
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
self.tools = {
"medical_actions": {
"order_lab_test": self.order_lab_test,
"prescribe_medication": self.prescribe_medication
},
"research": TavilySearchResults(max_results=3)
}
def analyze_patient(self, patient_data: dict) -> Optional[dict]:
try:
# Bind tools to the model
model_with_tools = self.model.bind_tools(list(self.tools["medical_actions"].keys()))
response = model_with_tools.invoke([
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
])
return response
except Exception as e:
st.error(f"Error analyzing patient data: {str(e)}")
return None
def process_action(self, action: dict) -> str:
try:
if action['name'] in self.tools['medical_actions']:
return self.tools['medical_actions'][action['name']](action['args'])
return f"Unknown action: {action['name']}"
except Exception as e:
return f"Error processing action: {str(e)}"
def order_lab_test(self, params: dict) -> str:
return f"✅ Lab ordered: {params['test_name']} ({params['reason']})"
def prescribe_medication(self, params: dict) -> str:
return f"✅ Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
def check_red_flags(self, patient_data: dict) -> List[str]:
flags = []
for symptom in patient_data.get("symptoms", []):
if symptom.lower() in MedicalConfig.RED_FLAGS["symptoms"]:
flags.append(f"🚨 High-risk symptom: {symptom}")
return flags
# Streamlit UI
def main():
st.set_page_config(page_title="AI Clinic", layout="wide")
# Initialize agent
if 'agent' not in st.session_state:
st.session_state.agent = MedicalAgent()
# Patient intake
with st.sidebar:
st.header("Patient Intake")
patient_data = {
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"]),
"history": {
"conditions": st.text_input("Medical History"),
"medications": st.text_input("Current Medications")
},
"vitals": {
"temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
"bp": st.text_input("BP (mmHg)", "120/80")
}
}
# Main interface
st.title("AI-Powered Clinical Support System")
# Check for red flags
red_flags = st.session_state.agent.check_red_flags(patient_data)
if red_flags:
st.warning("High-Risk Conditions Detected")
for flag in red_flags:
st.error(flag)
if st.button("Start Analysis"):
with st.spinner("Analyzing patient data..."):
response = st.session_state.agent.analyze_patient(patient_data)
if response:
st.subheader("Analysis Results")
if hasattr(response, 'tool_calls') and response.tool_calls:
for action in response.tool_calls:
result = st.session_state.agent.process_action(action)
st.success(result)
else:
st.info("No specific actions recommended. Here's the analysis:")
st.write(response.content)
# Run the app
if __name__ == "__main__":
main()