mindspark121 commited on
Commit
df2baf4
·
verified ·
1 Parent(s): df2f743

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -36
app.py CHANGED
@@ -3,8 +3,7 @@ from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import pandas as pd
6
- import random
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
8
 
9
  app = FastAPI()
10
 
@@ -14,11 +13,6 @@ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
14
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
15
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
16
 
17
- # New: Load Local LLM for Dynamic Emotional Responses (Mistral/Llama)
18
- response_model_name = "mistralai/Mistral-7B-Instruct"
19
- response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
20
- response_model = AutoModelForCausalLM.from_pretrained(response_model_name)
21
-
22
  # Load datasets
23
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
24
  questions_df = pd.read_csv("symptom_questions.csv")
@@ -40,38 +34,13 @@ class ChatRequest(BaseModel):
40
  class SummaryRequest(BaseModel):
41
  chat_history: list # List of messages
42
 
43
-
44
  @app.post("/get_questions")
45
  def get_recommended_questions(request: ChatRequest):
46
- """Retrieve the most relevant diagnostic questions with a dynamically generated conversational response."""
47
-
48
- # Step 1: Encode the input message for FAISS search
49
  input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
50
  distances, indices = question_index.search(input_embedding, 3)
51
-
52
- # Step 2: Retrieve the top 3 relevant questions
53
  retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
54
-
55
- # Step 3: Use a local LLM to generate context-aware empathetic responses
56
- prompt = f"""
57
- User: {request.message}
58
-
59
- You are a compassionate psychiatric assistant. Before asking a diagnostic question, respond empathetically.
60
-
61
- Questions:
62
- 1. {retrieved_questions[0]}
63
- 2. {retrieved_questions[1]}
64
- 3. {retrieved_questions[2]}
65
-
66
- Generate a conversational response that introduces each question naturally.
67
- """
68
-
69
- inputs = response_tokenizer(prompt, return_tensors="pt")
70
- output = response_model.generate(**inputs, max_length=300)
71
- enhanced_responses = response_tokenizer.decode(output[0], skip_special_tokens=True).split("\n")
72
-
73
- return {"questions": enhanced_responses}
74
-
75
 
76
  @app.post("/summarize_chat")
77
  def summarize_chat(request: SummaryRequest):
@@ -82,7 +51,6 @@ def summarize_chat(request: SummaryRequest):
82
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
83
  return {"summary": summary}
84
 
85
-
86
  @app.post("/detect_disorders")
87
  def detect_disorders(request: SummaryRequest):
88
  """Detect psychiatric disorders from full chat history at the end."""
@@ -92,7 +60,6 @@ def detect_disorders(request: SummaryRequest):
92
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
93
  return {"disorders": disorders}
94
 
95
-
96
  @app.post("/get_treatment")
97
  def get_treatment(request: SummaryRequest):
98
  """Retrieve treatment recommendations based on detected disorders."""
@@ -102,3 +69,4 @@ def get_treatment(request: SummaryRequest):
102
  for disorder in detected_disorders
103
  }
104
  return {"treatments": treatments}
 
 
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import pandas as pd
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
7
 
8
  app = FastAPI()
9
 
 
13
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
14
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
15
 
 
 
 
 
 
16
  # Load datasets
17
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
18
  questions_df = pd.read_csv("symptom_questions.csv")
 
34
  class SummaryRequest(BaseModel):
35
  chat_history: list # List of messages
36
 
 
37
  @app.post("/get_questions")
38
  def get_recommended_questions(request: ChatRequest):
39
+ """Retrieve the most relevant diagnostic questions."""
 
 
40
  input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
41
  distances, indices = question_index.search(input_embedding, 3)
 
 
42
  retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
43
+ return {"questions": retrieved_questions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @app.post("/summarize_chat")
46
  def summarize_chat(request: SummaryRequest):
 
51
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
52
  return {"summary": summary}
53
 
 
54
  @app.post("/detect_disorders")
55
  def detect_disorders(request: SummaryRequest):
56
  """Detect psychiatric disorders from full chat history at the end."""
 
60
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
61
  return {"disorders": disorders}
62
 
 
63
  @app.post("/get_treatment")
64
  def get_treatment(request: SummaryRequest):
65
  """Retrieve treatment recommendations based on detected disorders."""
 
69
  for disorder in detected_disorders
70
  }
71
  return {"treatments": treatments}
72
+