mgbam commited on
Commit
63b0a52
·
verified ·
1 Parent(s): 03e1441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -86
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
  from langchain_core.messages import HumanMessage, SystemMessage
 
5
  from typing import TypedDict, Annotated, List, Optional
6
  import json
7
 
@@ -10,39 +11,55 @@ class MedicalConfig:
10
  SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
11
  1. Analyze patient data using latest medical guidelines
12
  2. Always check for drug interactions
13
- 3. Suggest tests only when necessary
14
- 4. Use structured actions:
15
- - order_lab_test: {test_name, reason}
16
- - prescribe_medication: {name, dosage, frequency}
17
- 5. Flag high-risk conditions immediately"""
18
 
19
  RED_FLAGS = {
20
  "symptoms": ["chest pain", "shortness of breath", "severe headache"],
21
  "vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
22
  }
23
 
24
- # State Management
25
- class AgentState(TypedDict):
26
- messages: Annotated[List[dict], lambda l, r: l + r]
27
- patient_data: dict
28
- approvals: Annotated[List[dict], lambda l, r: l + r]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  class MedicalAgent:
31
  def __init__(self):
32
  self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
33
- self.tools = {
34
- "medical_actions": {
35
- "order_lab_test": self.order_lab_test,
36
- "prescribe_medication": self.prescribe_medication
37
- },
38
- "research": TavilySearchResults(max_results=3)
39
- }
40
 
41
  def analyze_patient(self, patient_data: dict) -> Optional[dict]:
42
  try:
43
- # Bind tools to the model
44
- model_with_tools = self.model.bind_tools(list(self.tools["medical_actions"].keys()))
45
- response = model_with_tools.invoke([
46
  SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
47
  HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
48
  ])
@@ -53,72 +70,16 @@ class MedicalAgent:
53
 
54
  def process_action(self, action: dict) -> str:
55
  try:
56
- if action['name'] in self.tools['medical_actions']:
57
- return self.tools['medical_actions'][action['name']](action['args'])
58
- return f"Unknown action: {action['name']}"
 
 
 
 
 
 
59
  except Exception as e:
60
  return f"Error processing action: {str(e)}"
61
 
62
- def order_lab_test(self, params: dict) -> str:
63
- return f"✅ Lab ordered: {params['test_name']} ({params['reason']})"
64
-
65
- def prescribe_medication(self, params: dict) -> str:
66
- return f"✅ Prescribed: {params['name']} {params['dosage']} {params['frequency']}"
67
-
68
- def check_red_flags(self, patient_data: dict) -> List[str]:
69
- flags = []
70
- for symptom in patient_data.get("symptoms", []):
71
- if symptom.lower() in MedicalConfig.RED_FLAGS["symptoms"]:
72
- flags.append(f"🚨 High-risk symptom: {symptom}")
73
- return flags
74
-
75
- # Streamlit UI
76
- def main():
77
- st.set_page_config(page_title="AI Clinic", layout="wide")
78
-
79
- # Initialize agent
80
- if 'agent' not in st.session_state:
81
- st.session_state.agent = MedicalAgent()
82
-
83
- # Patient intake
84
- with st.sidebar:
85
- st.header("Patient Intake")
86
- patient_data = {
87
- "symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain", "Shortness of Breath"]),
88
- "history": {
89
- "conditions": st.text_input("Medical History"),
90
- "medications": st.text_input("Current Medications")
91
- },
92
- "vitals": {
93
- "temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0),
94
- "bp": st.text_input("BP (mmHg)", "120/80")
95
- }
96
- }
97
-
98
- # Main interface
99
- st.title("AI-Powered Clinical Support System")
100
-
101
- # Check for red flags
102
- red_flags = st.session_state.agent.check_red_flags(patient_data)
103
- if red_flags:
104
- st.warning("High-Risk Conditions Detected")
105
- for flag in red_flags:
106
- st.error(flag)
107
-
108
- if st.button("Start Analysis"):
109
- with st.spinner("Analyzing patient data..."):
110
- response = st.session_state.agent.analyze_patient(patient_data)
111
-
112
- if response:
113
- st.subheader("Analysis Results")
114
- if hasattr(response, 'tool_calls') and response.tool_calls:
115
- for action in response.tool_calls:
116
- result = st.session_state.agent.process_action(action)
117
- st.success(result)
118
- else:
119
- st.info("No specific actions recommended. Here's the analysis:")
120
- st.write(response.content)
121
-
122
- # Run the app
123
- if __name__ == "__main__":
124
- main()
 
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
  from langchain_core.messages import HumanMessage, SystemMessage
5
+ from langchain_core.tools import tool
6
  from typing import TypedDict, Annotated, List, Optional
7
  import json
8
 
 
11
  SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules:
12
  1. Analyze patient data using latest medical guidelines
13
  2. Always check for drug interactions
14
+ 3. Use structured actions when needed:
15
+ - lab_order: Order lab tests
16
+ - prescribe: Prescribe medication
17
+ 4. Flag high-risk conditions immediately"""
 
18
 
19
  RED_FLAGS = {
20
  "symptoms": ["chest pain", "shortness of breath", "severe headache"],
21
  "vitals": {"temp": (38.5, "°C"), "hr": (120, "bpm"), "bp": ("180/120", "mmHg")}
22
  }
23
 
24
+ # Define tools with proper schemas
25
+ @tool
26
+ def order_lab_test(test_name: str, reason: str) -> str:
27
+ """Orders a lab test with specified parameters.
28
+
29
+ Args:
30
+ test_name: Name of the lab test to order
31
+ reason: Clinical justification for the test
32
+
33
+ Returns:
34
+ Confirmation message with test details
35
+ """
36
+ return f"Lab ordered: {test_name} ({reason})"
37
+
38
+ @tool
39
+ def prescribe_medication(name: str, dosage: str, frequency: str) -> str:
40
+ """Prescribes medication with specific dosage instructions.
41
+
42
+ Args:
43
+ name: Name of the medication
44
+ dosage: Dosage amount (e.g., '500mg')
45
+ frequency: Administration frequency (e.g., 'every 6 hours')
46
+
47
+ Returns:
48
+ Confirmation message with prescription details
49
+ """
50
+ return f"Prescribed: {name} {dosage} {frequency}"
51
+
52
+ # Initialize tools
53
+ tools = [order_lab_test, prescribe_medication, TavilySearchResults(max_results=3)]
54
 
55
  class MedicalAgent:
56
  def __init__(self):
57
  self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec")
58
+ self.model_with_tools = self.model.bind_tools(tools)
 
 
 
 
 
 
59
 
60
  def analyze_patient(self, patient_data: dict) -> Optional[dict]:
61
  try:
62
+ response = self.model_with_tools.invoke([
 
 
63
  SystemMessage(content=MedicalConfig.SYSTEM_PROMPT),
64
  HumanMessage(content=f"Patient Data: {json.dumps(patient_data)}")
65
  ])
 
70
 
71
  def process_action(self, action: dict) -> str:
72
  try:
73
+ tool_name = action['name']
74
+ args = action['args']
75
+
76
+ if tool_name == "order_lab_test":
77
+ return order_lab_test.invoke(args)
78
+ elif tool_name == "prescribe_medication":
79
+ return prescribe_medication.invoke(args)
80
+ else:
81
+ return f"Unknown action: {tool_name}"
82
  except Exception as e:
83
  return f"Error processing action: {str(e)}"
84
 
85
+ # Rest of the Streamlit UI code remains the same...