Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,9 @@ import streamlit as st
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
4 |
from langchain_core.messages import HumanMessage, SystemMessage
|
5 |
-
from typing import TypedDict, Annotated, List
|
|
|
|
|
6 |
|
7 |
# Configuration
|
8 |
class MedicalConfig:
|
@@ -12,12 +14,19 @@ class MedicalConfig:
|
|
12 |
3. Suggest tests only when necessary
|
13 |
4. Use structured actions:
|
14 |
- order_lab_test: {test_name, reason}
|
15 |
-
- prescribe_medication: {name, dosage, frequency}
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# State Management
|
18 |
class AgentState(TypedDict):
|
19 |
messages: Annotated[List[dict], lambda l, r: l + r]
|
20 |
patient_data: dict
|
|
|
21 |
|
22 |
class MedicalAgent:
|
23 |
def __init__(self):
|
@@ -30,23 +39,37 @@ class MedicalAgent:
|
|
30 |
"research": TavilySearchResults(max_results=3)
|
31 |
}
|
32 |
|
33 |
-
def analyze_patient(self, patient_data):
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
def
|
41 |
-
|
42 |
-
return self.tools['medical_actions'][action['name']](action['args'])
|
43 |
-
return "Unknown action"
|
44 |
|
45 |
-
def
|
46 |
-
return f"
|
47 |
|
48 |
-
def
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Streamlit UI
|
52 |
def main():
|
@@ -60,7 +83,7 @@ def main():
|
|
60 |
with st.sidebar:
|
61 |
st.header("Patient Intake")
|
62 |
patient_data = {
|
63 |
-
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]),
|
64 |
"history": {
|
65 |
"conditions": st.text_input("Medical History"),
|
66 |
"medications": st.text_input("Current Medications")
|
@@ -74,16 +97,25 @@ def main():
|
|
74 |
# Main interface
|
75 |
st.title("AI-Powered Clinical Support System")
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
if st.button("Start Analysis"):
|
78 |
with st.spinner("Analyzing patient data..."):
|
79 |
response = st.session_state.agent.analyze_patient(patient_data)
|
80 |
|
81 |
-
if
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
|
|
|
88 |
if __name__ == "__main__":
|
89 |
-
main()
|
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
4 |
from langchain_core.messages import HumanMessage, SystemMessage
|
5 |
+
from typing import TypedDict, Annotated, List, Optional
|
6 |
+
import json
|
7 |
+
import time
|
8 |
|
9 |
# Configuration
|
10 |
class MedicalConfig:
|
|
|
14 |
3. Suggest tests only when necessary
|
15 |
4. Use structured actions:
|
16 |
- order_lab_test: {test_name, reason}
|
17 |
+
- prescribe_medication: {name, dosage, frequency}
|
18 |
+
5. Flag high-risk conditions immediately"""
|
19 |
+
|
20 |
+
RED_FLAGS = {
|
21 |
+
"symptoms": ["chest pain", "shortness of breath", "severe headache"],
|
22 |
+
"vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
|
23 |
+
}
|
24 |
|
25 |
# State Management
|
26 |
class AgentState(TypedDict):
|
27 |
messages: Annotated[List[dict], lambda l, r: l + r]
|
28 |
patient_data: dict
|
29 |
+
approvals: Annotated[List[dict], lambda l, r: l + r]
|
30 |
|
31 |
class MedicalAgent:
|
32 |
def __init__(self):
|
|
|
39 |
"research": TavilySearchResults(max_results=3)
|
40 |
}
|
41 |
|
42 |
+
def analyze_patient(self, patient_data: dict) -> Optional[dict]:
|
43 |
+
try:
|
44 |
+
response = self.model.invoke([
|
45 |
+
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
|
46 |
+
HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
|
47 |
+
])
|
48 |
+
return response
|
49 |
+
except Exception as e:
|
50 |
+
st.error(f"Error analyzing patient data: {str(e)}")
|
51 |
+
return None
|
52 |
+
|
53 |
+
def process_action(self, action: dict) -> str:
|
54 |
+
try:
|
55 |
+
if action['name'] in self.tools['medical_actions']:
|
56 |
+
return self.tools['medical_actions'][action['name']](action['args'])
|
57 |
+
return "Unknown action"
|
58 |
+
except Exception as e:
|
59 |
+
return f"Error processing action: {str(e)}"
|
60 |
|
61 |
+
def order_lab_test(self, params: dict) -> str:
|
62 |
+
return f"✅ Lab ordered: {params['test_name']} ({params['reason']})"
|
|
|
|
|
63 |
|
64 |
+
def prescribe_medication(self, params: dict) -> str:
|
65 |
+
return f"✅ Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
|
66 |
|
67 |
+
def check_red_flags(self, patient_data: dict) -> List[str]:
|
68 |
+
flags = []
|
69 |
+
for symptom in patient_data.get("symptoms", []):
|
70 |
+
if symptom.lower() in MedicalConfig.RED_FLAGS["symptoms"]:
|
71 |
+
flags.append(f"🚨 High-risk symptom: {symptom}")
|
72 |
+
return flags
|
73 |
|
74 |
# Streamlit UI
|
75 |
def main():
|
|
|
83 |
with st.sidebar:
|
84 |
st.header("Patient Intake")
|
85 |
patient_data = {
|
86 |
+
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"]),
|
87 |
"history": {
|
88 |
"conditions": st.text_input("Medical History"),
|
89 |
"medications": st.text_input("Current Medications")
|
|
|
97 |
# Main interface
|
98 |
st.title("AI-Powered Clinical Support System")
|
99 |
|
100 |
+
# Check for red flags
|
101 |
+
red_flags = st.session_state.agent.check_red_flags(patient_data)
|
102 |
+
if red_flags:
|
103 |
+
st.warning("High-Risk Conditions Detected")
|
104 |
+
for flag in red_flags:
|
105 |
+
st.error(flag)
|
106 |
+
|
107 |
if st.button("Start Analysis"):
|
108 |
with st.spinner("Analyzing patient data..."):
|
109 |
response = st.session_state.agent.analyze_patient(patient_data)
|
110 |
|
111 |
+
if response:
|
112 |
+
if hasattr(response, 'tool_calls'):
|
113 |
+
for action in response.tool_calls:
|
114 |
+
result = st.session_state.agent.process_action(action)
|
115 |
+
st.success(result)
|
116 |
+
else:
|
117 |
+
st.info(response.content)
|
118 |
|
119 |
+
# Run the app
|
120 |
if __name__ == "__main__":
|
121 |
+
main()
|