mgbam commited on
Commit
896de2d
Β·
verified Β·
1 Parent(s): b564942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +538 -98
app.py CHANGED
@@ -1,121 +1,561 @@
1
  import streamlit as st
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
- from langchain_core.messages import HumanMessage, SystemMessage
 
 
 
5
  from langchain_core.tools import tool
6
- from typing import Optional
7
  import json
 
8
 
9
- # Configuration
10
- class MedicalConfig:
11
- SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
12
- 1. Analyze patient data using latest medical guidelines
13
- 2. Always check for drug interactions
14
- 3. Use structured actions when needed:
15
- - lab_order: Order lab tests
16
- - prescribe: Prescribe medication
17
- 4. Flag high-risk conditions immediately"""
18
-
19
- RED_FLAGS = {
20
- "symptoms": ["chest pain", "shortness of breath", "severe headache"],
21
- "vitals": {"temp": (38.5, "Β°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
22
- }
23
-
24
- # Define tools with proper schemas
25
- @tool
26
- def order_lab_test(test_name: str, reason: str) -> str:
27
- """Orders a lab test with specified parameters."""
28
- return f"βœ… Lab ordered: {test_name} ({reason})"
29
-
30
- @tool
31
- def prescribe_medication(name: str, dosage: str, frequency: str) -> str:
32
- """Prescribes medication with specific dosage instructions."""
33
- return f"βœ… Prescribed: {name} {dosage} {frequency}"
34
-
35
- # Initialize tools and model
36
- tools = [order_lab_test, prescribe_medication, TavilySearchResults(max_results=3)]
37
-
38
- class MedicalAgent:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def __init__(self):
40
- self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
41
- self.model_with_tools = self.model.bind_tools(tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- def analyze_patient(self, patient_data: dict) -> Optional[dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
- response = self.model_with_tools.invoke([
46
- SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
47
- HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
48
- ])
49
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
- st.error(f"Error analyzing patient data: {str(e)}")
52
- return None
 
 
 
 
 
 
 
 
 
53
 
54
- def process_action(self, action: dict) -> str:
55
  try:
56
- tool_name = action['name']
57
- args = action['args']
58
-
59
- if tool_name == "order_lab_test":
60
- return order_lab_test.invoke(args)
61
- elif tool_name == "prescribe_medication":
62
- return prescribe_medication.invoke(args)
63
- else:
64
- return f"Unknown action: {tool_name}"
65
  except Exception as e:
66
- return f"Error processing action: {str(e)}"
 
67
 
68
- # Streamlit UI
69
  def main():
70
- st.set_page_config(page_title="AI Clinic", layout="wide")
71
-
72
- # Initialize agent
 
 
73
  if 'agent' not in st.session_state:
74
- st.session_state.agent = MedicalAgent()
75
-
76
- # Patient intake form
 
 
 
 
 
 
 
 
77
  with st.sidebar:
78
- st.header("Patient Intake")
79
- symptoms = st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"])
80
- medical_history = st.text_input("Medical History")
81
- current_meds = st.text_input("Current Medications")
82
- temp = st.number_input("Temperature (Β°C)", 35.0, 42.0, 37.0)
83
- bp = st.text_input("Blood Pressure (mmHg)", "120/80")
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  patient_data = {
86
- "symptoms": symptoms,
87
- "history": {
88
- "conditions": medical_history,
89
- "medications": current_meds
90
- },
 
 
 
 
91
  "vitals": {
92
- "temp": temp,
93
- "bp": bp
94
- }
 
95
  }
96
-
97
- # Main interface
98
- st.title("AI-Powered Clinical Support System")
99
-
100
- # Analysis button
101
- if st.button("Start Analysis", type="primary"):
102
- with st.spinner("Analyzing patient data..."):
103
- response = st.session_state.agent.analyze_patient(patient_data)
104
-
105
- if response:
106
- st.subheader("Clinical Recommendations")
107
-
108
- # Check for tool calls
109
- if hasattr(response, 'tool_calls') and response.tool_calls:
110
- for action in response.tool_calls:
111
- result = st.session_state.agent.process_action({
112
- "name": action['name'],
113
- "args": action['args']
114
- })
115
- st.success(result)
116
- else:
117
- st.info("No specific actions recommended. Clinical Summary:")
118
- st.write(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
  main()
 
1
  import streamlit as st
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.pydantic_v1 import BaseModel, Field
8
  from langchain_core.tools import tool
9
+ from typing import Optional, List, Dict, Any
10
  import json
11
+ import re # For parsing vitals like BP
12
 
13
+ # --- Configuration & Constants ---
14
+ class ClinicalAppSettings:
15
+ APP_TITLE = "SynapseAI: Advanced Clinical Decision Support"
16
+ PAGE_LAYOUT = "wide"
17
+ MODEL_NAME = "llama3-70b-8192" # Use a powerful model like Groq's Llama3 70b
18
+ TEMPERATURE = 0.1
19
+ MAX_SEARCH_RESULTS = 3
20
+
21
+ class ClinicalPrompts:
22
+ SYSTEM_PROMPT = """
23
+ You are SynapseAI, an expert AI clinical assistant designed to support healthcare professionals.
24
+ Your primary function is to analyze patient data, provide differential diagnoses, suggest evidence-based management plans, and identify potential risks according to the latest medical guidelines and safety protocols.
25
+
26
+ **Core Directives:**
27
+ 1. **Comprehensive Analysis:** Thoroughly analyze ALL provided patient data (demographics, HPI, PMH, PSH, Allergies, Meds, SH, FH, ROS, Vitals, Exam).
28
+ 2. **Structured Output:** ALWAYS format your response using the following JSON structure:
29
+ ```json
30
+ {
31
+ "assessment": "Concise summary of the patient's presentation and key findings.",
32
+ "differential_diagnosis": [
33
+ {"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence..."},
34
+ {"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
35
+ {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
36
+ ],
37
+ "risk_assessment": {
38
+ "identified_red_flags": ["List any triggered red flags based on input"],
39
+ "immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
40
+ "potential_complications": ["Possible future issues based on presentation"]
41
+ },
42
+ "recommended_plan": {
43
+ "investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
44
+ "therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool."],
45
+ "consultations": ["Recommend specialist consultations if needed."],
46
+ "patient_education": ["Key points for patient communication."]
47
+ },
48
+ "rationale_summary": "Brief justification for the overall assessment and plan, referencing guidelines or evidence where possible. Use 'tavily_search_results' tool if needed to find supporting evidence/guidelines.",
49
+ "interaction_check_summary": "Summary of findings from the 'check_drug_interactions' tool IF a new medication was considered or prescribed."
50
+ }
51
+ ```
52
+ 3. **Safety First - Red Flags:** Immediately identify and report any conditions matching the defined RED_FLAGS. Use the `flag_risk` tool if critical.
53
+ 4. **Safety First - Drug Interactions:** BEFORE suggesting *any* new prescription, you MUST use the `check_drug_interactions` tool to verify against the patient's current medications and allergies. Mention the result in `interaction_check_summary`.
54
+ 5. **Tool Utilization:** Employ the provided tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) precisely when indicated by your plan. Adhere strictly to tool schemas. Do NOT hallucinate tool usage results; wait for actual tool output if required in a multi-turn scenario (though this implementation focuses on single-turn analysis with tool calls).
55
+ 6. **Evidence-Based:** Briefly cite reasoning, drawing on general medical knowledge. Use Tavily Search for specific guideline checks or novel information when necessary.
56
+ 7. **Clarity and Conciseness:** Be clear, avoiding ambiguity. Use standard medical terminology.
57
+ """
58
+
59
+ # --- Mock Data / Helpers ---
60
+ # (In a real system, this would be a proper API/database)
61
+ MOCK_INTERACTION_DB = {
62
+ ("Lisinopril", "Spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
63
+ ("Warfarin", "Amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
64
+ ("Simvastatin", "Clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
65
+ ("Aspirin", "Ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
66
+ }
67
+
68
+ ALLERGY_INTERACTIONS = {
69
+ "Penicillin": ["Amoxicillin", "Ampicillin", "Piperacillin"],
70
+ "Sulfa": ["Sulfamethoxazole", "Sulfasalazine"],
71
+ "Aspirin": ["Ibuprofen", "Naproxen"] # Cross-reactivity example for NSAIDs
72
+ }
73
+
74
+ def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
75
+ """Parses BP string like '120/80' into (systolic, diastolic) integers."""
76
+ match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
77
+ if match:
78
+ return int(match.group(1)), int(match.group(2))
79
+ return None
80
+
81
+ def check_red_flags(patient_data: dict) -> List[str]:
82
+ """Checks patient data against predefined red flags."""
83
+ flags = []
84
+ symptoms = patient_data.get("hpi", {}).get("symptoms", [])
85
+ vitals = patient_data.get("vitals", {})
86
+ history = patient_data.get("pmh", {}).get("conditions", "")
87
+
88
+ # Symptom Flags
89
+ if "chest pain" in [s.lower() for s in symptoms]: flags.append("Red Flag: Chest Pain reported.")
90
+ if "shortness of breath" in [s.lower() for s in symptoms]: flags.append("Red Flag: Shortness of Breath reported.")
91
+ if "severe headache" in [s.lower() for s in symptoms]: flags.append("Red Flag: Severe Headache reported.")
92
+ if "sudden vision loss" in [s.lower() for s in symptoms]: flags.append("Red Flag: Sudden Vision Loss reported.")
93
+ if "weakness on one side" in [s.lower() for s in symptoms]: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
94
+
95
+ # Vital Sign Flags (add more checks as needed)
96
+ if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever (Temperature: {vitals['temp_c']}Β°C).")
97
+ if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia (Heart Rate: {vitals['hr_bpm']} bpm).")
98
+ if "rr_rpm" in vitals and vitals["rr_rpm"] >= 24: flags.append(f"Red Flag: Tachypnea (Respiratory Rate: {vitals['rr_rpm']} rpm).")
99
+ if "spo2_percent" in vitals and vitals["spo2_percent"] <= 92: flags.append(f"Red Flag: Hypoxia (SpO2: {vitals['spo2_percent']}%).")
100
+ if "bp_mmhg" in vitals:
101
+ bp = parse_bp(vitals["bp_mmhg"])
102
+ if bp:
103
+ if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
104
+ if bp[0] <= 90 and bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
105
+
106
+ # History Flags (Simple examples)
107
+ if "history of mi" in history.lower() and "chest pain" in [s.lower() for s in symptoms]: flags.append("Red Flag: History of MI with current Chest Pain.")
108
+
109
+ return flags
110
+
111
+ # --- Enhanced Tool Definitions ---
112
+
113
+ # Use Pydantic models for robust argument validation
114
+ class LabOrderInput(BaseModel):
115
+ test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
116
+ reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
117
+ priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
118
+
119
+ @tool("order_lab_test", args_schema=LabOrderInput)
120
+ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
121
+ """Orders a specific lab test with clinical justification and priority."""
122
+ return json.dumps({
123
+ "status": "success",
124
+ "message": f"Lab Ordered: {test_name} ({priority})",
125
+ "details": f"Reason: {reason}"
126
+ })
127
+
128
+ class PrescriptionInput(BaseModel):
129
+ medication_name: str = Field(..., description="Name of the medication.")
130
+ dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg').")
131
+ route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical').")
132
+ frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN').")
133
+ duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Until follow-up').")
134
+ reason: str = Field(..., description="Clinical indication for the prescription.")
135
+
136
+ @tool("prescribe_medication", args_schema=PrescriptionInput)
137
+ def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
138
+ """Prescribes a medication with detailed instructions and clinical indication."""
139
+ # In a real scenario, this would trigger an e-prescription workflow
140
+ return json.dumps({
141
+ "status": "success",
142
+ "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
143
+ "details": f"Duration: {duration}. Reason: {reason}"
144
+ })
145
+
146
+ class InteractionCheckInput(BaseModel):
147
+ potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
148
+ current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.")
149
+ allergies: List[str] = Field(..., description="List of the patient's known allergies.")
150
+
151
+ @tool("check_drug_interactions", args_schema=InteractionCheckInput)
152
+ def check_drug_interactions(potential_prescription: str, current_medications: List[str], allergies: List[str]) -> str:
153
+ """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
154
+ warnings = []
155
+ potential_med_lower = potential_prescription.lower()
156
+
157
+ # Check Allergies
158
+ for allergy in allergies:
159
+ allergy_lower = allergy.lower()
160
+ # Simple direct check
161
+ if allergy_lower == potential_med_lower:
162
+ warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
163
+ continue
164
+ # Check cross-reactivity (using simplified mock data)
165
+ if allergy_lower in ALLERGY_INTERACTIONS:
166
+ for cross_reactant in ALLERGY_INTERACTIONS[allergy_lower]:
167
+ if cross_reactant.lower() == potential_med_lower:
168
+ warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
169
+
170
+ # Check Drug-Drug Interactions (using simplified mock data)
171
+ current_meds_lower = [med.lower() for med in current_medications]
172
+ for current_med in current_meds_lower:
173
+ # Check pairs in both orders
174
+ pair1 = (current_med, potential_med_lower)
175
+ pair2 = (potential_med_lower, current_med)
176
+ if pair1 in MOCK_INTERACTION_DB:
177
+ warnings.append(f"Interaction Found: {potential_prescription} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
178
+ elif pair2 in MOCK_INTERACTION_DB:
179
+ warnings.append(f"Interaction Found: {potential_prescription} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
180
+
181
+ if not warnings:
182
+ return json.dumps({"status": "clear", "message": f"No major interactions identified for {potential_prescription} with current meds/allergies.", "warnings": []})
183
+ else:
184
+ return json.dumps({"status": "warning", "message": f"Potential interactions identified for {potential_prescription}.", "warnings": warnings})
185
+
186
+ class FlagRiskInput(BaseModel):
187
+ risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
188
+ urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').")
189
+
190
+ @tool("flag_risk", args_schema=FlagRiskInput)
191
+ def flag_risk(risk_description: str, urgency: str) -> str:
192
+ """Flags a critical risk identified during analysis for immediate attention."""
193
+ st.error(f"🚨 **{urgency.upper()} RISK FLAGGED:** {risk_description}", icon="🚨")
194
+ return json.dumps({
195
+ "status": "flagged",
196
+ "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
197
+ })
198
+
199
+
200
+ # Initialize Search Tool
201
+ search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS)
202
+
203
+ # --- Core Agent Logic ---
204
+ class ClinicalAgent:
205
  def __init__(self):
206
+ self.model = ChatGroq(
207
+ temperature=ClinicalAppSettings.TEMPERATURE,
208
+ model=ClinicalAppSettings.MODEL_NAME
209
+ )
210
+ # Combine all tools
211
+ self.tools = [
212
+ order_lab_test,
213
+ prescribe_medication,
214
+ check_drug_interactions,
215
+ flag_risk,
216
+ search_tool
217
+ ]
218
+ # Bind tools to the model
219
+ self.model_with_tools = self.model.bind_tools(self.tools)
220
+ # History for context (simple implementation)
221
+ self.history = []
222
 
223
+ def _format_patient_data_for_prompt(self, data: dict) -> str:
224
+ """Formats the patient dictionary into a readable string for the LLM."""
225
+ prompt_str = "Patient Data:\n"
226
+ for key, value in data.items():
227
+ if isinstance(value, dict):
228
+ prompt_str += f" {key.replace('_', ' ').title()}:\n"
229
+ for sub_key, sub_value in value.items():
230
+ if sub_value: # Only include if there's data
231
+ prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
232
+ elif isinstance(value, list) and value:
233
+ prompt_str += f" {key.replace('_', ' ').title()}: {', '.join(map(str, value))}\n"
234
+ elif value: # Only include non-empty fields
235
+ prompt_str += f" {key.replace('_', ' ').title()}: {value}\n"
236
+ return prompt_str.strip()
237
+
238
+
239
+ def analyze(self, patient_data: dict) -> tuple[Optional[dict], List[dict]]:
240
+ """Runs the analysis, handling tool calls and parsing the structured output."""
241
  try:
242
+ # Add System Prompt and formatted Patient Data
243
+ # Simple history management: add previous messages if any
244
+ messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)]
245
+ # Include history if needed - consider token limits
246
+ # messages.extend(self.history)
247
+ formatted_data = self._format_patient_data_for_prompt(patient_data)
248
+ messages.append(HumanMessage(content=formatted_data))
249
+
250
+ # Invoke the model
251
+ ai_response = self.model_with_tools.invoke(messages)
252
+
253
+ # Store conversation turn
254
+ # self.history.append(HumanMessage(content=formatted_data))
255
+ # self.history.append(ai_response) # AIMessage includes tool calls
256
+
257
+ response_content = None
258
+ tool_calls = []
259
+
260
+ if isinstance(ai_response, AIMessage):
261
+ # Check if the response contains the structured JSON output
262
+ try:
263
+ # Sometimes the JSON is embedded in the content, sometimes it's the primary content
264
+ # Look for ```json ... ``` block first
265
+ json_match = re.search(r"```json\n(\{.*?\})\n```", ai_response.content, re.DOTALL)
266
+ if json_match:
267
+ response_content = json.loads(json_match.group(1))
268
+ else:
269
+ # Try parsing the whole content as JSON
270
+ response_content = json.loads(ai_response.content)
271
+ except json.JSONDecodeError:
272
+ st.warning("AI did not return valid JSON in the expected format. Displaying raw content.")
273
+ st.code(ai_response.content, language=None) # Display raw if not JSON
274
+ response_content = {"assessment": ai_response.content, "error": "Output format incorrect"}
275
+
276
+ # Extract tool calls separately
277
+ if ai_response.tool_calls:
278
+ tool_calls = ai_response.tool_calls
279
+
280
+ return response_content, tool_calls
281
+
282
  except Exception as e:
283
+ st.error(f"Error during AI analysis: {str(e)}")
284
+ return None, []
285
+
286
+ def process_tool_call(self, tool_call: Dict[str, Any]) -> Any:
287
+ """Executes a single tool call."""
288
+ tool_name = tool_call.get("name")
289
+ tool_args = tool_call.get("args", {})
290
+ selected_tool = {t.name: t for t in self.tools}.get(tool_name)
291
+
292
+ if not selected_tool:
293
+ return json.dumps({"status": "error", "message": f"Unknown tool: {tool_name}"})
294
 
 
295
  try:
296
+ # Ensure args are correctly passed (Pydantic models handle validation)
297
+ return selected_tool.invoke(tool_args)
 
 
 
 
 
 
 
298
  except Exception as e:
299
+ st.error(f"Error executing tool '{tool_name}': {str(e)}")
300
+ return json.dumps({"status": "error", "message": f"Failed to execute {tool_name}: {str(e)}"})
301
 
302
+ # --- Streamlit UI ---
303
  def main():
304
+ st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
305
+ st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
306
+ st.caption(f"Powered by Langchain & Groq ({ClinicalAppSettings.MODEL_NAME})")
307
+
308
+ # Initialize Agent in session state
309
  if 'agent' not in st.session_state:
310
+ st.session_state.agent = ClinicalAgent()
311
+ if 'analysis_complete' not in st.session_state:
312
+ st.session_state.analysis_complete = False
313
+ if 'analysis_result' not in st.session_state:
314
+ st.session_state.analysis_result = None
315
+ if 'tool_call_results' not in st.session_state:
316
+ st.session_state.tool_call_results = []
317
+ if 'red_flags' not in st.session_state:
318
+ st.session_state.red_flags = []
319
+
320
+ # --- Patient Data Input Sidebar ---
321
  with st.sidebar:
322
+ st.header("πŸ“„ Patient Intake Form")
323
+
324
+ # Demographics
325
+ st.subheader("Demographics")
326
+ age = st.number_input("Age", min_value=0, max_value=120, value=55)
327
+ sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"])
328
+
329
+ # History of Present Illness (HPI)
330
+ st.subheader("History of Present Illness (HPI)")
331
+ chief_complaint = st.text_input("Chief Complaint", "Chest pain")
332
+ hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago, described as pressure, radiating to left arm. Associated with nausea and diaphoresis. Pain is 8/10 severity. No relief with rest.")
333
+ symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"])
334
+
335
+ # Past Medical/Surgical History (PMH/PSH)
336
+ st.subheader("Past History")
337
+ pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)")
338
+ psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)")
339
+
340
+ # Medications & Allergies
341
+ st.subheader("Medications & Allergies")
342
+ current_meds = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily")
343
+ allergies = st.text_area("Allergies (comma separated)", "Penicillin (rash)")
344
+
345
+ # Social & Family History (SH/FH)
346
+ st.subheader("Social/Family History")
347
+ social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.")
348
+ family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.")
349
+
350
+ # Review of Systems (ROS) - Simplified
351
+ # st.subheader("Review of Systems (ROS)") # Keep UI cleaner for now
352
+ # ros_constitutional = st.checkbox("ROS: Constitutional (Fever, Chills, Weight loss)")
353
+ # ros_cardiac = st.checkbox("ROS: Cardiac (Chest pain, Palpitations)", value=True) # Pre-check based on HPI
354
+
355
+ # Vitals & Basic Exam
356
+ st.subheader("Vitals & Exam Findings")
357
+ col1, col2 = st.columns(2)
358
+ with col1:
359
+ temp_c = st.number_input("Temperature (Β°C)", 35.0, 42.0, 36.8, format="%.1f")
360
+ hr_bpm = st.number_input("Heart Rate (bpm)", 30, 250, 95)
361
+ rr_rpm = st.number_input("Respiratory Rate (rpm)", 5, 50, 18)
362
+ with col2:
363
+ bp_mmhg = st.text_input("Blood Pressure (SYS/DIA)", "155/90")
364
+ spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96)
365
+ pain_scale = st.slider("Pain (0-10)", 0, 10, 8)
366
+ exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear. Cardiac exam: Regular rhythm, no murmurs/gallops. Abdomen soft. No edema.")
367
+
368
+ # Clean medication list and allergies for processing
369
+ current_meds_list = [med.strip() for med in current_meds.split('\n') if med.strip()]
370
+ current_med_names = [med.split(' ')[0].strip() for med in current_meds_list] # Simplified name extraction
371
+ allergies_list = [a.strip() for a in allergies.split(',') if a.strip()]
372
+
373
+ # Compile Patient Data Dictionary
374
  patient_data = {
375
+ "demographics": {"age": age, "sex": sex},
376
+ "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
377
+ "pmh": {"conditions": pmh},
378
+ "psh": {"procedures": psh},
379
+ "medications": {"current": current_meds_list, "names_only": current_med_names},
380
+ "allergies": allergies_list,
381
+ "social_history": {"details": social_history},
382
+ "family_history": {"details": family_history},
383
+ # "ros": {"constitutional": ros_constitutional, "cardiac": ros_cardiac}, # Add if using ROS inputs
384
  "vitals": {
385
+ "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg,
386
+ "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale
387
+ },
388
+ "exam_findings": {"notes": exam_notes}
389
  }
390
+
391
+ # --- Main Analysis Area ---
392
+ st.header("πŸ€– AI Clinical Analysis")
393
+
394
+ # Action Button
395
+ if st.button("Analyze Patient Data", type="primary", use_container_width=True):
396
+ st.session_state.analysis_complete = False
397
+ st.session_state.analysis_result = None
398
+ st.session_state.tool_call_results = []
399
+ st.session_state.red_flags = []
400
+
401
+ # 1. Initial Red Flag Check (Client-side before LLM)
402
+ st.session_state.red_flags = check_red_flags(patient_data)
403
+ if st.session_state.red_flags:
404
+ st.warning("**Initial Red Flags Detected:**")
405
+ for flag in st.session_state.red_flags:
406
+ st.warning(f"- {flag}")
407
+ st.warning("Proceeding with AI analysis, but these require immediate attention.")
408
+
409
+ # 2. Call AI Agent
410
+ with st.spinner("SynapseAI is processing the case... Please wait."):
411
+ analysis_output, tool_calls = st.session_state.agent.analyze(patient_data)
412
+
413
+ if analysis_output:
414
+ st.session_state.analysis_result = analysis_output
415
+ st.session_state.analysis_complete = True
416
+
417
+ # 3. Process any Tool Calls requested by the AI
418
+ if tool_calls:
419
+ st.info(f"AI recommended {len(tool_calls)} action(s). Executing...")
420
+ tool_results = []
421
+ with st.spinner("Executing recommended actions..."):
422
+ for call in tool_calls:
423
+ st.write(f"βš™οΈ Requesting: `{call['name']}` with args `{call['args']}`")
424
+ # Pass patient context if needed (e.g., for interaction check)
425
+ if call['name'] == 'check_drug_interactions':
426
+ call['args']['current_medications'] = patient_data['medications']['names_only']
427
+ call['args']['allergies'] = patient_data['allergies']
428
+ elif call['name'] == 'prescribe_medication':
429
+ # Pre-flight check: Ensure interaction check was requested *before* this prescribe call
430
+ interaction_check_requested = any(tc['name'] == 'check_drug_interactions' and tc['args'].get('potential_prescription') == call['args'].get('medication_name') for tc in tool_calls)
431
+ if not interaction_check_requested:
432
+ st.error(f"**Safety Violation:** AI attempted to prescribe '{call['args'].get('medication_name')}' without requesting `check_drug_interactions` first. Prescription blocked.")
433
+ tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": json.dumps({"status":"error", "message": "Interaction check not performed prior to prescription attempt."})})
434
+ continue # Skip this tool call
435
+
436
+ result = st.session_state.agent.process_tool_call(call)
437
+ tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": result}) # Store result with ID
438
+
439
+ # Display tool result immediately
440
+ try:
441
+ result_data = json.loads(result)
442
+ if result_data.get("status") == "success" or result_data.get("status") == "clear" or result_data.get("status") == "flagged":
443
+ st.success(f"βœ… Action `{call['name']}`: {result_data.get('message')}", icon="βœ…")
444
+ if result_data.get("details"): st.caption(f"Details: {result_data.get('details')}")
445
+ elif result_data.get("status") == "warning":
446
+ st.warning(f"⚠️ Action `{call['name']}`: {result_data.get('message')}", icon="⚠️")
447
+ if result_data.get("warnings"):
448
+ for warn in result_data["warnings"]: st.caption(f"- {warn}")
449
+ else:
450
+ st.error(f"❌ Action `{call['name']}`: {result_data.get('message')}", icon="❌")
451
+ except json.JSONDecodeError:
452
+ st.error(f"Tool `{call['name']}` returned non-JSON: {result}") # Fallback for non-JSON results
453
+
454
+ st.session_state.tool_call_results = tool_results
455
+ # Optionally: Send results back to LLM for final summary (requires multi-turn agent)
456
+ else:
457
+ st.error("Analysis failed. Please check the input data or try again.")
458
+
459
+ # --- Display Analysis Results ---
460
+ if st.session_state.analysis_complete and st.session_state.analysis_result:
461
+ st.divider()
462
+ st.header("πŸ“Š Analysis & Recommendations")
463
+
464
+ res = st.session_state.analysis_result
465
+
466
+ # Layout columns for better readability
467
+ col_assessment, col_plan = st.columns(2)
468
+
469
+ with col_assessment:
470
+ st.subheader("πŸ“‹ Assessment")
471
+ st.write(res.get("assessment", "N/A"))
472
+
473
+ st.subheader("πŸ€” Differential Diagnosis")
474
+ ddx = res.get("differential_diagnosis", [])
475
+ if ddx:
476
+ for item in ddx:
477
+ likelihood = item.get('likelihood', 'Unknown').capitalize()
478
+ icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
479
+ with st.expander(f"{icon} {item.get('diagnosis', 'Unknown Diagnosis')} ({likelihood} Likelihood)", expanded=(likelihood=="High")):
480
+ st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
481
+ else:
482
+ st.info("No differential diagnosis provided.")
483
+
484
+ st.subheader("🚨 Risk Assessment")
485
+ risk = res.get("risk_assessment", {})
486
+ flags = risk.get("identified_red_flags", []) + [f.replace("Red Flag: ", "") for f in st.session_state.red_flags] # Combine AI and initial flags
487
+ if flags:
488
+ st.warning(f"**Identified Red Flags:** {', '.join(flags)}")
489
+ else:
490
+ st.success("No immediate red flags identified by AI in this analysis.")
491
+
492
+ if risk.get("immediate_concerns"):
493
+ st.warning(f"**Immediate Concerns:** {', '.join(risk.get('immediate_concerns'))}")
494
+ if risk.get("potential_complications"):
495
+ st.info(f"**Potential Complications:** {', '.join(risk.get('potential_complications'))}")
496
+
497
+
498
+ with col_plan:
499
+ st.subheader("πŸ“ Recommended Plan")
500
+ plan = res.get("recommended_plan", {})
501
+
502
+ st.markdown("**Investigations:**")
503
+ if plan.get("investigations"):
504
+ st.markdown("\n".join([f"- {inv}" for inv in plan.get("investigations")]))
505
+ else: st.markdown("_None suggested._")
506
+
507
+ st.markdown("**Therapeutics:**")
508
+ if plan.get("therapeutics"):
509
+ st.markdown("\n".join([f"- {thx}" for thx in plan.get("therapeutics")]))
510
+ else: st.markdown("_None suggested._")
511
+
512
+ st.markdown("**Consultations:**")
513
+ if plan.get("consultations"):
514
+ st.markdown("\n".join([f"- {con}" for con in plan.get("consultations")]))
515
+ else: st.markdown("_None suggested._")
516
+
517
+ st.markdown("**Patient Education:**")
518
+ if plan.get("patient_education"):
519
+ st.markdown("\n".join([f"- {edu}" for edu in plan.get("patient_education")]))
520
+ else: st.markdown("_None specified._")
521
+
522
+ # Display Rationale and Interaction Summary below the columns
523
+ st.subheader("🧠 AI Rationale & Checks")
524
+ with st.expander("Show AI Reasoning Summary", expanded=False):
525
+ st.write(res.get("rationale_summary", "No rationale summary provided."))
526
+
527
+ interaction_summary = res.get("interaction_check_summary", "")
528
+ if interaction_summary: # Only show if interaction check was relevant/performed
529
+ with st.expander("Drug Interaction Check Summary", expanded=True):
530
+ st.write(interaction_summary)
531
+ # Also show detailed results from the tool call itself if available
532
+ for tool_res in st.session_state.tool_call_results:
533
+ if tool_res['name'] == 'check_drug_interactions':
534
+ try:
535
+ data = json.loads(tool_res['output'])
536
+ if data.get('warnings'):
537
+ st.warning("Interaction Details:")
538
+ for warn in data['warnings']:
539
+ st.caption(f"- {warn}")
540
+ else:
541
+ st.success("Interaction Details: " + data.get('message', 'Check complete.'))
542
+ except: pass # Ignore parsing errors here
543
+
544
+ # Display raw JSON if needed for debugging
545
+ with st.expander("Show Raw AI Output (JSON)"):
546
+ st.json(res)
547
+
548
+ st.divider()
549
+ st.success("Analysis Complete.")
550
+
551
+ # Disclaimer
552
+ st.markdown("---")
553
+ st.warning(
554
+ """**Disclaimer:** SynapseAI is an AI assistant for clinical decision support and does not replace professional medical judgment.
555
+ All outputs should be critically reviewed by a qualified healthcare provider before making any clinical decisions.
556
+ Verify all information, especially dosages and interactions, independently."""
557
+ )
558
+
559
 
560
  if __name__ == "__main__":
561
  main()