File size: 7,481 Bytes
244be9c
b93cc95
 
 
244be9c
b93cc95
244be9c
b93cc95
 
244be9c
 
96f1d8c
 
 
 
 
 
244be9c
 
b93cc95
96f1d8c
244be9c
b93cc95
244be9c
 
b93cc95
96f1d8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244be9c
 
 
d7eec32
244be9c
 
 
 
 
 
 
 
 
 
 
 
 
 
b93cc95
244be9c
 
 
 
b93cc95
 
 
 
 
 
244be9c
 
b93cc95
244be9c
 
 
 
 
b93cc95
 
 
 
244be9c
b93cc95
244be9c
 
 
b93cc95
244be9c
b93cc95
 
244be9c
 
 
b93cc95
244be9c
 
 
 
 
b93cc95
244be9c
 
b93cc95
244be9c
b93cc95
244be9c
 
b93cc95
244be9c
 
 
 
 
b93cc95
 
244be9c
 
 
 
 
 
 
 
b93cc95
244be9c
 
 
 
 
ba5898e
 
b93cc95
244be9c
ba5898e
244be9c
 
ba5898e
 
 
244be9c
ba5898e
244be9c
 
ba5898e
244be9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# app.py
import os
import requests
import json
import logging
import pandas as pd
import faiss
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# βœ… Set a writable cache directory inside the container
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/app/cache"

# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Securely Fetch API Key
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
    raise ValueError("❌ OPENROUTER_API_KEY is missing. Set it as an environment variable.")

OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"

# βœ… Load AI Models with explicit caching & remote code trust
try:
    embedding_model = SentenceTransformer(
        "sentence-transformers/all-MiniLM-L6-v2",
        cache_folder="/app/cache",
        trust_remote_code=True  # βœ… Fix potential caching issues
    )
    summarization_model = AutoModelForSeq2SeqLM.from_pretrained(
        "google/long-t5-tglobal-base",
        cache_dir="/app/cache",
        trust_remote_code=True  # βœ… Trust remote code
    )
    summarization_tokenizer = AutoTokenizer.from_pretrained(
        "google/long-t5-tglobal-base",
        cache_dir="/app/cache",
        trust_remote_code=True
    )
    print("βœ… Models Loaded Successfully!")
except Exception as e:
    print(f"❌ Model loading error: {e}")

# βœ… API Health Check
@app.get("/")
def health_check():
    return {"status": "FastAPI is running!"}

# βœ… Load Datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
    print("βœ… Datasets Loaded Successfully!")
except FileNotFoundError as e:
    logging.error(f"❌ Missing dataset file: {e}")
    raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")

# βœ… Create FAISS Indexes
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

treatment_embeddings = embedding_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… Chat History Storage
chat_history = []

# βœ… Request Models
class ChatRequest(BaseModel):
    message: str

class SummaryRequest(BaseModel):
    chat_history: list


# βœ… Function: Call DeepSeek via OpenRouter
def deepseek_request(prompt, max_tokens=300):
    """Send a request to OpenRouter's DeepSeek model."""
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "deepseek/deepseek-r1-distill-llama-8b",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": 0.8
    }
    try:
        response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
        response.raise_for_status()
        response_json = response.json()
        
        if "choices" in response_json and response_json["choices"]:
            return response_json["choices"][0].get("message", {}).get("content", "").strip()
    except Exception as e:
        logging.error(f"OpenRouter DeepSeek API error: {e}")
        return "I'm here to support you. Can you share more about what you're feeling?"

# βœ… Function: Retrieve Relevant Diagnostic Question
def retrieve_relevant_question(user_input):
    """Find the most relevant diagnostic question from the dataset using FAISS."""
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)

    if indices[0][0] == -1:
        return "I'm here to listen. Can you tell me more about your symptoms?"

    return questions_df["Questions"].iloc[indices[0][0]]

# βœ… API Endpoint: Chat Interaction
@app.post("/get_questions")
def chat(request: ChatRequest):
    """Patient enters data, AI responds and stores conversation."""
    user_message = request.message
    chat_history.append(user_message)

    # Constructing the DeepSeek prompt
    prompt = f"""
    You are an AI psychiatrist conducting a mental health consultation.
    Engage in a supportive, natural conversation, maintaining an empathetic tone.

    - Always provide a thoughtful and compassionate response.
    - If a user shares distressing emotions, acknowledge their feelings and ask relevant follow-up questions.

    Previous conversation:
    {chat_history}

    User input:
    "{user_message}"

    Generate:
    - An empathetic response.
    - A related follow-up question.
    
    Ensure your response is meaningful and NEVER empty.
    """

    # Call DeepSeek API
    ai_response = deepseek_request(prompt, max_tokens=250)

    # βœ… Ensure response is NEVER empty
    if not ai_response or ai_response.strip() == "":
        ai_response = "I'm here to listen. Can you tell me more about how you're feeling? Maybe I can help."

    chat_history.append(ai_response)
    return {"response": ai_response}


# βœ… API Endpoint: Detect Disorders from Chat History
@app.post("/detect_disorders")
def detect_disorders():
    """Detect psychiatric disorders based on full chat history."""
    full_chat_text = " ".join(chat_history)
    text_embedding = embedding_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)

    if indices[0][0] == -1:
        return {"disorders": ["No matching disorder found."]}

    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return {"disorders": disorders}

# βœ… API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment():
    """Retrieve treatment recommendations based on detected disorders."""
    detected_disorders = detect_disorders()["disorders"]
    treatments = {}

    for disorder in detected_disorders:
        if disorder in recommendations_df["Disorder"].values:
            treatments[disorder] = recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        else:
            # Generate treatment if not in dataset
            treatment_prompt = f"""
            The user has been diagnosed with {disorder}. Provide a structured treatment plan including:

            - **Therapy options** (CBT, psychotherapy, etc.).
            - **Medications** (if applicable).
            - **Lifestyle strategies** (exercise, mindfulness, etc.).
            - **When to seek professional help**.
            - **Encouragement**.

            Ensure your response is clear and medically sound.
            """
            treatments[disorder] = deepseek_request(treatment_prompt, max_tokens=250)

    return {"treatments": treatments}

# βœ… API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat():
    """Summarize full chat session using DeepSeek."""
    chat_text = " ".join(chat_history)
    summary_prompt = f"The following is a conversation between a patient and an AI psychiatrist. Summarize it clearly:\n{chat_text}"
    summary = deepseek_request(summary_prompt, max_tokens=500)
    return {"summary": summary}