mindspark121's picture
Update app.py
ba5898e verified
# app.py
import os
import requests
import json
import logging
import pandas as pd
import faiss
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# βœ… Set a writable cache directory inside the container
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/app/cache"
# βœ… Initialize FastAPI
app = FastAPI()
# βœ… Securely Fetch API Key
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
raise ValueError("❌ OPENROUTER_API_KEY is missing. Set it as an environment variable.")
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
# βœ… Load AI Models with explicit caching & remote code trust
try:
embedding_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder="/app/cache",
trust_remote_code=True # βœ… Fix potential caching issues
)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(
"google/long-t5-tglobal-base",
cache_dir="/app/cache",
trust_remote_code=True # βœ… Trust remote code
)
summarization_tokenizer = AutoTokenizer.from_pretrained(
"google/long-t5-tglobal-base",
cache_dir="/app/cache",
trust_remote_code=True
)
print("βœ… Models Loaded Successfully!")
except Exception as e:
print(f"❌ Model loading error: {e}")
# βœ… API Health Check
@app.get("/")
def health_check():
return {"status": "FastAPI is running!"}
# βœ… Load Datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
print("βœ… Datasets Loaded Successfully!")
except FileNotFoundError as e:
logging.error(f"❌ Missing dataset file: {e}")
raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")
# βœ… Create FAISS Indexes
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
treatment_embeddings = embedding_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# βœ… Chat History Storage
chat_history = []
# βœ… Request Models
class ChatRequest(BaseModel):
message: str
class SummaryRequest(BaseModel):
chat_history: list
# βœ… Function: Call DeepSeek via OpenRouter
def deepseek_request(prompt, max_tokens=300):
"""Send a request to OpenRouter's DeepSeek model."""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "deepseek/deepseek-r1-distill-llama-8b",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.8
}
try:
response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
response.raise_for_status()
response_json = response.json()
if "choices" in response_json and response_json["choices"]:
return response_json["choices"][0].get("message", {}).get("content", "").strip()
except Exception as e:
logging.error(f"OpenRouter DeepSeek API error: {e}")
return "I'm here to support you. Can you share more about what you're feeling?"
# βœ… Function: Retrieve Relevant Diagnostic Question
def retrieve_relevant_question(user_input):
"""Find the most relevant diagnostic question from the dataset using FAISS."""
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1)
if indices[0][0] == -1:
return "I'm here to listen. Can you tell me more about your symptoms?"
return questions_df["Questions"].iloc[indices[0][0]]
# βœ… API Endpoint: Chat Interaction
@app.post("/get_questions")
def chat(request: ChatRequest):
"""Patient enters data, AI responds and stores conversation."""
user_message = request.message
chat_history.append(user_message)
# Constructing the DeepSeek prompt
prompt = f"""
You are an AI psychiatrist conducting a mental health consultation.
Engage in a supportive, natural conversation, maintaining an empathetic tone.
- Always provide a thoughtful and compassionate response.
- If a user shares distressing emotions, acknowledge their feelings and ask relevant follow-up questions.
Previous conversation:
{chat_history}
User input:
"{user_message}"
Generate:
- An empathetic response.
- A related follow-up question.
Ensure your response is meaningful and NEVER empty.
"""
# Call DeepSeek API
ai_response = deepseek_request(prompt, max_tokens=250)
# βœ… Ensure response is NEVER empty
if not ai_response or ai_response.strip() == "":
ai_response = "I'm here to listen. Can you tell me more about how you're feeling? Maybe I can help."
chat_history.append(ai_response)
return {"response": ai_response}
# βœ… API Endpoint: Detect Disorders from Chat History
@app.post("/detect_disorders")
def detect_disorders():
"""Detect psychiatric disorders based on full chat history."""
full_chat_text = " ".join(chat_history)
text_embedding = embedding_model.encode([full_chat_text], convert_to_numpy=True)
distances, indices = index.search(text_embedding, 3)
if indices[0][0] == -1:
return {"disorders": ["No matching disorder found."]}
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
return {"disorders": disorders}
# βœ… API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment():
"""Retrieve treatment recommendations based on detected disorders."""
detected_disorders = detect_disorders()["disorders"]
treatments = {}
for disorder in detected_disorders:
if disorder in recommendations_df["Disorder"].values:
treatments[disorder] = recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
else:
# Generate treatment if not in dataset
treatment_prompt = f"""
The user has been diagnosed with {disorder}. Provide a structured treatment plan including:
- **Therapy options** (CBT, psychotherapy, etc.).
- **Medications** (if applicable).
- **Lifestyle strategies** (exercise, mindfulness, etc.).
- **When to seek professional help**.
- **Encouragement**.
Ensure your response is clear and medically sound.
"""
treatments[disorder] = deepseek_request(treatment_prompt, max_tokens=250)
return {"treatments": treatments}
# βœ… API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat():
"""Summarize full chat session using DeepSeek."""
chat_text = " ".join(chat_history)
summary_prompt = f"The following is a conversation between a patient and an AI psychiatrist. Summarize it clearly:\n{chat_text}"
summary = deepseek_request(summary_prompt, max_tokens=500)
return {"summary": summary}