mindspark121 commited on
Commit
df2f743
·
verified ·
1 Parent(s): 97b87f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -1
app.py CHANGED
@@ -4,7 +4,7 @@ from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import pandas as pd
6
  import random
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
 
9
  app = FastAPI()
10
 
@@ -14,6 +14,11 @@ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
14
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
15
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
16
 
 
 
 
 
 
17
  # Load datasets
18
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
19
  questions_df = pd.read_csv("symptom_questions.csv")
@@ -36,6 +41,36 @@ class SummaryRequest(BaseModel):
36
  chat_history: list # List of messages
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  @app.post("/summarize_chat")
@@ -47,6 +82,7 @@ def summarize_chat(request: SummaryRequest):
47
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
48
  return {"summary": summary}
49
 
 
50
  @app.post("/detect_disorders")
51
  def detect_disorders(request: SummaryRequest):
52
  """Detect psychiatric disorders from full chat history at the end."""
@@ -56,6 +92,7 @@ def detect_disorders(request: SummaryRequest):
56
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
57
  return {"disorders": disorders}
58
 
 
59
  @app.post("/get_treatment")
60
  def get_treatment(request: SummaryRequest):
61
  """Retrieve treatment recommendations based on detected disorders."""
 
4
  import faiss
5
  import pandas as pd
6
  import random
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
8
 
9
  app = FastAPI()
10
 
 
14
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
15
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
16
 
17
+ # New: Load Local LLM for Dynamic Emotional Responses (Mistral/Llama)
18
+ response_model_name = "mistralai/Mistral-7B-Instruct"
19
+ response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
20
+ response_model = AutoModelForCausalLM.from_pretrained(response_model_name)
21
+
22
  # Load datasets
23
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
24
  questions_df = pd.read_csv("symptom_questions.csv")
 
41
  chat_history: list # List of messages
42
 
43
 
44
+ @app.post("/get_questions")
45
+ def get_recommended_questions(request: ChatRequest):
46
+ """Retrieve the most relevant diagnostic questions with a dynamically generated conversational response."""
47
+
48
+ # Step 1: Encode the input message for FAISS search
49
+ input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
50
+ distances, indices = question_index.search(input_embedding, 3)
51
+
52
+ # Step 2: Retrieve the top 3 relevant questions
53
+ retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
54
+
55
+ # Step 3: Use a local LLM to generate context-aware empathetic responses
56
+ prompt = f"""
57
+ User: {request.message}
58
+
59
+ You are a compassionate psychiatric assistant. Before asking a diagnostic question, respond empathetically.
60
+
61
+ Questions:
62
+ 1. {retrieved_questions[0]}
63
+ 2. {retrieved_questions[1]}
64
+ 3. {retrieved_questions[2]}
65
+
66
+ Generate a conversational response that introduces each question naturally.
67
+ """
68
+
69
+ inputs = response_tokenizer(prompt, return_tensors="pt")
70
+ output = response_model.generate(**inputs, max_length=300)
71
+ enhanced_responses = response_tokenizer.decode(output[0], skip_special_tokens=True).split("\n")
72
+
73
+ return {"questions": enhanced_responses}
74
 
75
 
76
  @app.post("/summarize_chat")
 
82
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
83
  return {"summary": summary}
84
 
85
+
86
  @app.post("/detect_disorders")
87
  def detect_disorders(request: SummaryRequest):
88
  """Detect psychiatric disorders from full chat history at the end."""
 
92
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
93
  return {"disorders": disorders}
94
 
95
+
96
  @app.post("/get_treatment")
97
  def get_treatment(request: SummaryRequest):
98
  """Retrieve treatment recommendations based on detected disorders."""