Spaces:
Runtime error
Runtime error
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import pandas as pd | |
import os | |
import logging | |
from groq import Groq | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# β Set a writable cache directory | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface" | |
# β Initialize FastAPI | |
app = FastAPI() | |
# β Securely Fetch API Key | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # β FIXED | |
if not GROQ_API_KEY: | |
raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.") | |
client = Groq(api_key=GROQ_API_KEY) # β Ensure the API key is passed correctly | |
# β Load AI Models (Now uses /tmp/huggingface as cache) | |
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface") | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface") | |
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") | |
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") | |
# β Check if files exist before loading | |
print("π Available Files:", os.listdir(".")) # This will log available files | |
# β Load datasets with error handling | |
try: | |
recommendations_df = pd.read_csv("treatment_recommendations .csv") | |
questions_df = pd.read_csv("symptom_questions.csv") | |
except FileNotFoundError as e: | |
logging.error(f"β Missing dataset file: {e}") | |
raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}") | |
# β FAISS Index for disorder detection | |
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) | |
index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) | |
index.add(treatment_embeddings) | |
# β FAISS Index for Question Retrieval | |
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True) | |
question_index = faiss.IndexFlatL2(question_embeddings.shape[1]) | |
question_index.add(question_embeddings) | |
# β Request Model | |
class ChatRequest(BaseModel): | |
message: str | |
class SummaryRequest(BaseModel): | |
chat_history: list # List of messages | |
# β Retrieve the most relevant question | |
def retrieve_questions(user_input): | |
"""Retrieve the most relevant individual diagnostic question using FAISS.""" | |
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True) | |
_, indices = question_index.search(input_embedding, 1) # β Retrieve only 1 question | |
if indices[0][0] == -1: | |
return "I'm sorry, I couldn't find a relevant question." | |
# β Extract only the first meaningful question | |
question_block = questions_df["Questions"].iloc[indices[0][0]] | |
split_questions = question_block.split(", ") | |
best_question = split_questions[0] if split_questions else question_block # β Select the first clear question | |
return best_question # β Return a single question as a string | |
# β Groq API for rephrasing | |
def generate_empathetic_response(user_input, retrieved_question): | |
"""Use Groq API (LLaMA-3) to generate one empathetic response.""" | |
# β Improved Prompt: Only One Question | |
prompt = f""" | |
The user said: "{user_input}" | |
Relevant Question: | |
- {retrieved_question} | |
You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way. | |
Acknowledge the user's emotions before asking the question. | |
Example format: | |
- "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?" | |
Generate only one empathetic response. | |
""" | |
try: | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."}, | |
{"role": "user", "content": prompt} | |
], | |
model="llama-3.3-70b-versatile", # β Use Groq's LLaMA-3 Model | |
temperature=0.8, | |
top_p=0.9 | |
) | |
return chat_completion.choices[0].message.content # β Return only one response | |
except Exception as e: | |
logging.error(f"Groq API error: {e}") | |
return "I'm sorry, I couldn't process your request." | |
# β API Endpoint: Get Empathetic Questions (Hybrid RAG) | |
def get_recommended_questions(request: ChatRequest): | |
"""Retrieve the most relevant diagnostic question and make it more empathetic using Groq API.""" | |
retrieved_question = retrieve_questions(request.message) | |
empathetic_response = generate_empathetic_response(request.message, retrieved_question) | |
return {"question": empathetic_response} | |
# β API Endpoint: Summarize Chat | |
def summarize_chat(request: SummaryRequest): | |
"""Summarize full chat session at the end.""" | |
chat_text = " ".join(request.chat_history) | |
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True) | |
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True) | |
summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return {"summary": summary} | |
# β API Endpoint: Detect Disorders | |
def detect_disorders(request: SummaryRequest): | |
"""Detect psychiatric disorders from full chat history at the end.""" | |
full_chat_text = " ".join(request.chat_history) | |
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True) | |
distances, indices = index.search(text_embedding, 3) | |
if indices[0][0] == -1: | |
return {"disorders": "No matching disorder found."} | |
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]] | |
return {"disorders": disorders} | |
# β API Endpoint: Get Treatment Recommendations | |
def get_treatment(request: SummaryRequest): | |
"""Retrieve treatment recommendations based on detected disorders.""" | |
detected_disorders = detect_disorders(request)["disorders"] | |
treatments = { | |
disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] | |
for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values | |
} | |
return {"treatments": treatments} | |