File size: 4,321 Bytes
c4f53c6 23943df f7e81db df2f743 c4f53c6 23943df c4f53c6 df2f743 23943df c4f53c6 23943df c4f53c6 23943df c4f53c6 23943df c4f53c6 23943df c4f53c6 ef0b933 df2f743 ef0b933 c4f53c6 23943df c4f53c6 23943df df2f743 23943df df2f743 23943df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import random
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
# New: Load Local LLM for Dynamic Emotional Responses (Mistral/Llama)
response_model_name = "mistralai/Mistral-7B-Instruct"
response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
response_model = AutoModelForCausalLM.from_pretrained(response_model_name)
# Load datasets
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
# FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
# Request Model
class ChatRequest(BaseModel):
message: str
class SummaryRequest(BaseModel):
chat_history: list # List of messages
@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
"""Retrieve the most relevant diagnostic questions with a dynamically generated conversational response."""
# Step 1: Encode the input message for FAISS search
input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
distances, indices = question_index.search(input_embedding, 3)
# Step 2: Retrieve the top 3 relevant questions
retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
# Step 3: Use a local LLM to generate context-aware empathetic responses
prompt = f"""
User: {request.message}
You are a compassionate psychiatric assistant. Before asking a diagnostic question, respond empathetically.
Questions:
1. {retrieved_questions[0]}
2. {retrieved_questions[1]}
3. {retrieved_questions[2]}
Generate a conversational response that introduces each question naturally.
"""
inputs = response_tokenizer(prompt, return_tensors="pt")
output = response_model.generate(**inputs, max_length=300)
enhanced_responses = response_tokenizer.decode(output[0], skip_special_tokens=True).split("\n")
return {"questions": enhanced_responses}
@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
"""Summarize full chat session at the end."""
chat_text = " ".join(request.chat_history)
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
"""Detect psychiatric disorders from full chat history at the end."""
full_chat_text = " ".join(request.chat_history)
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
distances, indices = index.search(text_embedding, 3)
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
return {"disorders": disorders}
@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
"""Retrieve treatment recommendations based on detected disorders."""
detected_disorders = detect_disorders(request)["disorders"]
treatments = {
disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
for disorder in detected_disorders
}
return {"treatments": treatments}
|