from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict from llama_cpp import Llama import gradio as gr import json from enum import Enum import re class ConsultationState(Enum): INITIAL = "initial" GATHERING_INFO = "gathering_info" DIAGNOSIS = "diagnosis" class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] class ChatResponse(BaseModel): response: str finished: bool # Standard health assessment questions that Nurse Oge always asks HEALTH_ASSESSMENT_QUESTIONS = [ "What are your current symptoms and how long have you been experiencing them?", "Do you have any pre-existing medical conditions or chronic illnesses?", "Are you currently taking any medications? If yes, please list them.", "Is there any relevant family medical history I should know about?", "Have you had any similar symptoms in the past? If yes, what treatments worked?" ] # Personality prompts for Nurse Oge NURSE_OGE_IDENTITY = """ You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic, professional, and thorough in your assessments. When asked about your identity, explain that you are Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete health information before providing any medical advice. """ class NurseOgeAssistant: def __init__(self): self.llm = Llama.from_pretrained( repo_id="mradermacher/Llama3-Med42-8B-GGUF", filename="Llama3-Med42-8B.IQ3_M.gguf", verbose=False ) self.consultation_states = {} # Tracks state for each conversation self.gathered_info = {} # Stores gathered health information def _is_identity_question(self, message: str) -> bool: identity_patterns = [ r"who are you", r"what are you", r"your name", r"what should I call you", r"tell me about yourself" ] return any(re.search(pattern, message.lower()) for pattern in identity_patterns) def _is_location_question(self, message: str) -> bool: location_patterns = [ r"where are you", r"which country", r"your location", r"where do you work", r"where are you based" ] return any(re.search(pattern, message.lower()) for pattern in location_patterns) def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]: if conversation_id not in self.gathered_info: self.gathered_info[conversation_id] = [] questions_asked = len(self.gathered_info[conversation_id]) if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS): return HEALTH_ASSESSMENT_QUESTIONS[questions_asked] return None async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse: # Initialize state if new conversation if conversation_id not in self.consultation_states: self.consultation_states[conversation_id] = ConsultationState.INITIAL # Handle identity questions if self._is_identity_question(message): return ChatResponse( response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. " "I'm here to provide medical guidance while ensuring I gather all necessary health information " "for accurate assessments.", finished=True ) # Handle location questions if self._is_location_question(message): return ChatResponse( response="I am based in Nigeria and specifically trained to serve Nigerian communities, " "taking into account local healthcare contexts and needs.", finished=True ) # Start health assessment if it's a medical query if self.consultation_states[conversation_id] == ConsultationState.INITIAL: self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO next_question = self._get_next_assessment_question(conversation_id) return ChatResponse( response=f"Before I can provide any medical advice, I need to gather some important health information. " f"{next_question}", finished=False ) # Continue gathering information if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO: self.gathered_info[conversation_id].append(message) next_question = self._get_next_assessment_question(conversation_id) if next_question: return ChatResponse( response=f"Thank you for that information. {next_question}", finished=False ) else: self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS # Prepare complete context for final response context = "\n".join([ f"Q: {q}\nA: {a}" for q, a in zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id]) ]) # Generate final response using the model messages = [ {"role": "system", "content": NURSE_OGE_IDENTITY}, {"role": "user", "content": f"Based on the following patient information, provide a thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"} ] response = self.llm.create_chat_completion( messages=messages, max_tokens=1024, temperature=0.7 ) # Reset state for next consultation self.consultation_states[conversation_id] = ConsultationState.INITIAL self.gathered_info[conversation_id] = [] return ChatResponse( response=response['choices'][0]['message']['content'], finished=True ) # Initialize FastAPI and Nurse Oge app = FastAPI() nurse_oge = NurseOgeAssistant() @app.post("/chat") async def chat_endpoint(request: ChatRequest): # Generate a conversation ID (in a real app, you'd want to manage these better) conversation_id = "default" # Extract the latest message if not request.messages: raise HTTPException(status_code=400, detail="No messages provided") latest_message = request.messages[-1].content # Process the message response = await nurse_oge.process_message( conversation_id=conversation_id, message=latest_message, history=request.messages[:-1] ) return response # Initialize Gradio interface (optional, for testing) def gradio_chat(message, history): response = nurse_oge.process_message("gradio_user", message, history) return response.response demo = gr.ChatInterface( fn=gradio_chat, title="Nurse Oge", description="Finetuned llama 3.0 for medical diagnosis and all. This is just a demo", theme="soft" ) # Mount both FastAPI and Gradio app = gr.mount_gradio_app(app, demo, path="/gradio") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)