|
import streamlit as st |
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from typing import TypedDict, Annotated, List |
|
|
|
|
|
class MedicalConfig: |
|
SYSTEM_PROMPT = """You are an AI clinical assistant. Follow these rules: |
|
1. Analyze patient data using latest medical guidelines |
|
2. Always check for drug interactions |
|
3. Suggest tests only when necessary |
|
4. Use structured actions: |
|
- order_lab_test: {test_name, reason} |
|
- prescribe_medication: {name, dosage, frequency}""" |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[List[dict], lambda l, r: l + r] |
|
patient_data: dict |
|
|
|
class MedicalAgent: |
|
def __init__(self): |
|
self.model = ChatGroq(temperature=0.1, model="Llama-3.3-70b-Specdec") |
|
self.tools = { |
|
"medical_actions": { |
|
"order_lab_test": self.order_lab_test, |
|
"prescribe_medication": self.prescribe_medication |
|
}, |
|
"research": TavilySearchResults(max_results=3) |
|
} |
|
|
|
def analyze_patient(self, patient_data): |
|
response = self.model.invoke([ |
|
SystemMessage(content=MedicalConfig.SYSTEM_PROMPT), |
|
HumanMessage(content=f"Patient Data: {patient_data}") |
|
]) |
|
return response |
|
|
|
def process_action(self, action): |
|
if action['name'] in self.tools['medical_actions']: |
|
return self.tools['medical_actions'][action['name']](action['args']) |
|
return "Unknown action" |
|
|
|
def order_lab_test(self, params): |
|
return f"Lab ordered: {params['test_name']} ({params['reason']})" |
|
|
|
def prescribe_medication(self, params): |
|
return f"Prescribed: {params['name']} {params['dosage']} {params['frequency']}" |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="AI Clinic", layout="wide") |
|
|
|
|
|
if 'agent' not in st.session_state: |
|
st.session_state.agent = MedicalAgent() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Patient Intake") |
|
patient_data = { |
|
"symptoms": st.multiselect("Symptoms", ["Fever", "Cough", "Chest Pain"]), |
|
"history": { |
|
"conditions": st.text_input("Medical History"), |
|
"medications": st.text_input("Current Medications") |
|
}, |
|
"vitals": { |
|
"temp": st.number_input("Temp (°C)", 35.0, 42.0, 37.0), |
|
"bp": st.text_input("BP (mmHg)", "120/80") |
|
} |
|
} |
|
|
|
|
|
st.title("AI-Powered Clinical Support System") |
|
|
|
if st.button("Start Analysis"): |
|
with st.spinner("Analyzing patient data..."): |
|
response = st.session_state.agent.analyze_patient(patient_data) |
|
|
|
if hasattr(response, 'tool_calls'): |
|
for action in response.tool_calls: |
|
result = st.session_state.agent.process_action(action) |
|
st.success(result) |
|
else: |
|
st.info(response.content) |
|
|
|
if __name__ == "__main__": |
|
main() |