from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import pandas as pd from transformers import AutoModelForSeq2SeqLM, AutoTokenizer app = FastAPI() # Load AI Models similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") embedding_model = SentenceTransformer("all-MiniLM-L6-v2") summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base") summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base") # Load datasets recommendations_df = pd.read_csv("treatment_recommendations.csv") questions_df = pd.read_csv("symptom_questions.csv") # FAISS Index for disorder detection treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) index.add(treatment_embeddings) # FAISS Index for Question Retrieval question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True) question_index = faiss.IndexFlatL2(question_embeddings.shape[1]) question_index.add(question_embeddings) # Request Model class ChatRequest(BaseModel): message: str class SummaryRequest(BaseModel): chat_history: list # List of messages @app.post("/get_questions") def get_recommended_questions(request: ChatRequest): """Retrieve the most relevant diagnostic questions.""" input_embedding = embedding_model.encode([request.message], convert_to_numpy=True) distances, indices = question_index.search(input_embedding, 3) retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]] return {"questions": retrieved_questions} @app.post("/summarize_chat") def summarize_chat(request: SummaryRequest): """Summarize full chat session at the end.""" chat_text = " ".join(request.chat_history) inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True) summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True) summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) return {"summary": summary} @app.post("/detect_disorders") def detect_disorders(request: SummaryRequest): """Detect psychiatric disorders from full chat history at the end.""" full_chat_text = " ".join(request.chat_history) text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True) distances, indices = index.search(text_embedding, 3) disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]] return {"disorders": disorders} @app.post("/get_treatment") def get_treatment(request: SummaryRequest): """Retrieve treatment recommendations based on detected disorders.""" detected_disorders = detect_disorders(request)["disorders"] treatments = { disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] for disorder in detected_disorders } return {"treatments": treatments}