Spaces:
Runtime error
Runtime error
File size: 7,481 Bytes
244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c 96f1d8c 244be9c b93cc95 96f1d8c 244be9c b93cc95 244be9c b93cc95 96f1d8c 244be9c d7eec32 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c b93cc95 244be9c ba5898e b93cc95 244be9c ba5898e 244be9c ba5898e 244be9c ba5898e 244be9c ba5898e 244be9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# app.py
import os
import requests
import json
import logging
import pandas as pd
import faiss
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# β
Set a writable cache directory inside the container
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/app/cache"
# β
Initialize FastAPI
app = FastAPI()
# β
Securely Fetch API Key
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
raise ValueError("β OPENROUTER_API_KEY is missing. Set it as an environment variable.")
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
# β
Load AI Models with explicit caching & remote code trust
try:
embedding_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder="/app/cache",
trust_remote_code=True # β
Fix potential caching issues
)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(
"google/long-t5-tglobal-base",
cache_dir="/app/cache",
trust_remote_code=True # β
Trust remote code
)
summarization_tokenizer = AutoTokenizer.from_pretrained(
"google/long-t5-tglobal-base",
cache_dir="/app/cache",
trust_remote_code=True
)
print("β
Models Loaded Successfully!")
except Exception as e:
print(f"β Model loading error: {e}")
# β
API Health Check
@app.get("/")
def health_check():
return {"status": "FastAPI is running!"}
# β
Load Datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
print("β
Datasets Loaded Successfully!")
except FileNotFoundError as e:
logging.error(f"β Missing dataset file: {e}")
raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")
# β
Create FAISS Indexes
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
treatment_embeddings = embedding_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# β
Chat History Storage
chat_history = []
# β
Request Models
class ChatRequest(BaseModel):
message: str
class SummaryRequest(BaseModel):
chat_history: list
# β
Function: Call DeepSeek via OpenRouter
def deepseek_request(prompt, max_tokens=300):
"""Send a request to OpenRouter's DeepSeek model."""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "deepseek/deepseek-r1-distill-llama-8b",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.8
}
try:
response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
response.raise_for_status()
response_json = response.json()
if "choices" in response_json and response_json["choices"]:
return response_json["choices"][0].get("message", {}).get("content", "").strip()
except Exception as e:
logging.error(f"OpenRouter DeepSeek API error: {e}")
return "I'm here to support you. Can you share more about what you're feeling?"
# β
Function: Retrieve Relevant Diagnostic Question
def retrieve_relevant_question(user_input):
"""Find the most relevant diagnostic question from the dataset using FAISS."""
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1)
if indices[0][0] == -1:
return "I'm here to listen. Can you tell me more about your symptoms?"
return questions_df["Questions"].iloc[indices[0][0]]
# β
API Endpoint: Chat Interaction
@app.post("/get_questions")
def chat(request: ChatRequest):
"""Patient enters data, AI responds and stores conversation."""
user_message = request.message
chat_history.append(user_message)
# Constructing the DeepSeek prompt
prompt = f"""
You are an AI psychiatrist conducting a mental health consultation.
Engage in a supportive, natural conversation, maintaining an empathetic tone.
- Always provide a thoughtful and compassionate response.
- If a user shares distressing emotions, acknowledge their feelings and ask relevant follow-up questions.
Previous conversation:
{chat_history}
User input:
"{user_message}"
Generate:
- An empathetic response.
- A related follow-up question.
Ensure your response is meaningful and NEVER empty.
"""
# Call DeepSeek API
ai_response = deepseek_request(prompt, max_tokens=250)
# β
Ensure response is NEVER empty
if not ai_response or ai_response.strip() == "":
ai_response = "I'm here to listen. Can you tell me more about how you're feeling? Maybe I can help."
chat_history.append(ai_response)
return {"response": ai_response}
# β
API Endpoint: Detect Disorders from Chat History
@app.post("/detect_disorders")
def detect_disorders():
"""Detect psychiatric disorders based on full chat history."""
full_chat_text = " ".join(chat_history)
text_embedding = embedding_model.encode([full_chat_text], convert_to_numpy=True)
distances, indices = index.search(text_embedding, 3)
if indices[0][0] == -1:
return {"disorders": ["No matching disorder found."]}
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
return {"disorders": disorders}
# β
API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment():
"""Retrieve treatment recommendations based on detected disorders."""
detected_disorders = detect_disorders()["disorders"]
treatments = {}
for disorder in detected_disorders:
if disorder in recommendations_df["Disorder"].values:
treatments[disorder] = recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
else:
# Generate treatment if not in dataset
treatment_prompt = f"""
The user has been diagnosed with {disorder}. Provide a structured treatment plan including:
- **Therapy options** (CBT, psychotherapy, etc.).
- **Medications** (if applicable).
- **Lifestyle strategies** (exercise, mindfulness, etc.).
- **When to seek professional help**.
- **Encouragement**.
Ensure your response is clear and medically sound.
"""
treatments[disorder] = deepseek_request(treatment_prompt, max_tokens=250)
return {"treatments": treatments}
# β
API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat():
"""Summarize full chat session using DeepSeek."""
chat_text = " ".join(chat_history)
summary_prompt = f"The following is a conversation between a patient and an AI psychiatrist. Summarize it clearly:\n{chat_text}"
summary = deepseek_request(summary_prompt, max_tokens=500)
return {"summary": summary}
|