File size: 4,321 Bytes
c4f53c6
 
23943df
 
 
f7e81db
df2f743
c4f53c6
 
 
23943df
c4f53c6
 
 
 
 
df2f743
 
 
 
 
23943df
 
 
c4f53c6
23943df
 
 
 
c4f53c6
23943df
 
 
 
c4f53c6
23943df
c4f53c6
 
 
23943df
 
c4f53c6
ef0b933
df2f743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef0b933
c4f53c6
 
23943df
 
 
 
c4f53c6
 
23943df
 
df2f743
23943df
 
 
 
 
 
 
 
 
df2f743
23943df
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import random
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM

app = FastAPI()

# Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")

# New: Load Local LLM for Dynamic Emotional Responses (Mistral/Llama)
response_model_name = "mistralai/Mistral-7B-Instruct"
response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
response_model = AutoModelForCausalLM.from_pretrained(response_model_name)

# Load datasets
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")

# FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# Request Model
class ChatRequest(BaseModel):
    message: str

class SummaryRequest(BaseModel):
    chat_history: list  # List of messages


@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
    """Retrieve the most relevant diagnostic questions with a dynamically generated conversational response."""
    
    # Step 1: Encode the input message for FAISS search
    input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
    distances, indices = question_index.search(input_embedding, 3)

    # Step 2: Retrieve the top 3 relevant questions
    retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]

    # Step 3: Use a local LLM to generate context-aware empathetic responses
    prompt = f"""
    User: {request.message}
    
    You are a compassionate psychiatric assistant. Before asking a diagnostic question, respond empathetically.
    
    Questions:
    1. {retrieved_questions[0]}
    2. {retrieved_questions[1]}
    3. {retrieved_questions[2]}
    
    Generate a conversational response that introduces each question naturally.
    """

    inputs = response_tokenizer(prompt, return_tensors="pt")
    output = response_model.generate(**inputs, max_length=300)
    enhanced_responses = response_tokenizer.decode(output[0], skip_special_tokens=True).split("\n")

    return {"questions": enhanced_responses}


@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
    """Summarize full chat session at the end."""
    chat_text = " ".join(request.chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return {"summary": summary}


@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
    """Detect psychiatric disorders from full chat history at the end."""
    full_chat_text = " ".join(request.chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)
    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return {"disorders": disorders}


@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
    """Retrieve treatment recommendations based on detected disorders."""
    detected_disorders = detect_disorders(request)["disorders"]
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders
    }
    return {"treatments": treatments}