File size: 6,097 Bytes
f6e0c20
93407fc
 
 
 
 
f6e0c20
93407fc
 
 
f6e0c20
93407fc
 
f6e0c20
 
 
 
 
 
93407fc
 
 
 
 
 
 
 
f6e0c20
 
 
 
 
 
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
93407fc
 
 
f6e0c20
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
 
 
 
 
 
 
 
 
 
 
 
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
 
93407fc
 
 
 
 
 
 
 
 
 
f6e0c20
93407fc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import os
import logging
from groq import Groq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Securely Fetch API Key
GROQ_API_KEY = os.getenv("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")  # Use environment variable for security
if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.")

client = Groq(api_key=GROQ_API_KEY)

# βœ… Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")

# βœ… Load datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    logging.error(f"Missing dataset file: {e}")
    raise HTTPException(status_code=500, detail="Dataset files are missing.")

# βœ… FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Request Model
class ChatRequest(BaseModel):
    message: str

class SummaryRequest(BaseModel):
    chat_history: list  # List of messages

# βœ… Retrieve the most relevant question
def retrieve_questions(user_input):
    """Retrieve the most relevant individual diagnostic question using FAISS."""
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)  # βœ… Retrieve only 1 question

    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."

    # βœ… Extract only the first meaningful question
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    split_questions = question_block.split(", ")  
    best_question = split_questions[0] if split_questions else question_block  # βœ… Select the first clear question

    return best_question  # βœ… Return a single question as a string

# βœ… Groq API for rephrasing
def generate_empathetic_response(user_input, retrieved_question):
    """Use Groq API (LLaMA-3) to generate one empathetic response."""
    
    # βœ… Improved Prompt: Only One Question
    prompt = f"""
    The user said: "{user_input}"
    
    Relevant Question:
    - {retrieved_question}

    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    Acknowledge the user's emotions before asking the question.

    Example format:
    - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
    
    Generate only one empathetic response.
    """

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
                {"role": "user", "content": prompt}
            ],
            model="llama3-8b",  # βœ… Use Groq's LLaMA-3 Model
            temperature=0.8,
            top_p=0.9
        )
        return chat_completion.choices[0].message.content  # βœ… Return only one response
    except Exception as e:
        logging.error(f"Groq API error: {e}")
        return "I'm sorry, I couldn't process your request."

# βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
    """Retrieve the most relevant diagnostic question and make it more empathetic using Groq API."""
    retrieved_question = retrieve_questions(request.message)
    empathetic_response = generate_empathetic_response(request.message, retrieved_question)
    
    return {"question": empathetic_response}

# βœ… API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
    """Summarize full chat session at the end."""
    chat_text = " ".join(request.chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return {"summary": summary}

# βœ… API Endpoint: Detect Disorders
@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
    """Detect psychiatric disorders from full chat history at the end."""
    full_chat_text = " ".join(request.chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)

    if indices[0][0] == -1:
        return {"disorders": "No matching disorder found."}

    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return {"disorders": disorders}

# βœ… API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
    """Retrieve treatment recommendations based on detected disorders."""
    detected_disorders = detect_disorders(request)["disorders"]
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
    }
    return {"treatments": treatments}