mindspark121 commited on
Commit
244be9c
Β·
verified Β·
1 Parent(s): 4adb072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -95
app.py CHANGED
@@ -1,131 +1,179 @@
 
1
  import os
2
  import requests
3
  import json
 
4
  import pandas as pd
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
 
 
7
 
8
- # Load DSM-5 Dataset
9
- file_path = "dsm5_final_cleaned.csv"
10
- df = pd.read_csv(file_path)
11
 
12
- # OpenRouter API Configuration (DeepSeek Model)
13
- OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") # Use environment variable for security
14
  if not OPENROUTER_API_KEY:
15
- raise ValueError("OPENROUTER_API_KEY is missing. Set it as an environment variable.")
 
16
  OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
17
 
18
- # Initialize FastAPI
19
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Pydantic Models
 
 
 
22
  class ChatRequest(BaseModel):
23
  message: str
24
 
25
  class SummaryRequest(BaseModel):
26
  chat_history: list
27
 
 
 
28
  def deepseek_request(prompt, max_tokens=300):
29
- """Helper function to send a request to DeepSeek API and handle response."""
30
- headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
 
 
 
31
  payload = {
32
  "model": "deepseek/deepseek-r1-distill-llama-8b",
33
  "messages": [{"role": "user", "content": prompt}],
34
  "max_tokens": max_tokens,
35
- "temperature": 0.7
36
  }
37
- response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
38
- if response.status_code == 200:
 
39
  response_json = response.json()
 
40
  if "choices" in response_json and response_json["choices"]:
41
  return response_json["choices"][0].get("message", {}).get("content", "").strip()
42
- return "Error: Unable to process the request."
43
-
44
- def match_disorders(chat_history):
45
- """Match user symptoms with DSM-5 disorders based on keyword occurrence."""
46
- disorder_scores = {}
47
- for _, row in df.iterrows():
48
- disorder = row["Disorder"]
49
- keywords = row["Criteria"].split(", ")
50
- match_count = sum(1 for word in keywords if word in chat_history.lower())
51
- if match_count > 0:
52
- disorder_scores[disorder] = match_count
53
- sorted_disorders = sorted(disorder_scores, key=disorder_scores.get, reverse=True)
54
- return sorted_disorders[:3] if sorted_disorders else []
55
 
56
- @app.post("/detect_disorders")
57
- def detect_disorders(request: SummaryRequest):
58
- """Detect psychiatric disorders using DSM-5 keyword matching + DeepSeek validation."""
59
- full_chat = " ".join(request.chat_history)
60
- matched_disorders = match_disorders(full_chat)
61
-
62
- prompt = f"""
63
- The following is a psychiatric conversation:
64
- {full_chat}
65
-
66
- Based on DSM-5 diagnostic criteria, analyze the symptoms and determine the most probable psychiatric disorders.
67
- Here are possible disorder matches from DSM-5 keyword analysis: {', '.join(matched_disorders) if matched_disorders else 'None found'}.
68
- If no clear matches exist, diagnose based purely on symptom patterns and clinical reasoning.
69
- Return a **list** of disorders, separated by commas, without extra text.
70
- """
71
-
72
- response = deepseek_request(prompt, max_tokens=150)
73
- disorders = [disorder.strip() for disorder in response.split(",")] if response and response.lower() != "unspecified disorder" else matched_disorders
74
- return {"disorders": disorders if disorders else ["Unspecified Disorder"]}
75
 
76
- @app.post("/get_treatment")
77
- def get_treatment(request: SummaryRequest):
78
- """Retrieve structured treatment recommendations based on detected disorders."""
79
- detected_disorders = detect_disorders(request)["disorders"]
80
- disorders_text = ", ".join(detected_disorders)
81
- prompt = f"""
82
- The user has been diagnosed with: {disorders_text}.
83
- Provide a structured, evidence-based treatment plan including:
84
- - Therapy recommendations (e.g., CBT, DBT, psychotherapy).
85
- - Possible medications if applicable (e.g., SSRIs, anxiolytics, sleep aids).
86
- - Lifestyle and self-care strategies (e.g., sleep hygiene, mindfulness, exercise).
87
- If the user has suicidal thoughts, emphasize **immediate crisis intervention and emergency medical support.**
88
- """
89
- treatment_response = deepseek_request(prompt, max_tokens=200)
90
- return {"treatments": treatment_response}
91
 
92
- @app.post("/summarize_chat")
93
- def summarize_chat(request: SummaryRequest):
94
- """Generate a structured summary of the psychiatric consultation."""
95
- full_chat = " ".join(request.chat_history)
96
- detected_disorders = detect_disorders(request)["disorders"]
97
- treatment_response = get_treatment(request)["treatments"]
98
- prompt = f"""
99
- Summarize the following psychiatric conversation:
100
- {full_chat}
101
-
102
- - **Detected Disorders:** {', '.join(detected_disorders)}
103
- - **Suggested Treatments:** {treatment_response}
104
-
105
- The summary should include:
106
- - Main concerns reported by the user.
107
- - Key symptoms observed.
108
- - Possible underlying psychological conditions.
109
- - Recommended next steps, including professional consultation and self-care strategies.
110
- If suicidal thoughts were mentioned, highlight the **need for immediate crisis intervention and professional support.**
111
- """
112
- summary = deepseek_request(prompt, max_tokens=300)
113
- return {"summary": summary}
114
 
115
- @app.post("/chat")
 
116
  def chat(request: ChatRequest):
117
- """Generate AI psychiatric response for user input."""
 
 
 
 
 
 
 
118
  prompt = f"""
119
  You are an AI psychiatrist conducting a mental health consultation.
120
- The user is discussing their concerns and symptoms. Engage in a supportive conversation,
121
- ask relevant follow-up questions, and maintain an empathetic tone.
122
-
 
 
 
 
 
 
123
  User input:
124
- {request.message}
125
-
126
- Provide a meaningful response and a follow-up question if necessary.
127
- If the user mentions suicidal thoughts, respond with an urgent and compassionate tone,
128
- suggesting that they seek immediate professional help while providing emotional support.
 
 
 
129
  """
130
- response = deepseek_request(prompt, max_tokens=200)
131
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import os
3
  import requests
4
  import json
5
+ import logging
6
  import pandas as pd
7
+ import faiss
8
  from fastapi import FastAPI, HTTPException
9
  from pydantic import BaseModel
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
13
+ # βœ… Initialize FastAPI
14
+ app = FastAPI()
 
15
 
16
+ # βœ… Set OpenRouter API Key
17
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
18
  if not OPENROUTER_API_KEY:
19
+ raise ValueError("❌ OPENROUTER_API_KEY is missing. Set it as an environment variable.")
20
+
21
  OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
22
 
23
+ # βœ… Load AI Models
24
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
25
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
26
+ summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
27
+
28
+ # βœ… Load Datasets
29
+ try:
30
+ recommendations_df = pd.read_csv("treatment_recommendations .csv")
31
+ questions_df = pd.read_csv("symptom_questions.csv")
32
+ print("βœ… Datasets Loaded Successfully!")
33
+ except FileNotFoundError as e:
34
+ logging.error(f"❌ Missing dataset file: {e}")
35
+ raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")
36
+
37
+ # βœ… Create FAISS Indexes
38
+ question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
39
+ question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
40
+ question_index.add(question_embeddings)
41
+
42
+ treatment_embeddings = embedding_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
43
+ index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
44
+ index.add(treatment_embeddings)
45
 
46
+ # βœ… Chat History Storage
47
+ chat_history = []
48
+
49
+ # βœ… Request Models
50
  class ChatRequest(BaseModel):
51
  message: str
52
 
53
  class SummaryRequest(BaseModel):
54
  chat_history: list
55
 
56
+
57
+ # βœ… Function: Call DeepSeek via OpenRouter
58
  def deepseek_request(prompt, max_tokens=300):
59
+ """Send a request to OpenRouter's DeepSeek model."""
60
+ headers = {
61
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
62
+ "Content-Type": "application/json"
63
+ }
64
  payload = {
65
  "model": "deepseek/deepseek-r1-distill-llama-8b",
66
  "messages": [{"role": "user", "content": prompt}],
67
  "max_tokens": max_tokens,
68
+ "temperature": 0.8
69
  }
70
+ try:
71
+ response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
72
+ response.raise_for_status()
73
  response_json = response.json()
74
+
75
  if "choices" in response_json and response_json["choices"]:
76
  return response_json["choices"][0].get("message", {}).get("content", "").strip()
77
+ except Exception as e:
78
+ logging.error(f"OpenRouter DeepSeek API error: {e}")
79
+ return "I'm here to support you. Can you share more about what you're feeling?"
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # βœ… Function: Retrieve Relevant Diagnostic Question
82
+ def retrieve_relevant_question(user_input):
83
+ """Find the most relevant diagnostic question from the dataset using FAISS."""
84
+ input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
85
+ _, indices = question_index.search(input_embedding, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ if indices[0][0] == -1:
88
+ return "I'm here to listen. Can you tell me more about your symptoms?"
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ return questions_df["Questions"].iloc[indices[0][0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # βœ… API Endpoint: Chat Interaction
93
+ @app.post("/get_questions")
94
  def chat(request: ChatRequest):
95
+ """Patient enters data, AI responds and stores conversation."""
96
+ user_message = request.message
97
+ chat_history.append(user_message)
98
+
99
+ # Retrieve relevant diagnostic question
100
+ relevant_question = retrieve_relevant_question(user_message)
101
+
102
+ # Constructing the DeepSeek prompt
103
  prompt = f"""
104
  You are an AI psychiatrist conducting a mental health consultation.
105
+ Engage in a supportive, natural conversation, maintaining an empathetic tone.
106
+
107
+ - Always provide a thoughtful and compassionate response.
108
+ - If a user shares distressing emotions, acknowledge their feelings and ask relevant follow-up questions.
109
+ - Ask a symptom-related question to explore their concerns in depth.
110
+
111
+ Previous conversation:
112
+ {chat_history}
113
+
114
  User input:
115
+ "{user_message}"
116
+
117
+ Generate:
118
+ - An empathetic response.
119
+ - A related follow-up question.
120
+ - The most relevant diagnostic question: "{relevant_question}".
121
+
122
+ Ensure your response is always meaningful and non-empty.
123
  """
124
+
125
+ ai_response = deepseek_request(prompt, max_tokens=250)
126
+
127
+ chat_history.append(ai_response)
128
+
129
+ return {"response": ai_response}
130
+
131
+ # βœ… API Endpoint: Detect Disorders from Chat History
132
+ @app.post("/detect_disorders")
133
+ def detect_disorders():
134
+ """Detect psychiatric disorders based on full chat history."""
135
+ full_chat_text = " ".join(chat_history)
136
+ text_embedding = embedding_model.encode([full_chat_text], convert_to_numpy=True)
137
+ distances, indices = index.search(text_embedding, 3)
138
+
139
+ if indices[0][0] == -1:
140
+ return {"disorders": ["No matching disorder found."]}
141
+
142
+ disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
143
+ return {"disorders": disorders}
144
+
145
+ # βœ… API Endpoint: Get Treatment Recommendations
146
+ @app.post("/get_treatment")
147
+ def get_treatment():
148
+ """Retrieve treatment recommendations based on detected disorders."""
149
+ detected_disorders = detect_disorders()["disorders"]
150
+ treatments = {}
151
+
152
+ for disorder in detected_disorders:
153
+ if disorder in recommendations_df["Disorder"].values:
154
+ treatments[disorder] = recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
155
+ else:
156
+ # Generate treatment if not in dataset
157
+ treatment_prompt = f"""
158
+ The user has been diagnosed with {disorder}. Provide a structured treatment plan including:
159
+
160
+ - **Therapy options** (CBT, psychotherapy, etc.).
161
+ - **Medications** (if applicable).
162
+ - **Lifestyle strategies** (exercise, mindfulness, etc.).
163
+ - **When to seek professional help**.
164
+ - **Encouragement**.
165
+
166
+ Ensure your response is clear and medically sound.
167
+ """
168
+ treatments[disorder] = deepseek_request(treatment_prompt, max_tokens=250)
169
+
170
+ return {"treatments": treatments}
171
+
172
+ # βœ… API Endpoint: Summarize Chat
173
+ @app.post("/summarize_chat")
174
+ def summarize_chat():
175
+ """Summarize full chat session using DeepSeek."""
176
+ chat_text = " ".join(chat_history)
177
+ summary_prompt = f"The following is a conversation between a patient and an AI psychiatrist. Summarize it clearly:\n{chat_text}"
178
+ summary = deepseek_request(summary_prompt, max_tokens=500)
179
+ return {"summary": summary}