Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
4 |
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
5 |
from typing import TypedDict, Annotated, List, Optional
|
6 |
import json
|
7 |
|
@@ -10,39 +11,55 @@ class MedicalConfig:
|
|
10 |
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
|
11 |
1. Analyze patient data using latest medical guidelines
|
12 |
2. Always check for drug interactions
|
13 |
-
3.
|
14 |
-
|
15 |
-
-
|
16 |
-
|
17 |
-
5. Flag high-risk conditions immediately"""
|
18 |
|
19 |
RED_FLAGS = {
|
20 |
"symptoms": ["chest pain", "shortness of breath", "severe headache"],
|
21 |
"vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
|
22 |
}
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
class MedicalAgent:
|
31 |
def __init__(self):
|
32 |
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
|
33 |
-
self.
|
34 |
-
"medical_actions": {
|
35 |
-
"order_lab_test": self.order_lab_test,
|
36 |
-
"prescribe_medication": self.prescribe_medication
|
37 |
-
},
|
38 |
-
"research": TavilySearchResults(max_results=3)
|
39 |
-
}
|
40 |
|
41 |
def analyze_patient(self, patient_data: dict) -> Optional[dict]:
|
42 |
try:
|
43 |
-
|
44 |
-
model_with_tools = self.model.bind_tools(list(self.tools["medical_actions"].keys()))
|
45 |
-
response = model_with_tools.invoke([
|
46 |
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
|
47 |
HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
|
48 |
])
|
@@ -53,72 +70,16 @@ class MedicalAgent:
|
|
53 |
|
54 |
def process_action(self, action: dict) -> str:
|
55 |
try:
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
except Exception as e:
|
60 |
return f"Error processing action: {str(e)}"
|
61 |
|
62 |
-
|
63 |
-
return f"✅ Lab ordered: {params['test_name']} ({params['reason']})"
|
64 |
-
|
65 |
-
def prescribe_medication(self, params: dict) -> str:
|
66 |
-
return f"✅ Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
|
67 |
-
|
68 |
-
def check_red_flags(self, patient_data: dict) -> List[str]:
|
69 |
-
flags = []
|
70 |
-
for symptom in patient_data.get("symptoms", []):
|
71 |
-
if symptom.lower() in MedicalConfig.RED_FLAGS["symptoms"]:
|
72 |
-
flags.append(f"🚨 High-risk symptom: {symptom}")
|
73 |
-
return flags
|
74 |
-
|
75 |
-
# Streamlit UI
|
76 |
-
def main():
|
77 |
-
st.set_page_config(page_title="AI Clinic", layout="wide")
|
78 |
-
|
79 |
-
# Initialize agent
|
80 |
-
if 'agent' not in st.session_state:
|
81 |
-
st.session_state.agent = MedicalAgent()
|
82 |
-
|
83 |
-
# Patient intake
|
84 |
-
with st.sidebar:
|
85 |
-
st.header("Patient Intake")
|
86 |
-
patient_data = {
|
87 |
-
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"]),
|
88 |
-
"history": {
|
89 |
-
"conditions": st.text_input("Medical History"),
|
90 |
-
"medications": st.text_input("Current Medications")
|
91 |
-
},
|
92 |
-
"vitals": {
|
93 |
-
"temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
|
94 |
-
"bp": st.text_input("BP (mmHg)", "120/80")
|
95 |
-
}
|
96 |
-
}
|
97 |
-
|
98 |
-
# Main interface
|
99 |
-
st.title("AI-Powered Clinical Support System")
|
100 |
-
|
101 |
-
# Check for red flags
|
102 |
-
red_flags = st.session_state.agent.check_red_flags(patient_data)
|
103 |
-
if red_flags:
|
104 |
-
st.warning("High-Risk Conditions Detected")
|
105 |
-
for flag in red_flags:
|
106 |
-
st.error(flag)
|
107 |
-
|
108 |
-
if st.button("Start Analysis"):
|
109 |
-
with st.spinner("Analyzing patient data..."):
|
110 |
-
response = st.session_state.agent.analyze_patient(patient_data)
|
111 |
-
|
112 |
-
if response:
|
113 |
-
st.subheader("Analysis Results")
|
114 |
-
if hasattr(response, 'tool_calls') and response.tool_calls:
|
115 |
-
for action in response.tool_calls:
|
116 |
-
result = st.session_state.agent.process_action(action)
|
117 |
-
st.success(result)
|
118 |
-
else:
|
119 |
-
st.info("No specific actions recommended. Here's the analysis:")
|
120 |
-
st.write(response.content)
|
121 |
-
|
122 |
-
# Run the app
|
123 |
-
if __name__ == "__main__":
|
124 |
-
main()
|
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
4 |
from langchain_core.messages import HumanMessage, SystemMessage
|
5 |
+
from langchain_core.tools import tool
|
6 |
from typing import TypedDict, Annotated, List, Optional
|
7 |
import json
|
8 |
|
|
|
11 |
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
|
12 |
1. Analyze patient data using latest medical guidelines
|
13 |
2. Always check for drug interactions
|
14 |
+
3. Use structured actions when needed:
|
15 |
+
- lab_order: Order lab tests
|
16 |
+
- prescribe: Prescribe medication
|
17 |
+
4. Flag high-risk conditions immediately"""
|
|
|
18 |
|
19 |
RED_FLAGS = {
|
20 |
"symptoms": ["chest pain", "shortness of breath", "severe headache"],
|
21 |
"vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
|
22 |
}
|
23 |
|
24 |
+
# Define tools with proper schemas
|
25 |
+
@tool
|
26 |
+
def order_lab_test(test_name: str, reason: str) -> str:
|
27 |
+
"""Orders a lab test with specified parameters.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
test_name: Name of the lab test to order
|
31 |
+
reason: Clinical justification for the test
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Confirmation message with test details
|
35 |
+
"""
|
36 |
+
return f"Lab ordered: {test_name} ({reason})"
|
37 |
+
|
38 |
+
@tool
|
39 |
+
def prescribe_medication(name: str, dosage: str, frequency: str) -> str:
|
40 |
+
"""Prescribes medication with specific dosage instructions.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
name: Name of the medication
|
44 |
+
dosage: Dosage amount (e.g., '500mg')
|
45 |
+
frequency: Administration frequency (e.g., 'every 6 hours')
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Confirmation message with prescription details
|
49 |
+
"""
|
50 |
+
return f"Prescribed: {name} {dosage} {frequency}"
|
51 |
+
|
52 |
+
# Initialize tools
|
53 |
+
tools = [order_lab_test, prescribe_medication, TavilySearchResults(max_results=3)]
|
54 |
|
55 |
class MedicalAgent:
|
56 |
def __init__(self):
|
57 |
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
|
58 |
+
self.model_with_tools = self.model.bind_tools(tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def analyze_patient(self, patient_data: dict) -> Optional[dict]:
|
61 |
try:
|
62 |
+
response = self.model_with_tools.invoke([
|
|
|
|
|
63 |
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
|
64 |
HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
|
65 |
])
|
|
|
70 |
|
71 |
def process_action(self, action: dict) -> str:
|
72 |
try:
|
73 |
+
tool_name = action['name']
|
74 |
+
args = action['args']
|
75 |
+
|
76 |
+
if tool_name == "order_lab_test":
|
77 |
+
return order_lab_test.invoke(args)
|
78 |
+
elif tool_name == "prescribe_medication":
|
79 |
+
return prescribe_medication.invoke(args)
|
80 |
+
else:
|
81 |
+
return f"Unknown action: {tool_name}"
|
82 |
except Exception as e:
|
83 |
return f"Error processing action: {str(e)}"
|
84 |
|
85 |
+
# Rest of the Streamlit UI code remains the same...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|