|
from fastapi import FastAPI, HTTPException, Request |
|
from pydantic import BaseModel |
|
from typing import List, Optional, Dict |
|
import gradio as gr |
|
import json |
|
from enum import Enum |
|
import re |
|
import os |
|
import time |
|
import gc |
|
from contextlib import asynccontextmanager |
|
from huggingface_hub import hf_hub_download |
|
from llama_cpp import Llama |
|
|
|
|
|
|
|
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "mradermacher/Llama3-Med42-8B-GGUF") |
|
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "Llama3-Med42-8B.Q5_K_M.gguf") |
|
N_THREADS = int(os.getenv("N_THREADS", "4")) |
|
|
|
|
|
class ConsultationState(Enum): |
|
INITIAL = "initial" |
|
GATHERING_INFO = "gathering_info" |
|
DIAGNOSIS = "diagnosis" |
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: str |
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[Message] |
|
|
|
class ChatResponse(BaseModel): |
|
response: str |
|
finished: bool |
|
|
|
|
|
HEALTH_ASSESSMENT_QUESTIONS = [ |
|
"What are your current symptoms and how long have you been experiencing them?", |
|
"Do you have any pre-existing medical conditions or chronic illnesses?", |
|
"Are you currently taking any medications? If yes, please list them.", |
|
"Is there any relevant family medical history I should know about?", |
|
"Have you had any similar symptoms in the past? If yes, what treatments worked?" |
|
] |
|
|
|
|
|
NURSE_OGE_IDENTITY = """ |
|
You are Nurse Oge, a medical AI assistant focused on serving patients in Nigeria. Always be empathetic, |
|
professional, and thorough in your assessments. When asked about your identity, explain that you are |
|
Nurse Oge, a medical AI assistant serving Nigerian communities. Remember that you must gather complete |
|
health information before providing any medical advice. |
|
""" |
|
|
|
class NurseOgeAssistant: |
|
""" |
|
Main assistant class that handles conversation management and medical consultations |
|
""" |
|
def __init__(self): |
|
try: |
|
|
|
self.llm = Llama.from_pretrained( |
|
repo_id=MODEL_REPO_ID, |
|
filename=MODEL_FILENAME, |
|
n_ctx=2048, |
|
n_threads=N_THREADS, |
|
n_gpu_layers=0 |
|
) |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Failed to initialize the model: {str(e)}") |
|
|
|
|
|
self.consultation_states = {} |
|
self.gathered_info = {} |
|
|
|
def _is_identity_question(self, message: str) -> bool: |
|
"""Detect if the user is asking about the assistant's identity""" |
|
identity_patterns = [ |
|
r"who are you", |
|
r"what are you", |
|
r"your name", |
|
r"what should I call you", |
|
r"tell me about yourself" |
|
] |
|
return any(re.search(pattern, message.lower()) for pattern in identity_patterns) |
|
|
|
def _is_location_question(self, message: str) -> bool: |
|
"""Detect if the user is asking about the assistant's location""" |
|
location_patterns = [ |
|
r"where are you", |
|
r"which country", |
|
r"your location", |
|
r"where do you work", |
|
r"where are you based" |
|
] |
|
return any(re.search(pattern, message.lower()) for pattern in location_patterns) |
|
|
|
def _get_next_assessment_question(self, conversation_id: str) -> Optional[str]: |
|
"""Get the next health assessment question based on conversation progress""" |
|
if conversation_id not in self.gathered_info: |
|
self.gathered_info[conversation_id] = [] |
|
|
|
questions_asked = len(self.gathered_info[conversation_id]) |
|
if questions_asked < len(HEALTH_ASSESSMENT_QUESTIONS): |
|
return HEALTH_ASSESSMENT_QUESTIONS[questions_asked] |
|
return None |
|
|
|
async def process_message(self, conversation_id: str, message: str, history: List[Dict]) -> ChatResponse: |
|
""" |
|
Process incoming messages and manage the conversation flow |
|
""" |
|
try: |
|
|
|
if conversation_id not in self.consultation_states: |
|
self.consultation_states[conversation_id] = ConsultationState.INITIAL |
|
|
|
|
|
if self._is_identity_question(message): |
|
return ChatResponse( |
|
response="I am Nurse Oge, a medical AI assistant dedicated to helping patients in Nigeria. " |
|
"I'm here to provide medical guidance while ensuring I gather all necessary health information " |
|
"for accurate assessments.", |
|
finished=True |
|
) |
|
|
|
|
|
if self._is_location_question(message): |
|
return ChatResponse( |
|
response="I am based in Nigeria and specifically trained to serve Nigerian communities, " |
|
"taking into account local healthcare contexts and needs.", |
|
finished=True |
|
) |
|
|
|
|
|
if self.consultation_states[conversation_id] == ConsultationState.INITIAL: |
|
self.consultation_states[conversation_id] = ConsultationState.GATHERING_INFO |
|
next_question = self._get_next_assessment_question(conversation_id) |
|
return ChatResponse( |
|
response=f"Before I can provide any medical advice, I need to gather some important health information. " |
|
f"{next_question}", |
|
finished=False |
|
) |
|
|
|
|
|
if self.consultation_states[conversation_id] == ConsultationState.GATHERING_INFO: |
|
self.gathered_info[conversation_id].append(message) |
|
next_question = self._get_next_assessment_question(conversation_id) |
|
|
|
if next_question: |
|
return ChatResponse( |
|
response=f"Thank you for that information. {next_question}", |
|
finished=False |
|
) |
|
else: |
|
self.consultation_states[conversation_id] = ConsultationState.DIAGNOSIS |
|
|
|
context = "\n".join([ |
|
f"Q: {q}\nA: {a}" for q, a in |
|
zip(HEALTH_ASSESSMENT_QUESTIONS, self.gathered_info[conversation_id]) |
|
]) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": NURSE_OGE_IDENTITY}, |
|
{"role": "user", "content": f"Based on the following patient information, provide thorough assessment, diagnosis and recommendations:\n\n{context}\n\nOriginal query: {message}"} |
|
] |
|
|
|
|
|
max_retries = 3 |
|
retry_delay = 2 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
response = self.llm.create_chat_completion( |
|
messages=messages, |
|
max_tokens=512, |
|
temperature=0.7, |
|
top_p=0.95, |
|
stop=["</s>"] |
|
) |
|
break |
|
except Exception as e: |
|
if attempt < max_retries - 1: |
|
time.sleep(retry_delay) |
|
continue |
|
return ChatResponse( |
|
response="I'm sorry, I'm experiencing some technical difficulties. Please try again in a moment.", |
|
finished=True |
|
) |
|
|
|
|
|
self.consultation_states[conversation_id] = ConsultationState.INITIAL |
|
self.gathered_info[conversation_id] = [] |
|
|
|
return ChatResponse( |
|
response=response['choices'][0]['message']['content'], |
|
finished=True |
|
) |
|
|
|
except Exception as e: |
|
return ChatResponse( |
|
response=f"An error occurred while processing your request. Please try again.", |
|
finished=True |
|
) |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
global nurse_oge |
|
try: |
|
nurse_oge = NurseOgeAssistant() |
|
except Exception as e: |
|
print(f"Failed to initialize NurseOgeAssistant: {e}") |
|
yield |
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.middleware("http") |
|
async def add_memory_management(request: Request, call_next): |
|
"""Middleware to help manage memory usage""" |
|
gc.collect() |
|
response = await call_next(request) |
|
gc.collect() |
|
return response |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Endpoint to verify service health""" |
|
return {"status": "healthy", "model_loaded": nurse_oge is not None} |
|
|
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(request: ChatRequest): |
|
"""Main chat endpoint for API interactions""" |
|
if nurse_oge is None: |
|
raise HTTPException( |
|
status_code=503, |
|
detail="The medical assistant is not available at the moment. Please try again later." |
|
) |
|
|
|
if not request.messages: |
|
raise HTTPException(status_code=400, detail="No messages provided") |
|
|
|
latest_message = request.messages[-1].content |
|
|
|
response = await nurse_oge.process_message( |
|
conversation_id="default", |
|
message=latest_message, |
|
history=request.messages[:-1] |
|
) |
|
|
|
return response |
|
|
|
|
|
async def gradio_chat(message, history): |
|
"""Handler for Gradio chat interface""" |
|
if nurse_oge is None: |
|
return "The medical assistant is not available at the moment. Please try again later." |
|
|
|
response = await nurse_oge.process_message("gradio_user", message, history) |
|
return response.response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=gradio_chat, |
|
title="Nurse Oge - Medical Assistant", |
|
description="""Welcome to Nurse Oge, your AI medical assistant specialized in serving Nigerian communities. |
|
This system provides medical guidance while ensuring comprehensive health information gathering.""", |
|
examples=[ |
|
["What are the common symptoms of malaria?"], |
|
["I've been having headaches for the past week"], |
|
["How can I prevent typhoid fever?"], |
|
], |
|
theme=gr.themes.Soft( |
|
primary_hue="blue", |
|
secondary_hue="purple", |
|
) |
|
) |
|
|
|
|
|
demo.css = """ |
|
.gradio-container { |
|
font-family: 'Arial', sans-serif; |
|
} |
|
.chat-message { |
|
padding: 1rem; |
|
border-radius: 0.5rem; |
|
margin-bottom: 0.5rem; |
|
} |
|
""" |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/gradio") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |