ghuman7 commited on
Commit
49dcda7
·
verified ·
1 Parent(s): 3fc5646

Upload 22 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
37
+ MentalHealth/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
Evaluation_MH/.ipynb_checkpoints/Evaluation-checkpoint.ipynb ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f7b87c2c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# pip install -U sentence-transformers"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "8a7cc9d8",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
32
+ " from tqdm.autonotebook import tqdm, trange\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "from sentence_transformers import SentenceTransformer\n",
38
+ "from langchain.prompts import PromptTemplate\n",
39
+ "from langchain.chains import LLMChain\n",
40
+ "from langchain_community.llms import Ollama\n",
41
+ "from langchain.evaluation import load_evaluator\n",
42
+ "import faiss\n",
43
+ "import pandas as pd\n",
44
+ "import numpy as np\n",
45
+ "import pickle\n",
46
+ "import time\n",
47
+ "from tqdm import tqdm"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "id": "b6efca1d",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Intialization"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 2,
61
+ "id": "cc9a49d2",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# Load the FAISS index\n",
66
+ "index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "id": "9af39b55",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "model = SentenceTransformer('all-MiniLM-L6-v2')"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 4,
82
+ "id": "fee8cdfd",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
87
+ " sections_data = pickle.load(f)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "id": "d6a1ba6a",
93
+ "metadata": {},
94
+ "source": [
95
+ "# RAG functions"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 5,
101
+ "id": "182bdbd8",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def search_faiss(query, k=3):\n",
106
+ " query_vector = model.encode([query])[0].astype('float32')\n",
107
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
108
+ " distances, indices = index.search(query_vector, k)\n",
109
+ " \n",
110
+ " results = []\n",
111
+ " for dist, idx in zip(distances[0], indices[0]):\n",
112
+ " results.append({\n",
113
+ " 'distance': dist,\n",
114
+ " 'content': sections_data[idx]['content'],\n",
115
+ " 'metadata': sections_data[idx]['metadata']\n",
116
+ " })\n",
117
+ " \n",
118
+ " return results"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 15,
124
+ "id": "67edc46a",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Create a prompt template\n",
129
+ "prompt_template = \"\"\"\n",
130
+ "You are an AI assistant specialized in Mental Health guidelines. \n",
131
+ "Use the following pieces of context to answer the question. \n",
132
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
133
+ "\n",
134
+ "Context:\n",
135
+ "{context}\n",
136
+ "\n",
137
+ "Question: {question}\n",
138
+ "\n",
139
+ "Answer:\"\"\"\n",
140
+ "\n",
141
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
142
+ "\n",
143
+ "llm = Ollama(\n",
144
+ " model=\"llama3\"\n",
145
+ ")\n",
146
+ "\n",
147
+ "# Create the chain\n",
148
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
149
+ "\n",
150
+ "def answer_question(query):\n",
151
+ " # Search for relevant context\n",
152
+ " search_results = search_faiss(query)\n",
153
+ " \n",
154
+ " # Combine the content from the search results\n",
155
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
156
+ "\n",
157
+ " # Run the chain\n",
158
+ " response = chain.run(context=context, question=query)\n",
159
+ " \n",
160
+ " return response"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "3b176af9",
166
+ "metadata": {},
167
+ "source": [
168
+ "# Reading GT"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 16,
174
+ "id": "4ab68dff",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "df = pd.read_csv('data/MentalHealth_Dataset.csv')"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 17,
184
+ "id": "4e7e22d7",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stderr",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "time_list=[]\n",
197
+ "response_list=[]\n",
198
+ "for i in tqdm(range(len(df))):\n",
199
+ " query = df['Questions'].values[i]\n",
200
+ " start = time.time()\n",
201
+ " response = answer_question(query)\n",
202
+ " end = time.time() \n",
203
+ " time_list.append(end-start)\n",
204
+ " response_list.append(response)"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 18,
210
+ "id": "2b327e90",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "df['latency'] = time_list\n",
215
+ "df['response'] = response_list"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "3c147204",
221
+ "metadata": {},
222
+ "source": [
223
+ "# Evaluation"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 29,
229
+ "id": "d799e541",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "eval_llm = Ollama(\n",
234
+ " model=\"phi3\"\n",
235
+ ")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 30,
241
+ "id": "c2f788dc",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 31,
251
+ "id": "83ec2b8d",
252
+ "metadata": {},
253
+ "outputs": [
254
+ {
255
+ "name": "stderr",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
259
+ "100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
260
+ "100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
261
+ "100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "for metric in metrics:\n",
267
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
268
+ " \n",
269
+ " reasoning = []\n",
270
+ " value = []\n",
271
+ " score = []\n",
272
+ " \n",
273
+ " for i in tqdm(range(len(df))):\n",
274
+ " eval_result = evaluator.evaluate_strings(\n",
275
+ " prediction=df.response.values[i],\n",
276
+ " input=df.Questions.values[i],\n",
277
+ " reference=df.Answers.values[i]\n",
278
+ " )\n",
279
+ " reasoning.append(eval_result['reasoning'])\n",
280
+ " value.append(eval_result['value'])\n",
281
+ " score.append(eval_result['score'])\n",
282
+ " \n",
283
+ " df[metric+'_reasoning'] = reasoning\n",
284
+ " df[metric+'_value'] = value\n",
285
+ " df[metric+'_score'] = score "
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 78,
291
+ "id": "f1673a31",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>Questions</th>\n",
316
+ " <th>Answers</th>\n",
317
+ " <th>latency</th>\n",
318
+ " <th>response</th>\n",
319
+ " <th>correctness_reasoning</th>\n",
320
+ " <th>correctness_value</th>\n",
321
+ " <th>correctness_score</th>\n",
322
+ " <th>relevance_reasoning</th>\n",
323
+ " <th>relevance_value</th>\n",
324
+ " <th>relevance_score</th>\n",
325
+ " <th>coherence_reasoning</th>\n",
326
+ " <th>coherence_value</th>\n",
327
+ " <th>coherence_score</th>\n",
328
+ " <th>conciseness_reasoning</th>\n",
329
+ " <th>conciseness_value</th>\n",
330
+ " <th>conciseness_score</th>\n",
331
+ " </tr>\n",
332
+ " </thead>\n",
333
+ " <tbody>\n",
334
+ " <tr>\n",
335
+ " <th>0</th>\n",
336
+ " <td>What is Mental Health</td>\n",
337
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
338
+ " <td>11.974234</td>\n",
339
+ " <td>Based on the provided context, specifically fr...</td>\n",
340
+ " <td>The submission refers to the provided input wh...</td>\n",
341
+ " <td>Y</td>\n",
342
+ " <td>1</td>\n",
343
+ " <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
344
+ " <td>Y</td>\n",
345
+ " <td>1</td>\n",
346
+ " <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
347
+ " <td>Y</td>\n",
348
+ " <td>1</td>\n",
349
+ " <td>1. The submission directly answers the questio...</td>\n",
350
+ " <td>Y</td>\n",
351
+ " <td>1</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>1</th>\n",
355
+ " <td>What are the most common mental disorders ment...</td>\n",
356
+ " <td>The most common mental disorders include depre...</td>\n",
357
+ " <td>5.863329</td>\n",
358
+ " <td>Based on the provided context, the mental diso...</td>\n",
359
+ " <td>Step 1: Check if the submission is factually a...</td>\n",
360
+ " <td>Y</td>\n",
361
+ " <td>1</td>\n",
362
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
363
+ " <td>Y</td>\n",
364
+ " <td>1</td>\n",
365
+ " <td>The submission begins with an appropriate ques...</td>\n",
366
+ " <td>Y</td>\n",
367
+ " <td>1</td>\n",
368
+ " <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
369
+ " <td>Y</td>\n",
370
+ " <td>1</td>\n",
371
+ " </tr>\n",
372
+ " <tr>\n",
373
+ " <th>2</th>\n",
374
+ " <td>What are the early warning signs and symptoms ...</td>\n",
375
+ " <td>Early warning signs and symptoms of depression...</td>\n",
376
+ " <td>13.434543</td>\n",
377
+ " <td>Based on the provided context, I found a refer...</td>\n",
378
+ " <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
379
+ " <td>Y</td>\n",
380
+ " <td>1</td>\n",
381
+ " <td>Step 1: Identify the relevant criterion from t...</td>\n",
382
+ " <td>Y</td>\n",
383
+ " <td>1</td>\n",
384
+ " <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
385
+ " <td>Y</td>\n",
386
+ " <td>1</td>\n",
387
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
388
+ " <td>Y</td>\n",
389
+ " <td>1</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <th>3</th>\n",
393
+ " <td>How can someone help a person who suffers from...</td>\n",
394
+ " <td>To help someone with anxiety, one can support ...</td>\n",
395
+ " <td>13.838464</td>\n",
396
+ " <td>According to the provided context, specificall...</td>\n",
397
+ " <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
398
+ " <td>Y</td>\n",
399
+ " <td>1</td>\n",
400
+ " <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
401
+ " <td>Y</td>\n",
402
+ " <td>1</td>\n",
403
+ " <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
404
+ " <td>Y</td>\n",
405
+ " <td>1</td>\n",
406
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
407
+ " <td>N</td>\n",
408
+ " <td>0</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <th>4</th>\n",
412
+ " <td>What are the causes of mental illness listed i...</td>\n",
413
+ " <td>Causes of mental illness include abnormal func...</td>\n",
414
+ " <td>6.871735</td>\n",
415
+ " <td>According to the provided context, the causes ...</td>\n",
416
+ " <td>The submission lists factors that align with t...</td>\n",
417
+ " <td>N</td>\n",
418
+ " <td>0</td>\n",
419
+ " <td>Step 1: Review relevance criterion - Check if ...</td>\n",
420
+ " <td>Y</td>\n",
421
+ " <td>1</td>\n",
422
+ " <td>Step 1: Compare the submission with the provid...</td>\n",
423
+ " <td>Y</td>\n",
424
+ " <td>1</td>\n",
425
+ " <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
426
+ " <td>Y</td>\n",
427
+ " <td>1</td>\n",
428
+ " </tr>\n",
429
+ " </tbody>\n",
430
+ "</table>\n",
431
+ "</div>"
432
+ ],
433
+ "text/plain": [
434
+ " Questions \\\n",
435
+ "0 What is Mental Health \n",
436
+ "1 What are the most common mental disorders ment... \n",
437
+ "2 What are the early warning signs and symptoms ... \n",
438
+ "3 How can someone help a person who suffers from... \n",
439
+ "4 What are the causes of mental illness listed i... \n",
440
+ "\n",
441
+ " Answers latency \\\n",
442
+ "0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
443
+ "1 The most common mental disorders include depre... 5.863329 \n",
444
+ "2 Early warning signs and symptoms of depression... 13.434543 \n",
445
+ "3 To help someone with anxiety, one can support ... 13.838464 \n",
446
+ "4 Causes of mental illness include abnormal func... 6.871735 \n",
447
+ "\n",
448
+ " response \\\n",
449
+ "0 Based on the provided context, specifically fr... \n",
450
+ "1 Based on the provided context, the mental diso... \n",
451
+ "2 Based on the provided context, I found a refer... \n",
452
+ "3 According to the provided context, specificall... \n",
453
+ "4 According to the provided context, the causes ... \n",
454
+ "\n",
455
+ " correctness_reasoning correctness_value \\\n",
456
+ "0 The submission refers to the provided input wh... Y \n",
457
+ "1 Step 1: Check if the submission is factually a... Y \n",
458
+ "2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
459
+ "3 Step 1: Correctness\\nThe submission accurately... Y \n",
460
+ "4 The submission lists factors that align with t... N \n",
461
+ "\n",
462
+ " correctness_score relevance_reasoning \\\n",
463
+ "0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
464
+ "1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
465
+ "2 1 Step 1: Identify the relevant criterion from t... \n",
466
+ "3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
467
+ "4 0 Step 1: Review relevance criterion - Check if ... \n",
468
+ "\n",
469
+ " relevance_value relevance_score \\\n",
470
+ "0 Y 1 \n",
471
+ "1 Y 1 \n",
472
+ "2 Y 1 \n",
473
+ "3 Y 1 \n",
474
+ "4 Y 1 \n",
475
+ "\n",
476
+ " coherence_reasoning coherence_value \\\n",
477
+ "0 Step 1: Assess coherence\\nThe submission direc... Y \n",
478
+ "1 The submission begins with an appropriate ques... Y \n",
479
+ "2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
480
+ "3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
481
+ "4 Step 1: Compare the submission with the provid... Y \n",
482
+ "\n",
483
+ " coherence_score conciseness_reasoning \\\n",
484
+ "0 1 1. The submission directly answers the questio... \n",
485
+ "1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
486
+ "2 1 Step 1: Evaluate conciseness - The submission ... \n",
487
+ "3 1 Step 1: Evaluate conciseness - The submission ... \n",
488
+ "4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
489
+ "\n",
490
+ " conciseness_value conciseness_score \n",
491
+ "0 Y 1 \n",
492
+ "1 Y 1 \n",
493
+ "2 Y 1 \n",
494
+ "3 N 0 \n",
495
+ "4 Y 1 "
496
+ ]
497
+ },
498
+ "execution_count": 78,
499
+ "metadata": {},
500
+ "output_type": "execute_result"
501
+ }
502
+ ],
503
+ "source": [
504
+ "df.head()"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 32,
510
+ "id": "7797a360",
511
+ "metadata": {},
512
+ "outputs": [
513
+ {
514
+ "data": {
515
+ "text/plain": [
516
+ "correctness_score 0.800000\n",
517
+ "relevance_score 0.900000\n",
518
+ "coherence_score 1.000000\n",
519
+ "conciseness_score 0.800000\n",
520
+ "latency 10.544803\n",
521
+ "dtype: float64"
522
+ ]
523
+ },
524
+ "execution_count": 32,
525
+ "metadata": {},
526
+ "output_type": "execute_result"
527
+ }
528
+ ],
529
+ "source": [
530
+ "df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 34,
536
+ "id": "fe667926",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "irr_q=pd.read_csv('data/Unrelated_questions.csv')"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 35,
546
+ "id": "189f8a0f",
547
+ "metadata": {},
548
+ "outputs": [
549
+ {
550
+ "name": "stderr",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
554
+ ]
555
+ }
556
+ ],
557
+ "source": [
558
+ "time_list=[]\n",
559
+ "response_list=[]\n",
560
+ "for i in tqdm(range(len(irr_q))):\n",
561
+ " query = irr_q['Questions'].values[i]\n",
562
+ " start = time.time()\n",
563
+ " response = answer_question(query)\n",
564
+ " end = time.time() \n",
565
+ " time_list.append(end-start)\n",
566
+ " response_list.append(response)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 36,
572
+ "id": "b0244ea0",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "irr_q['response']=response_list\n",
577
+ "irr_q['latency']=time_list"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 79,
583
+ "id": "dc3b1ade",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/html": [
589
+ "<div>\n",
590
+ "<style scoped>\n",
591
+ " .dataframe tbody tr th:only-of-type {\n",
592
+ " vertical-align: middle;\n",
593
+ " }\n",
594
+ "\n",
595
+ " .dataframe tbody tr th {\n",
596
+ " vertical-align: top;\n",
597
+ " }\n",
598
+ "\n",
599
+ " .dataframe thead th {\n",
600
+ " text-align: right;\n",
601
+ " }\n",
602
+ "</style>\n",
603
+ "<table border=\"1\" class=\"dataframe\">\n",
604
+ " <thead>\n",
605
+ " <tr style=\"text-align: right;\">\n",
606
+ " <th></th>\n",
607
+ " <th>Questions</th>\n",
608
+ " <th>response</th>\n",
609
+ " <th>latency</th>\n",
610
+ " <th>irrelevant_score</th>\n",
611
+ " </tr>\n",
612
+ " </thead>\n",
613
+ " <tbody>\n",
614
+ " <tr>\n",
615
+ " <th>0</th>\n",
616
+ " <td>What is the capital of Mars?</td>\n",
617
+ " <td>I don't know. The provided context does not se...</td>\n",
618
+ " <td>12.207266</td>\n",
619
+ " <td>True</td>\n",
620
+ " </tr>\n",
621
+ " <tr>\n",
622
+ " <th>1</th>\n",
623
+ " <td>How many unicorns live in New York City?</td>\n",
624
+ " <td>I don't know. The information provided does no...</td>\n",
625
+ " <td>2.368774</td>\n",
626
+ " <td>True</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <th>2</th>\n",
630
+ " <td>What is the color of happiness?</td>\n",
631
+ " <td>I don't know! The provided context only talks ...</td>\n",
632
+ " <td>5.480067</td>\n",
633
+ " <td>True</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <th>3</th>\n",
637
+ " <td>Can cats fly on Tuesdays?</td>\n",
638
+ " <td>I don't know the answer to this question as it...</td>\n",
639
+ " <td>5.272529</td>\n",
640
+ " <td>True</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <th>4</th>\n",
644
+ " <td>How much does a thought weigh?</td>\n",
645
+ " <td>I don't know. The context provided is about me...</td>\n",
646
+ " <td>5.253224</td>\n",
647
+ " <td>True</td>\n",
648
+ " </tr>\n",
649
+ " </tbody>\n",
650
+ "</table>\n",
651
+ "</div>"
652
+ ],
653
+ "text/plain": [
654
+ " Questions \\\n",
655
+ "0 What is the capital of Mars? \n",
656
+ "1 How many unicorns live in New York City? \n",
657
+ "2 What is the color of happiness? \n",
658
+ "3 Can cats fly on Tuesdays? \n",
659
+ "4 How much does a thought weigh? \n",
660
+ "\n",
661
+ " response latency \\\n",
662
+ "0 I don't know. The provided context does not se... 12.207266 \n",
663
+ "1 I don't know. The information provided does no... 2.368774 \n",
664
+ "2 I don't know! The provided context only talks ... 5.480067 \n",
665
+ "3 I don't know the answer to this question as it... 5.272529 \n",
666
+ "4 I don't know. The context provided is about me... 5.253224 \n",
667
+ "\n",
668
+ " irrelevant_score \n",
669
+ "0 True \n",
670
+ "1 True \n",
671
+ "2 True \n",
672
+ "3 True \n",
673
+ "4 True "
674
+ ]
675
+ },
676
+ "execution_count": 79,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "irr_q.head()"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": 37,
688
+ "id": "8620e50c",
689
+ "metadata": {},
690
+ "outputs": [
691
+ {
692
+ "data": {
693
+ "text/plain": [
694
+ "0 12.207266\n",
695
+ "1 2.368774\n",
696
+ "2 5.480067\n",
697
+ "3 5.272529\n",
698
+ "4 5.253224\n",
699
+ "5 5.351224\n",
700
+ "6 8.118429\n",
701
+ "7 7.288261\n",
702
+ "8 3.856500\n",
703
+ "9 7.745016\n",
704
+ "Name: latency, dtype: float64"
705
+ ]
706
+ },
707
+ "execution_count": 37,
708
+ "metadata": {},
709
+ "output_type": "execute_result"
710
+ }
711
+ ],
712
+ "source": [
713
+ "irr_q['latency']"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": 39,
719
+ "id": "debd3461",
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": [
723
+ "irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 40,
729
+ "id": "bef1d3a4",
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "data": {
734
+ "text/plain": [
735
+ "irrelevant_score 0.900000\n",
736
+ "latency 6.294129\n",
737
+ "dtype: float64"
738
+ ]
739
+ },
740
+ "execution_count": 40,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "irr_q[['irrelevant_score','latency']].mean()"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "markdown",
751
+ "id": "c1610a70",
752
+ "metadata": {},
753
+ "source": [
754
+ "# Improvement"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 48,
760
+ "id": "ff6614f9",
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": [
764
+ "new_prompt_template = \"\"\"\n",
765
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
766
+ "Use the provided context to answer the question short and accurately. \n",
767
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
768
+ "\n",
769
+ "Context:\n",
770
+ "{context}\n",
771
+ "\n",
772
+ "Question: {question}\n",
773
+ "\n",
774
+ "Answer:\"\"\"\n",
775
+ "\n",
776
+ "prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
777
+ "\n",
778
+ "llm = Ollama(\n",
779
+ " model=\"llama3\"\n",
780
+ ")\n",
781
+ "\n",
782
+ "# Create the chain\n",
783
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
784
+ "\n",
785
+ "def answer_question_new(query):\n",
786
+ " # Search for relevant context\n",
787
+ " search_results = search_faiss(query)\n",
788
+ " \n",
789
+ " # Combine the content from the search results\n",
790
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
791
+ "\n",
792
+ " # Run the chain\n",
793
+ " response = chain.run(context=context, question=query)\n",
794
+ " \n",
795
+ " return response"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 49,
801
+ "id": "20580d50",
802
+ "metadata": {},
803
+ "outputs": [],
804
+ "source": [
805
+ "df2=df.copy()"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 50,
811
+ "id": "b1b3d725",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "name": "stderr",
816
+ "output_type": "stream",
817
+ "text": [
818
+ "100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
819
+ ]
820
+ }
821
+ ],
822
+ "source": [
823
+ "time_list=[]\n",
824
+ "response_list=[]\n",
825
+ "for i in tqdm(range(len(df2))):\n",
826
+ " query = df2['Questions'].values[i]\n",
827
+ " start = time.time()\n",
828
+ " response = answer_question(query)\n",
829
+ " end = time.time() \n",
830
+ " time_list.append(end-start)\n",
831
+ " response_list.append(response)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 51,
837
+ "id": "63f41256",
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "df2['latency'] = time_list\n",
842
+ "df2['response'] = response_list"
843
+ ]
844
+ },
845
+ {
846
+ "cell_type": "code",
847
+ "execution_count": 52,
848
+ "id": "0d8a6065",
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stderr",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
856
+ "100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
857
+ "100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
858
+ "100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
859
+ ]
860
+ }
861
+ ],
862
+ "source": [
863
+ "for metric in metrics:\n",
864
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
865
+ " \n",
866
+ " reasoning = []\n",
867
+ " value = []\n",
868
+ " score = []\n",
869
+ " \n",
870
+ " for i in tqdm(range(len(df2))):\n",
871
+ " eval_result = evaluator.evaluate_strings(\n",
872
+ " prediction=df2.response.values[i],\n",
873
+ " input=df2.Questions.values[i],\n",
874
+ " reference=df2.Answers.values[i]\n",
875
+ " )\n",
876
+ " reasoning.append(eval_result['reasoning'])\n",
877
+ " value.append(eval_result['value'])\n",
878
+ " score.append(eval_result['score'])\n",
879
+ " \n",
880
+ " df2[metric+'_reasoning'] = reasoning\n",
881
+ " df2[metric+'_value'] = value\n",
882
+ " df2[metric+'_score'] = score "
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 77,
888
+ "id": "c648632c",
889
+ "metadata": {},
890
+ "outputs": [
891
+ {
892
+ "data": {
893
+ "text/html": [
894
+ "<div>\n",
895
+ "<style scoped>\n",
896
+ " .dataframe tbody tr th:only-of-type {\n",
897
+ " vertical-align: middle;\n",
898
+ " }\n",
899
+ "\n",
900
+ " .dataframe tbody tr th {\n",
901
+ " vertical-align: top;\n",
902
+ " }\n",
903
+ "\n",
904
+ " .dataframe thead th {\n",
905
+ " text-align: right;\n",
906
+ " }\n",
907
+ "</style>\n",
908
+ "<table border=\"1\" class=\"dataframe\">\n",
909
+ " <thead>\n",
910
+ " <tr style=\"text-align: right;\">\n",
911
+ " <th></th>\n",
912
+ " <th>Questions</th>\n",
913
+ " <th>Answers</th>\n",
914
+ " <th>latency</th>\n",
915
+ " <th>response</th>\n",
916
+ " <th>correctness_reasoning</th>\n",
917
+ " <th>correctness_value</th>\n",
918
+ " <th>correctness_score</th>\n",
919
+ " <th>relevance_reasoning</th>\n",
920
+ " <th>relevance_value</th>\n",
921
+ " <th>relevance_score</th>\n",
922
+ " <th>coherence_reasoning</th>\n",
923
+ " <th>coherence_value</th>\n",
924
+ " <th>coherence_score</th>\n",
925
+ " <th>conciseness_reasoning</th>\n",
926
+ " <th>conciseness_value</th>\n",
927
+ " <th>conciseness_score</th>\n",
928
+ " </tr>\n",
929
+ " </thead>\n",
930
+ " <tbody>\n",
931
+ " <tr>\n",
932
+ " <th>0</th>\n",
933
+ " <td>What is Mental Health</td>\n",
934
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
935
+ " <td>11.046327</td>\n",
936
+ " <td>Based on the context provided, mental health r...</td>\n",
937
+ " <td>Step 1: Evaluate if the submission is factuall...</td>\n",
938
+ " <td>N</td>\n",
939
+ " <td>0</td>\n",
940
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
941
+ " <td>N</td>\n",
942
+ " <td>0</td>\n",
943
+ " <td>The submission discusses mental health in rela...</td>\n",
944
+ " <td>Y</td>\n",
945
+ " <td>1</td>\n",
946
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
947
+ " <td>Y</td>\n",
948
+ " <td>1</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <th>1</th>\n",
952
+ " <td>What are the most common mental disorders ment...</td>\n",
953
+ " <td>The most common mental disorders include depre...</td>\n",
954
+ " <td>4.509713</td>\n",
955
+ " <td>The handbook mentions several mental illnesses...</td>\n",
956
+ " <td>The submission mentions depression and schizop...</td>\n",
957
+ " <td>N</td>\n",
958
+ " <td>0</td>\n",
959
+ " <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
960
+ " <td>Y</td>\n",
961
+ " <td>1</td>\n",
962
+ " <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
963
+ " <td>N</td>\n",
964
+ " <td>0</td>\n",
965
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
966
+ " <td>N</td>\n",
967
+ " <td>0</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <th>2</th>\n",
971
+ " <td>What are the early warning signs and symptoms ...</td>\n",
972
+ " <td>Early warning signs and symptoms of depression...</td>\n",
973
+ " <td>8.501180</td>\n",
974
+ " <td>According to the provided context, specificall...</td>\n",
975
+ " <td>The submission matches the reference data in t...</td>\n",
976
+ " <td>Y</td>\n",
977
+ " <td>1</td>\n",
978
+ " <td>The submission refers directly to information ...</td>\n",
979
+ " <td>Y</td>\n",
980
+ " <td>1</td>\n",
981
+ " <td>Step 1: Evaluate coherence - The submission is...</td>\n",
982
+ " <td>Y</td>\n",
983
+ " <td>1</td>\n",
984
+ " <td>The submission is concise and includes most of...</td>\n",
985
+ " <td>Y</td>\n",
986
+ " <td>1</td>\n",
987
+ " </tr>\n",
988
+ " <tr>\n",
989
+ " <th>3</th>\n",
990
+ " <td>How can someone help a person who suffers from...</td>\n",
991
+ " <td>To help someone with anxiety, one can support ...</td>\n",
992
+ " <td>10.611402</td>\n",
993
+ " <td>According to the Mental Health Handbook, when ...</td>\n",
994
+ " <td>The submission seems consistent with the refer...</td>\n",
995
+ " <td>Y</td>\n",
996
+ " <td>1</td>\n",
997
+ " <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
998
+ " <td>Y</td>\n",
999
+ " <td>1</td>\n",
1000
+ " <td>The submission is coherent, well-structured, a...</td>\n",
1001
+ " <td>Y</td>\n",
1002
+ " <td>1</td>\n",
1003
+ " <td>The submission is relatively concise and cover...</td>\n",
1004
+ " <td>Y</td>\n",
1005
+ " <td>1</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <th>4</th>\n",
1009
+ " <td>What are the causes of mental illness listed i...</td>\n",
1010
+ " <td>Causes of mental illness include abnormal func...</td>\n",
1011
+ " <td>6.299272</td>\n",
1012
+ " <td>According to the context, the causes of mental...</td>\n",
1013
+ " <td>The submission lists causes such as neglect, s...</td>\n",
1014
+ " <td>N</td>\n",
1015
+ " <td>0</td>\n",
1016
+ " <td>The submission mentions factors that are part ...</td>\n",
1017
+ " <td>N</td>\n",
1018
+ " <td>0</td>\n",
1019
+ " <td>The submission is coherent and well-structured...</td>\n",
1020
+ " <td>Y</td>\n",
1021
+ " <td>1</td>\n",
1022
+ " <td>Step 1: Read and understand both the input dat...</td>\n",
1023
+ " <td>N</td>\n",
1024
+ " <td>0</td>\n",
1025
+ " </tr>\n",
1026
+ " </tbody>\n",
1027
+ "</table>\n",
1028
+ "</div>"
1029
+ ],
1030
+ "text/plain": [
1031
+ " Questions \\\n",
1032
+ "0 What is Mental Health \n",
1033
+ "1 What are the most common mental disorders ment... \n",
1034
+ "2 What are the early warning signs and symptoms ... \n",
1035
+ "3 How can someone help a person who suffers from... \n",
1036
+ "4 What are the causes of mental illness listed i... \n",
1037
+ "\n",
1038
+ " Answers latency \\\n",
1039
+ "0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
1040
+ "1 The most common mental disorders include depre... 4.509713 \n",
1041
+ "2 Early warning signs and symptoms of depression... 8.501180 \n",
1042
+ "3 To help someone with anxiety, one can support ... 10.611402 \n",
1043
+ "4 Causes of mental illness include abnormal func... 6.299272 \n",
1044
+ "\n",
1045
+ " response \\\n",
1046
+ "0 Based on the context provided, mental health r... \n",
1047
+ "1 The handbook mentions several mental illnesses... \n",
1048
+ "2 According to the provided context, specificall... \n",
1049
+ "3 According to the Mental Health Handbook, when ... \n",
1050
+ "4 According to the context, the causes of mental... \n",
1051
+ "\n",
1052
+ " correctness_reasoning correctness_value \\\n",
1053
+ "0 Step 1: Evaluate if the submission is factuall... N \n",
1054
+ "1 The submission mentions depression and schizop... N \n",
1055
+ "2 The submission matches the reference data in t... Y \n",
1056
+ "3 The submission seems consistent with the refer... Y \n",
1057
+ "4 The submission lists causes such as neglect, s... N \n",
1058
+ "\n",
1059
+ " correctness_score relevance_reasoning \\\n",
1060
+ "0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
1061
+ "1 0 Step 1: Analyze relevance criterion - Check if... \n",
1062
+ "2 1 The submission refers directly to information ... \n",
1063
+ "3 1 Step 1: Review relevance criterion\\nThe submis... \n",
1064
+ "4 0 The submission mentions factors that are part ... \n",
1065
+ "\n",
1066
+ " relevance_value relevance_score \\\n",
1067
+ "0 N 0 \n",
1068
+ "1 Y 1 \n",
1069
+ "2 Y 1 \n",
1070
+ "3 Y 1 \n",
1071
+ "4 N 0 \n",
1072
+ "\n",
1073
+ " coherence_reasoning coherence_value \\\n",
1074
+ "0 The submission discusses mental health in rela... Y \n",
1075
+ "1 Step 1: Assess coherence\\nThe submission menti... N \n",
1076
+ "2 Step 1: Evaluate coherence - The submission is... Y \n",
1077
+ "3 The submission is coherent, well-structured, a... Y \n",
1078
+ "4 The submission is coherent and well-structured... Y \n",
1079
+ "\n",
1080
+ " coherence_score conciseness_reasoning \\\n",
1081
+ "0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1082
+ "1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1083
+ "2 1 The submission is concise and includes most of... \n",
1084
+ "3 1 The submission is relatively concise and cover... \n",
1085
+ "4 1 Step 1: Read and understand both the input dat... \n",
1086
+ "\n",
1087
+ " conciseness_value conciseness_score \n",
1088
+ "0 Y 1 \n",
1089
+ "1 N 0 \n",
1090
+ "2 Y 1 \n",
1091
+ "3 Y 1 \n",
1092
+ "4 N 0 "
1093
+ ]
1094
+ },
1095
+ "execution_count": 77,
1096
+ "metadata": {},
1097
+ "output_type": "execute_result"
1098
+ }
1099
+ ],
1100
+ "source": [
1101
+ "df2.head()"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": 47,
1107
+ "id": "2d1002b2",
1108
+ "metadata": {},
1109
+ "outputs": [
1110
+ {
1111
+ "data": {
1112
+ "text/plain": [
1113
+ "correctness_score 0.500000\n",
1114
+ "relevance_score 0.888889\n",
1115
+ "coherence_score 0.888889\n",
1116
+ "conciseness_score 0.900000\n",
1117
+ "latency 8.190205\n",
1118
+ "dtype: float64"
1119
+ ]
1120
+ },
1121
+ "execution_count": 47,
1122
+ "metadata": {},
1123
+ "output_type": "execute_result"
1124
+ }
1125
+ ],
1126
+ "source": [
1127
+ "df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "markdown",
1132
+ "id": "e808bdcf",
1133
+ "metadata": {},
1134
+ "source": [
1135
+ "# Query relevance"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "execution_count": 66,
1141
+ "id": "6b541f3d",
1142
+ "metadata": {},
1143
+ "outputs": [],
1144
+ "source": [
1145
+ "def new_search_faiss(query, k=3, threshold=1.5):\n",
1146
+ " query_vector = model.encode([query])[0].astype('float32')\n",
1147
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
1148
+ " distances, indices = index.search(query_vector, k)\n",
1149
+ " \n",
1150
+ " results = []\n",
1151
+ " for dist, idx in zip(distances[0], indices[0]):\n",
1152
+ " if dist < threshold: # Only include results within the threshold distance\n",
1153
+ " results.append({\n",
1154
+ " 'distance': dist,\n",
1155
+ " 'content': sections_data[idx]['content'],\n",
1156
+ " 'metadata': sections_data[idx]['metadata']\n",
1157
+ " })\n",
1158
+ " \n",
1159
+ " return results"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "code",
1164
+ "execution_count": 70,
1165
+ "id": "4f579654",
1166
+ "metadata": {},
1167
+ "outputs": [],
1168
+ "source": [
1169
+ "new_prompt_template = \"\"\"\n",
1170
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
1171
+ "Use the provided context to answer the question short and accurately. \n",
1172
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
1173
+ "\n",
1174
+ "Context:\n",
1175
+ "{context}\n",
1176
+ "\n",
1177
+ "Question: {question}\n",
1178
+ "\n",
1179
+ "Answer:\"\"\"\n",
1180
+ "\n",
1181
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
1182
+ "\n",
1183
+ "llm = Ollama(\n",
1184
+ " model=\"llama3\"\n",
1185
+ ")\n",
1186
+ "\n",
1187
+ "# Create the chain\n",
1188
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
1189
+ "\n",
1190
+ "def new_answer_question(query):\n",
1191
+ " # Search for relevant context\n",
1192
+ " search_results = new_search_faiss(query)\n",
1193
+ " \n",
1194
+ " if search_results==[]:\n",
1195
+ " response=\"I don't know, sorry\"\n",
1196
+ " else:\n",
1197
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
1198
+ " response = chain.run(context=context, question=query)\n",
1199
+ " \n",
1200
+ " return response"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": 71,
1206
+ "id": "1f83ef1b",
1207
+ "metadata": {},
1208
+ "outputs": [],
1209
+ "source": [
1210
+ "irr_q2=irr_q.copy()"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": 72,
1216
+ "id": "f06474e3",
1217
+ "metadata": {},
1218
+ "outputs": [
1219
+ {
1220
+ "name": "stderr",
1221
+ "output_type": "stream",
1222
+ "text": [
1223
+ "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
1224
+ ]
1225
+ }
1226
+ ],
1227
+ "source": [
1228
+ "time_list=[]\n",
1229
+ "response_list=[]\n",
1230
+ "for i in tqdm(range(len(irr_q2))):\n",
1231
+ " query = irr_q['Questions'].values[i]\n",
1232
+ " start = time.time()\n",
1233
+ " response = new_answer_question(query)\n",
1234
+ " end = time.time() \n",
1235
+ " time_list.append(end-start)\n",
1236
+ " response_list.append(response)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": 73,
1242
+ "id": "52db6b82",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": [
1246
+ "irr_q2['response']=response_list\n",
1247
+ "irr_q2['latency']=time_list"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 80,
1253
+ "id": "80a178ee",
1254
+ "metadata": {},
1255
+ "outputs": [
1256
+ {
1257
+ "data": {
1258
+ "text/html": [
1259
+ "<div>\n",
1260
+ "<style scoped>\n",
1261
+ " .dataframe tbody tr th:only-of-type {\n",
1262
+ " vertical-align: middle;\n",
1263
+ " }\n",
1264
+ "\n",
1265
+ " .dataframe tbody tr th {\n",
1266
+ " vertical-align: top;\n",
1267
+ " }\n",
1268
+ "\n",
1269
+ " .dataframe thead th {\n",
1270
+ " text-align: right;\n",
1271
+ " }\n",
1272
+ "</style>\n",
1273
+ "<table border=\"1\" class=\"dataframe\">\n",
1274
+ " <thead>\n",
1275
+ " <tr style=\"text-align: right;\">\n",
1276
+ " <th></th>\n",
1277
+ " <th>Questions</th>\n",
1278
+ " <th>response</th>\n",
1279
+ " <th>latency</th>\n",
1280
+ " <th>irrelevant_score</th>\n",
1281
+ " </tr>\n",
1282
+ " </thead>\n",
1283
+ " <tbody>\n",
1284
+ " <tr>\n",
1285
+ " <th>0</th>\n",
1286
+ " <td>What is the capital of Mars?</td>\n",
1287
+ " <td>I don't know, sorry</td>\n",
1288
+ " <td>0.061378</td>\n",
1289
+ " <td>True</td>\n",
1290
+ " </tr>\n",
1291
+ " <tr>\n",
1292
+ " <th>1</th>\n",
1293
+ " <td>How many unicorns live in New York City?</td>\n",
1294
+ " <td>I don't know, sorry</td>\n",
1295
+ " <td>0.012511</td>\n",
1296
+ " <td>True</td>\n",
1297
+ " </tr>\n",
1298
+ " <tr>\n",
1299
+ " <th>2</th>\n",
1300
+ " <td>What is the color of happiness?</td>\n",
1301
+ " <td>I don't know, sorry</td>\n",
1302
+ " <td>0.011900</td>\n",
1303
+ " <td>True</td>\n",
1304
+ " </tr>\n",
1305
+ " <tr>\n",
1306
+ " <th>3</th>\n",
1307
+ " <td>Can cats fly on Tuesdays?</td>\n",
1308
+ " <td>I don't know, sorry</td>\n",
1309
+ " <td>0.011438</td>\n",
1310
+ " <td>True</td>\n",
1311
+ " </tr>\n",
1312
+ " <tr>\n",
1313
+ " <th>4</th>\n",
1314
+ " <td>How much does a thought weigh?</td>\n",
1315
+ " <td>I don't know, sorry</td>\n",
1316
+ " <td>0.010644</td>\n",
1317
+ " <td>True</td>\n",
1318
+ " </tr>\n",
1319
+ " </tbody>\n",
1320
+ "</table>\n",
1321
+ "</div>"
1322
+ ],
1323
+ "text/plain": [
1324
+ " Questions response latency \\\n",
1325
+ "0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
1326
+ "1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
1327
+ "2 What is the color of happiness? I don't know, sorry 0.011900 \n",
1328
+ "3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
1329
+ "4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
1330
+ "\n",
1331
+ " irrelevant_score \n",
1332
+ "0 True \n",
1333
+ "1 True \n",
1334
+ "2 True \n",
1335
+ "3 True \n",
1336
+ "4 True "
1337
+ ]
1338
+ },
1339
+ "execution_count": 80,
1340
+ "metadata": {},
1341
+ "output_type": "execute_result"
1342
+ }
1343
+ ],
1344
+ "source": [
1345
+ "irr_q2.head()"
1346
+ ]
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "execution_count": 74,
1351
+ "id": "4508de9e",
1352
+ "metadata": {},
1353
+ "outputs": [],
1354
+ "source": [
1355
+ "irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 75,
1361
+ "id": "3d34ba06",
1362
+ "metadata": {},
1363
+ "outputs": [
1364
+ {
1365
+ "data": {
1366
+ "text/plain": [
1367
+ "irrelevant_score 1.000000\n",
1368
+ "latency 0.016068\n",
1369
+ "dtype: float64"
1370
+ ]
1371
+ },
1372
+ "execution_count": 75,
1373
+ "metadata": {},
1374
+ "output_type": "execute_result"
1375
+ }
1376
+ ],
1377
+ "source": [
1378
+ "irr_q2[['irrelevant_score','latency']].mean()"
1379
+ ]
1380
+ }
1381
+ ],
1382
+ "metadata": {
1383
+ "kernelspec": {
1384
+ "display_name": "Python 3 (ipykernel)",
1385
+ "language": "python",
1386
+ "name": "python3"
1387
+ },
1388
+ "language_info": {
1389
+ "codemirror_mode": {
1390
+ "name": "ipython",
1391
+ "version": 3
1392
+ },
1393
+ "file_extension": ".py",
1394
+ "mimetype": "text/x-python",
1395
+ "name": "python",
1396
+ "nbconvert_exporter": "python",
1397
+ "pygments_lexer": "ipython3",
1398
+ "version": "3.11.5"
1399
+ }
1400
+ },
1401
+ "nbformat": 4,
1402
+ "nbformat_minor": 5
1403
+ }
Evaluation_MH/Evaluation.ipynb ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f7b87c2c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# pip install -U sentence-transformers"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "8a7cc9d8",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
32
+ " from tqdm.autonotebook import tqdm, trange\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "from sentence_transformers import SentenceTransformer\n",
38
+ "from langchain.prompts import PromptTemplate\n",
39
+ "from langchain.chains import LLMChain\n",
40
+ "from langchain_community.llms import Ollama\n",
41
+ "from langchain.evaluation import load_evaluator\n",
42
+ "import faiss\n",
43
+ "import pandas as pd\n",
44
+ "import numpy as np\n",
45
+ "import pickle\n",
46
+ "import time\n",
47
+ "from tqdm import tqdm"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "id": "b6efca1d",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Intialization"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 2,
61
+ "id": "cc9a49d2",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# Load the FAISS index\n",
66
+ "index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "id": "9af39b55",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "model = SentenceTransformer('all-MiniLM-L6-v2')"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 4,
82
+ "id": "fee8cdfd",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
87
+ " sections_data = pickle.load(f)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "id": "d6a1ba6a",
93
+ "metadata": {},
94
+ "source": [
95
+ "# RAG functions"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 5,
101
+ "id": "182bdbd8",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def search_faiss(query, k=3):\n",
106
+ " query_vector = model.encode([query])[0].astype('float32')\n",
107
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
108
+ " distances, indices = index.search(query_vector, k)\n",
109
+ " \n",
110
+ " results = []\n",
111
+ " for dist, idx in zip(distances[0], indices[0]):\n",
112
+ " results.append({\n",
113
+ " 'distance': dist,\n",
114
+ " 'content': sections_data[idx]['content'],\n",
115
+ " 'metadata': sections_data[idx]['metadata']\n",
116
+ " })\n",
117
+ " \n",
118
+ " return results"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 15,
124
+ "id": "67edc46a",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Create a prompt template\n",
129
+ "prompt_template = \"\"\"\n",
130
+ "You are an AI assistant specialized in Mental Health guidelines. \n",
131
+ "Use the following pieces of context to answer the question. \n",
132
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
133
+ "\n",
134
+ "Context:\n",
135
+ "{context}\n",
136
+ "\n",
137
+ "Question: {question}\n",
138
+ "\n",
139
+ "Answer:\"\"\"\n",
140
+ "\n",
141
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
142
+ "\n",
143
+ "llm = Ollama(\n",
144
+ " model=\"llama3\"\n",
145
+ ")\n",
146
+ "\n",
147
+ "# Create the chain\n",
148
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
149
+ "\n",
150
+ "def answer_question(query):\n",
151
+ " # Search for relevant context\n",
152
+ " search_results = search_faiss(query)\n",
153
+ " \n",
154
+ " # Combine the content from the search results\n",
155
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
156
+ "\n",
157
+ " # Run the chain\n",
158
+ " response = chain.run(context=context, question=query)\n",
159
+ " \n",
160
+ " return response"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "3b176af9",
166
+ "metadata": {},
167
+ "source": [
168
+ "# Reading GT"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 16,
174
+ "id": "4ab68dff",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "df = pd.read_csv('data/MentalHealth_Dataset.csv')"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 17,
184
+ "id": "4e7e22d7",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stderr",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "time_list=[]\n",
197
+ "response_list=[]\n",
198
+ "for i in tqdm(range(len(df))):\n",
199
+ " query = df['Questions'].values[i]\n",
200
+ " start = time.time()\n",
201
+ " response = answer_question(query)\n",
202
+ " end = time.time() \n",
203
+ " time_list.append(end-start)\n",
204
+ " response_list.append(response)"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 18,
210
+ "id": "2b327e90",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "df['latency'] = time_list\n",
215
+ "df['response'] = response_list"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "3c147204",
221
+ "metadata": {},
222
+ "source": [
223
+ "# Evaluation"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 29,
229
+ "id": "d799e541",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "eval_llm = Ollama(\n",
234
+ " model=\"phi3\"\n",
235
+ ")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 30,
241
+ "id": "c2f788dc",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 31,
251
+ "id": "83ec2b8d",
252
+ "metadata": {},
253
+ "outputs": [
254
+ {
255
+ "name": "stderr",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
259
+ "100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
260
+ "100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
261
+ "100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "for metric in metrics:\n",
267
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
268
+ " \n",
269
+ " reasoning = []\n",
270
+ " value = []\n",
271
+ " score = []\n",
272
+ " \n",
273
+ " for i in tqdm(range(len(df))):\n",
274
+ " eval_result = evaluator.evaluate_strings(\n",
275
+ " prediction=df.response.values[i],\n",
276
+ " input=df.Questions.values[i],\n",
277
+ " reference=df.Answers.values[i]\n",
278
+ " )\n",
279
+ " reasoning.append(eval_result['reasoning'])\n",
280
+ " value.append(eval_result['value'])\n",
281
+ " score.append(eval_result['score'])\n",
282
+ " \n",
283
+ " df[metric+'_reasoning'] = reasoning\n",
284
+ " df[metric+'_value'] = value\n",
285
+ " df[metric+'_score'] = score "
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 78,
291
+ "id": "f1673a31",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>Questions</th>\n",
316
+ " <th>Answers</th>\n",
317
+ " <th>latency</th>\n",
318
+ " <th>response</th>\n",
319
+ " <th>correctness_reasoning</th>\n",
320
+ " <th>correctness_value</th>\n",
321
+ " <th>correctness_score</th>\n",
322
+ " <th>relevance_reasoning</th>\n",
323
+ " <th>relevance_value</th>\n",
324
+ " <th>relevance_score</th>\n",
325
+ " <th>coherence_reasoning</th>\n",
326
+ " <th>coherence_value</th>\n",
327
+ " <th>coherence_score</th>\n",
328
+ " <th>conciseness_reasoning</th>\n",
329
+ " <th>conciseness_value</th>\n",
330
+ " <th>conciseness_score</th>\n",
331
+ " </tr>\n",
332
+ " </thead>\n",
333
+ " <tbody>\n",
334
+ " <tr>\n",
335
+ " <th>0</th>\n",
336
+ " <td>What is Mental Health</td>\n",
337
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
338
+ " <td>11.974234</td>\n",
339
+ " <td>Based on the provided context, specifically fr...</td>\n",
340
+ " <td>The submission refers to the provided input wh...</td>\n",
341
+ " <td>Y</td>\n",
342
+ " <td>1</td>\n",
343
+ " <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
344
+ " <td>Y</td>\n",
345
+ " <td>1</td>\n",
346
+ " <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
347
+ " <td>Y</td>\n",
348
+ " <td>1</td>\n",
349
+ " <td>1. The submission directly answers the questio...</td>\n",
350
+ " <td>Y</td>\n",
351
+ " <td>1</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>1</th>\n",
355
+ " <td>What are the most common mental disorders ment...</td>\n",
356
+ " <td>The most common mental disorders include depre...</td>\n",
357
+ " <td>5.863329</td>\n",
358
+ " <td>Based on the provided context, the mental diso...</td>\n",
359
+ " <td>Step 1: Check if the submission is factually a...</td>\n",
360
+ " <td>Y</td>\n",
361
+ " <td>1</td>\n",
362
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
363
+ " <td>Y</td>\n",
364
+ " <td>1</td>\n",
365
+ " <td>The submission begins with an appropriate ques...</td>\n",
366
+ " <td>Y</td>\n",
367
+ " <td>1</td>\n",
368
+ " <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
369
+ " <td>Y</td>\n",
370
+ " <td>1</td>\n",
371
+ " </tr>\n",
372
+ " <tr>\n",
373
+ " <th>2</th>\n",
374
+ " <td>What are the early warning signs and symptoms ...</td>\n",
375
+ " <td>Early warning signs and symptoms of depression...</td>\n",
376
+ " <td>13.434543</td>\n",
377
+ " <td>Based on the provided context, I found a refer...</td>\n",
378
+ " <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
379
+ " <td>Y</td>\n",
380
+ " <td>1</td>\n",
381
+ " <td>Step 1: Identify the relevant criterion from t...</td>\n",
382
+ " <td>Y</td>\n",
383
+ " <td>1</td>\n",
384
+ " <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
385
+ " <td>Y</td>\n",
386
+ " <td>1</td>\n",
387
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
388
+ " <td>Y</td>\n",
389
+ " <td>1</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <th>3</th>\n",
393
+ " <td>How can someone help a person who suffers from...</td>\n",
394
+ " <td>To help someone with anxiety, one can support ...</td>\n",
395
+ " <td>13.838464</td>\n",
396
+ " <td>According to the provided context, specificall...</td>\n",
397
+ " <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
398
+ " <td>Y</td>\n",
399
+ " <td>1</td>\n",
400
+ " <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
401
+ " <td>Y</td>\n",
402
+ " <td>1</td>\n",
403
+ " <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
404
+ " <td>Y</td>\n",
405
+ " <td>1</td>\n",
406
+ " <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
407
+ " <td>N</td>\n",
408
+ " <td>0</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <th>4</th>\n",
412
+ " <td>What are the causes of mental illness listed i...</td>\n",
413
+ " <td>Causes of mental illness include abnormal func...</td>\n",
414
+ " <td>6.871735</td>\n",
415
+ " <td>According to the provided context, the causes ...</td>\n",
416
+ " <td>The submission lists factors that align with t...</td>\n",
417
+ " <td>N</td>\n",
418
+ " <td>0</td>\n",
419
+ " <td>Step 1: Review relevance criterion - Check if ...</td>\n",
420
+ " <td>Y</td>\n",
421
+ " <td>1</td>\n",
422
+ " <td>Step 1: Compare the submission with the provid...</td>\n",
423
+ " <td>Y</td>\n",
424
+ " <td>1</td>\n",
425
+ " <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
426
+ " <td>Y</td>\n",
427
+ " <td>1</td>\n",
428
+ " </tr>\n",
429
+ " </tbody>\n",
430
+ "</table>\n",
431
+ "</div>"
432
+ ],
433
+ "text/plain": [
434
+ " Questions \\\n",
435
+ "0 What is Mental Health \n",
436
+ "1 What are the most common mental disorders ment... \n",
437
+ "2 What are the early warning signs and symptoms ... \n",
438
+ "3 How can someone help a person who suffers from... \n",
439
+ "4 What are the causes of mental illness listed i... \n",
440
+ "\n",
441
+ " Answers latency \\\n",
442
+ "0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
443
+ "1 The most common mental disorders include depre... 5.863329 \n",
444
+ "2 Early warning signs and symptoms of depression... 13.434543 \n",
445
+ "3 To help someone with anxiety, one can support ... 13.838464 \n",
446
+ "4 Causes of mental illness include abnormal func... 6.871735 \n",
447
+ "\n",
448
+ " response \\\n",
449
+ "0 Based on the provided context, specifically fr... \n",
450
+ "1 Based on the provided context, the mental diso... \n",
451
+ "2 Based on the provided context, I found a refer... \n",
452
+ "3 According to the provided context, specificall... \n",
453
+ "4 According to the provided context, the causes ... \n",
454
+ "\n",
455
+ " correctness_reasoning correctness_value \\\n",
456
+ "0 The submission refers to the provided input wh... Y \n",
457
+ "1 Step 1: Check if the submission is factually a... Y \n",
458
+ "2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
459
+ "3 Step 1: Correctness\\nThe submission accurately... Y \n",
460
+ "4 The submission lists factors that align with t... N \n",
461
+ "\n",
462
+ " correctness_score relevance_reasoning \\\n",
463
+ "0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
464
+ "1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
465
+ "2 1 Step 1: Identify the relevant criterion from t... \n",
466
+ "3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
467
+ "4 0 Step 1: Review relevance criterion - Check if ... \n",
468
+ "\n",
469
+ " relevance_value relevance_score \\\n",
470
+ "0 Y 1 \n",
471
+ "1 Y 1 \n",
472
+ "2 Y 1 \n",
473
+ "3 Y 1 \n",
474
+ "4 Y 1 \n",
475
+ "\n",
476
+ " coherence_reasoning coherence_value \\\n",
477
+ "0 Step 1: Assess coherence\\nThe submission direc... Y \n",
478
+ "1 The submission begins with an appropriate ques... Y \n",
479
+ "2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
480
+ "3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
481
+ "4 Step 1: Compare the submission with the provid... Y \n",
482
+ "\n",
483
+ " coherence_score conciseness_reasoning \\\n",
484
+ "0 1 1. The submission directly answers the questio... \n",
485
+ "1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
486
+ "2 1 Step 1: Evaluate conciseness - The submission ... \n",
487
+ "3 1 Step 1: Evaluate conciseness - The submission ... \n",
488
+ "4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
489
+ "\n",
490
+ " conciseness_value conciseness_score \n",
491
+ "0 Y 1 \n",
492
+ "1 Y 1 \n",
493
+ "2 Y 1 \n",
494
+ "3 N 0 \n",
495
+ "4 Y 1 "
496
+ ]
497
+ },
498
+ "execution_count": 78,
499
+ "metadata": {},
500
+ "output_type": "execute_result"
501
+ }
502
+ ],
503
+ "source": [
504
+ "df.head()"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 32,
510
+ "id": "7797a360",
511
+ "metadata": {},
512
+ "outputs": [
513
+ {
514
+ "data": {
515
+ "text/plain": [
516
+ "correctness_score 0.800000\n",
517
+ "relevance_score 0.900000\n",
518
+ "coherence_score 1.000000\n",
519
+ "conciseness_score 0.800000\n",
520
+ "latency 10.544803\n",
521
+ "dtype: float64"
522
+ ]
523
+ },
524
+ "execution_count": 32,
525
+ "metadata": {},
526
+ "output_type": "execute_result"
527
+ }
528
+ ],
529
+ "source": [
530
+ "df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 34,
536
+ "id": "fe667926",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "irr_q=pd.read_csv('data/Unrelated_questions.csv')"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 35,
546
+ "id": "189f8a0f",
547
+ "metadata": {},
548
+ "outputs": [
549
+ {
550
+ "name": "stderr",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
554
+ ]
555
+ }
556
+ ],
557
+ "source": [
558
+ "time_list=[]\n",
559
+ "response_list=[]\n",
560
+ "for i in tqdm(range(len(irr_q))):\n",
561
+ " query = irr_q['Questions'].values[i]\n",
562
+ " start = time.time()\n",
563
+ " response = answer_question(query)\n",
564
+ " end = time.time() \n",
565
+ " time_list.append(end-start)\n",
566
+ " response_list.append(response)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 36,
572
+ "id": "b0244ea0",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "irr_q['response']=response_list\n",
577
+ "irr_q['latency']=time_list"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 79,
583
+ "id": "dc3b1ade",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/html": [
589
+ "<div>\n",
590
+ "<style scoped>\n",
591
+ " .dataframe tbody tr th:only-of-type {\n",
592
+ " vertical-align: middle;\n",
593
+ " }\n",
594
+ "\n",
595
+ " .dataframe tbody tr th {\n",
596
+ " vertical-align: top;\n",
597
+ " }\n",
598
+ "\n",
599
+ " .dataframe thead th {\n",
600
+ " text-align: right;\n",
601
+ " }\n",
602
+ "</style>\n",
603
+ "<table border=\"1\" class=\"dataframe\">\n",
604
+ " <thead>\n",
605
+ " <tr style=\"text-align: right;\">\n",
606
+ " <th></th>\n",
607
+ " <th>Questions</th>\n",
608
+ " <th>response</th>\n",
609
+ " <th>latency</th>\n",
610
+ " <th>irrelevant_score</th>\n",
611
+ " </tr>\n",
612
+ " </thead>\n",
613
+ " <tbody>\n",
614
+ " <tr>\n",
615
+ " <th>0</th>\n",
616
+ " <td>What is the capital of Mars?</td>\n",
617
+ " <td>I don't know. The provided context does not se...</td>\n",
618
+ " <td>12.207266</td>\n",
619
+ " <td>True</td>\n",
620
+ " </tr>\n",
621
+ " <tr>\n",
622
+ " <th>1</th>\n",
623
+ " <td>How many unicorns live in New York City?</td>\n",
624
+ " <td>I don't know. The information provided does no...</td>\n",
625
+ " <td>2.368774</td>\n",
626
+ " <td>True</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <th>2</th>\n",
630
+ " <td>What is the color of happiness?</td>\n",
631
+ " <td>I don't know! The provided context only talks ...</td>\n",
632
+ " <td>5.480067</td>\n",
633
+ " <td>True</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <th>3</th>\n",
637
+ " <td>Can cats fly on Tuesdays?</td>\n",
638
+ " <td>I don't know the answer to this question as it...</td>\n",
639
+ " <td>5.272529</td>\n",
640
+ " <td>True</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <th>4</th>\n",
644
+ " <td>How much does a thought weigh?</td>\n",
645
+ " <td>I don't know. The context provided is about me...</td>\n",
646
+ " <td>5.253224</td>\n",
647
+ " <td>True</td>\n",
648
+ " </tr>\n",
649
+ " </tbody>\n",
650
+ "</table>\n",
651
+ "</div>"
652
+ ],
653
+ "text/plain": [
654
+ " Questions \\\n",
655
+ "0 What is the capital of Mars? \n",
656
+ "1 How many unicorns live in New York City? \n",
657
+ "2 What is the color of happiness? \n",
658
+ "3 Can cats fly on Tuesdays? \n",
659
+ "4 How much does a thought weigh? \n",
660
+ "\n",
661
+ " response latency \\\n",
662
+ "0 I don't know. The provided context does not se... 12.207266 \n",
663
+ "1 I don't know. The information provided does no... 2.368774 \n",
664
+ "2 I don't know! The provided context only talks ... 5.480067 \n",
665
+ "3 I don't know the answer to this question as it... 5.272529 \n",
666
+ "4 I don't know. The context provided is about me... 5.253224 \n",
667
+ "\n",
668
+ " irrelevant_score \n",
669
+ "0 True \n",
670
+ "1 True \n",
671
+ "2 True \n",
672
+ "3 True \n",
673
+ "4 True "
674
+ ]
675
+ },
676
+ "execution_count": 79,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "irr_q.head()"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": 37,
688
+ "id": "8620e50c",
689
+ "metadata": {},
690
+ "outputs": [
691
+ {
692
+ "data": {
693
+ "text/plain": [
694
+ "0 12.207266\n",
695
+ "1 2.368774\n",
696
+ "2 5.480067\n",
697
+ "3 5.272529\n",
698
+ "4 5.253224\n",
699
+ "5 5.351224\n",
700
+ "6 8.118429\n",
701
+ "7 7.288261\n",
702
+ "8 3.856500\n",
703
+ "9 7.745016\n",
704
+ "Name: latency, dtype: float64"
705
+ ]
706
+ },
707
+ "execution_count": 37,
708
+ "metadata": {},
709
+ "output_type": "execute_result"
710
+ }
711
+ ],
712
+ "source": [
713
+ "irr_q['latency']"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": 39,
719
+ "id": "debd3461",
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": [
723
+ "irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 40,
729
+ "id": "bef1d3a4",
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "data": {
734
+ "text/plain": [
735
+ "irrelevant_score 0.900000\n",
736
+ "latency 6.294129\n",
737
+ "dtype: float64"
738
+ ]
739
+ },
740
+ "execution_count": 40,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "irr_q[['irrelevant_score','latency']].mean()"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "markdown",
751
+ "id": "c1610a70",
752
+ "metadata": {},
753
+ "source": [
754
+ "# Improvement"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 48,
760
+ "id": "ff6614f9",
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": [
764
+ "new_prompt_template = \"\"\"\n",
765
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
766
+ "Use the provided context to answer the question short and accurately. \n",
767
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
768
+ "\n",
769
+ "Context:\n",
770
+ "{context}\n",
771
+ "\n",
772
+ "Question: {question}\n",
773
+ "\n",
774
+ "Answer:\"\"\"\n",
775
+ "\n",
776
+ "prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
777
+ "\n",
778
+ "llm = Ollama(\n",
779
+ " model=\"llama3\"\n",
780
+ ")\n",
781
+ "\n",
782
+ "# Create the chain\n",
783
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
784
+ "\n",
785
+ "def answer_question_new(query):\n",
786
+ " # Search for relevant context\n",
787
+ " search_results = search_faiss(query)\n",
788
+ " \n",
789
+ " # Combine the content from the search results\n",
790
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
791
+ "\n",
792
+ " # Run the chain\n",
793
+ " response = chain.run(context=context, question=query)\n",
794
+ " \n",
795
+ " return response"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 49,
801
+ "id": "20580d50",
802
+ "metadata": {},
803
+ "outputs": [],
804
+ "source": [
805
+ "df2=df.copy()"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 50,
811
+ "id": "b1b3d725",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "name": "stderr",
816
+ "output_type": "stream",
817
+ "text": [
818
+ "100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
819
+ ]
820
+ }
821
+ ],
822
+ "source": [
823
+ "time_list=[]\n",
824
+ "response_list=[]\n",
825
+ "for i in tqdm(range(len(df2))):\n",
826
+ " query = df2['Questions'].values[i]\n",
827
+ " start = time.time()\n",
828
+ " response = answer_question(query)\n",
829
+ " end = time.time() \n",
830
+ " time_list.append(end-start)\n",
831
+ " response_list.append(response)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 51,
837
+ "id": "63f41256",
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "df2['latency'] = time_list\n",
842
+ "df2['response'] = response_list"
843
+ ]
844
+ },
845
+ {
846
+ "cell_type": "code",
847
+ "execution_count": 52,
848
+ "id": "0d8a6065",
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stderr",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
856
+ "100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
857
+ "100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
858
+ "100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
859
+ ]
860
+ }
861
+ ],
862
+ "source": [
863
+ "for metric in metrics:\n",
864
+ " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
865
+ " \n",
866
+ " reasoning = []\n",
867
+ " value = []\n",
868
+ " score = []\n",
869
+ " \n",
870
+ " for i in tqdm(range(len(df2))):\n",
871
+ " eval_result = evaluator.evaluate_strings(\n",
872
+ " prediction=df2.response.values[i],\n",
873
+ " input=df2.Questions.values[i],\n",
874
+ " reference=df2.Answers.values[i]\n",
875
+ " )\n",
876
+ " reasoning.append(eval_result['reasoning'])\n",
877
+ " value.append(eval_result['value'])\n",
878
+ " score.append(eval_result['score'])\n",
879
+ " \n",
880
+ " df2[metric+'_reasoning'] = reasoning\n",
881
+ " df2[metric+'_value'] = value\n",
882
+ " df2[metric+'_score'] = score "
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 77,
888
+ "id": "c648632c",
889
+ "metadata": {},
890
+ "outputs": [
891
+ {
892
+ "data": {
893
+ "text/html": [
894
+ "<div>\n",
895
+ "<style scoped>\n",
896
+ " .dataframe tbody tr th:only-of-type {\n",
897
+ " vertical-align: middle;\n",
898
+ " }\n",
899
+ "\n",
900
+ " .dataframe tbody tr th {\n",
901
+ " vertical-align: top;\n",
902
+ " }\n",
903
+ "\n",
904
+ " .dataframe thead th {\n",
905
+ " text-align: right;\n",
906
+ " }\n",
907
+ "</style>\n",
908
+ "<table border=\"1\" class=\"dataframe\">\n",
909
+ " <thead>\n",
910
+ " <tr style=\"text-align: right;\">\n",
911
+ " <th></th>\n",
912
+ " <th>Questions</th>\n",
913
+ " <th>Answers</th>\n",
914
+ " <th>latency</th>\n",
915
+ " <th>response</th>\n",
916
+ " <th>correctness_reasoning</th>\n",
917
+ " <th>correctness_value</th>\n",
918
+ " <th>correctness_score</th>\n",
919
+ " <th>relevance_reasoning</th>\n",
920
+ " <th>relevance_value</th>\n",
921
+ " <th>relevance_score</th>\n",
922
+ " <th>coherence_reasoning</th>\n",
923
+ " <th>coherence_value</th>\n",
924
+ " <th>coherence_score</th>\n",
925
+ " <th>conciseness_reasoning</th>\n",
926
+ " <th>conciseness_value</th>\n",
927
+ " <th>conciseness_score</th>\n",
928
+ " </tr>\n",
929
+ " </thead>\n",
930
+ " <tbody>\n",
931
+ " <tr>\n",
932
+ " <th>0</th>\n",
933
+ " <td>What is Mental Health</td>\n",
934
+ " <td>Mental Health is a \" state of well-being in wh...</td>\n",
935
+ " <td>11.046327</td>\n",
936
+ " <td>Based on the context provided, mental health r...</td>\n",
937
+ " <td>Step 1: Evaluate if the submission is factuall...</td>\n",
938
+ " <td>N</td>\n",
939
+ " <td>0</td>\n",
940
+ " <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
941
+ " <td>N</td>\n",
942
+ " <td>0</td>\n",
943
+ " <td>The submission discusses mental health in rela...</td>\n",
944
+ " <td>Y</td>\n",
945
+ " <td>1</td>\n",
946
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
947
+ " <td>Y</td>\n",
948
+ " <td>1</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <th>1</th>\n",
952
+ " <td>What are the most common mental disorders ment...</td>\n",
953
+ " <td>The most common mental disorders include depre...</td>\n",
954
+ " <td>4.509713</td>\n",
955
+ " <td>The handbook mentions several mental illnesses...</td>\n",
956
+ " <td>The submission mentions depression and schizop...</td>\n",
957
+ " <td>N</td>\n",
958
+ " <td>0</td>\n",
959
+ " <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
960
+ " <td>Y</td>\n",
961
+ " <td>1</td>\n",
962
+ " <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
963
+ " <td>N</td>\n",
964
+ " <td>0</td>\n",
965
+ " <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
966
+ " <td>N</td>\n",
967
+ " <td>0</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <th>2</th>\n",
971
+ " <td>What are the early warning signs and symptoms ...</td>\n",
972
+ " <td>Early warning signs and symptoms of depression...</td>\n",
973
+ " <td>8.501180</td>\n",
974
+ " <td>According to the provided context, specificall...</td>\n",
975
+ " <td>The submission matches the reference data in t...</td>\n",
976
+ " <td>Y</td>\n",
977
+ " <td>1</td>\n",
978
+ " <td>The submission refers directly to information ...</td>\n",
979
+ " <td>Y</td>\n",
980
+ " <td>1</td>\n",
981
+ " <td>Step 1: Evaluate coherence - The submission is...</td>\n",
982
+ " <td>Y</td>\n",
983
+ " <td>1</td>\n",
984
+ " <td>The submission is concise and includes most of...</td>\n",
985
+ " <td>Y</td>\n",
986
+ " <td>1</td>\n",
987
+ " </tr>\n",
988
+ " <tr>\n",
989
+ " <th>3</th>\n",
990
+ " <td>How can someone help a person who suffers from...</td>\n",
991
+ " <td>To help someone with anxiety, one can support ...</td>\n",
992
+ " <td>10.611402</td>\n",
993
+ " <td>According to the Mental Health Handbook, when ...</td>\n",
994
+ " <td>The submission seems consistent with the refer...</td>\n",
995
+ " <td>Y</td>\n",
996
+ " <td>1</td>\n",
997
+ " <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
998
+ " <td>Y</td>\n",
999
+ " <td>1</td>\n",
1000
+ " <td>The submission is coherent, well-structured, a...</td>\n",
1001
+ " <td>Y</td>\n",
1002
+ " <td>1</td>\n",
1003
+ " <td>The submission is relatively concise and cover...</td>\n",
1004
+ " <td>Y</td>\n",
1005
+ " <td>1</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <th>4</th>\n",
1009
+ " <td>What are the causes of mental illness listed i...</td>\n",
1010
+ " <td>Causes of mental illness include abnormal func...</td>\n",
1011
+ " <td>6.299272</td>\n",
1012
+ " <td>According to the context, the causes of mental...</td>\n",
1013
+ " <td>The submission lists causes such as neglect, s...</td>\n",
1014
+ " <td>N</td>\n",
1015
+ " <td>0</td>\n",
1016
+ " <td>The submission mentions factors that are part ...</td>\n",
1017
+ " <td>N</td>\n",
1018
+ " <td>0</td>\n",
1019
+ " <td>The submission is coherent and well-structured...</td>\n",
1020
+ " <td>Y</td>\n",
1021
+ " <td>1</td>\n",
1022
+ " <td>Step 1: Read and understand both the input dat...</td>\n",
1023
+ " <td>N</td>\n",
1024
+ " <td>0</td>\n",
1025
+ " </tr>\n",
1026
+ " </tbody>\n",
1027
+ "</table>\n",
1028
+ "</div>"
1029
+ ],
1030
+ "text/plain": [
1031
+ " Questions \\\n",
1032
+ "0 What is Mental Health \n",
1033
+ "1 What are the most common mental disorders ment... \n",
1034
+ "2 What are the early warning signs and symptoms ... \n",
1035
+ "3 How can someone help a person who suffers from... \n",
1036
+ "4 What are the causes of mental illness listed i... \n",
1037
+ "\n",
1038
+ " Answers latency \\\n",
1039
+ "0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
1040
+ "1 The most common mental disorders include depre... 4.509713 \n",
1041
+ "2 Early warning signs and symptoms of depression... 8.501180 \n",
1042
+ "3 To help someone with anxiety, one can support ... 10.611402 \n",
1043
+ "4 Causes of mental illness include abnormal func... 6.299272 \n",
1044
+ "\n",
1045
+ " response \\\n",
1046
+ "0 Based on the context provided, mental health r... \n",
1047
+ "1 The handbook mentions several mental illnesses... \n",
1048
+ "2 According to the provided context, specificall... \n",
1049
+ "3 According to the Mental Health Handbook, when ... \n",
1050
+ "4 According to the context, the causes of mental... \n",
1051
+ "\n",
1052
+ " correctness_reasoning correctness_value \\\n",
1053
+ "0 Step 1: Evaluate if the submission is factuall... N \n",
1054
+ "1 The submission mentions depression and schizop... N \n",
1055
+ "2 The submission matches the reference data in t... Y \n",
1056
+ "3 The submission seems consistent with the refer... Y \n",
1057
+ "4 The submission lists causes such as neglect, s... N \n",
1058
+ "\n",
1059
+ " correctness_score relevance_reasoning \\\n",
1060
+ "0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
1061
+ "1 0 Step 1: Analyze relevance criterion - Check if... \n",
1062
+ "2 1 The submission refers directly to information ... \n",
1063
+ "3 1 Step 1: Review relevance criterion\\nThe submis... \n",
1064
+ "4 0 The submission mentions factors that are part ... \n",
1065
+ "\n",
1066
+ " relevance_value relevance_score \\\n",
1067
+ "0 N 0 \n",
1068
+ "1 Y 1 \n",
1069
+ "2 Y 1 \n",
1070
+ "3 Y 1 \n",
1071
+ "4 N 0 \n",
1072
+ "\n",
1073
+ " coherence_reasoning coherence_value \\\n",
1074
+ "0 The submission discusses mental health in rela... Y \n",
1075
+ "1 Step 1: Assess coherence\\nThe submission menti... N \n",
1076
+ "2 Step 1: Evaluate coherence - The submission is... Y \n",
1077
+ "3 The submission is coherent, well-structured, a... Y \n",
1078
+ "4 The submission is coherent and well-structured... Y \n",
1079
+ "\n",
1080
+ " coherence_score conciseness_reasoning \\\n",
1081
+ "0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1082
+ "1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
1083
+ "2 1 The submission is concise and includes most of... \n",
1084
+ "3 1 The submission is relatively concise and cover... \n",
1085
+ "4 1 Step 1: Read and understand both the input dat... \n",
1086
+ "\n",
1087
+ " conciseness_value conciseness_score \n",
1088
+ "0 Y 1 \n",
1089
+ "1 N 0 \n",
1090
+ "2 Y 1 \n",
1091
+ "3 Y 1 \n",
1092
+ "4 N 0 "
1093
+ ]
1094
+ },
1095
+ "execution_count": 77,
1096
+ "metadata": {},
1097
+ "output_type": "execute_result"
1098
+ }
1099
+ ],
1100
+ "source": [
1101
+ "df2.head()"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": 47,
1107
+ "id": "2d1002b2",
1108
+ "metadata": {},
1109
+ "outputs": [
1110
+ {
1111
+ "data": {
1112
+ "text/plain": [
1113
+ "correctness_score 0.500000\n",
1114
+ "relevance_score 0.888889\n",
1115
+ "coherence_score 0.888889\n",
1116
+ "conciseness_score 0.900000\n",
1117
+ "latency 8.190205\n",
1118
+ "dtype: float64"
1119
+ ]
1120
+ },
1121
+ "execution_count": 47,
1122
+ "metadata": {},
1123
+ "output_type": "execute_result"
1124
+ }
1125
+ ],
1126
+ "source": [
1127
+ "df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "markdown",
1132
+ "id": "e808bdcf",
1133
+ "metadata": {},
1134
+ "source": [
1135
+ "# Query relevance"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "execution_count": 66,
1141
+ "id": "6b541f3d",
1142
+ "metadata": {},
1143
+ "outputs": [],
1144
+ "source": [
1145
+ "def new_search_faiss(query, k=3, threshold=1.5):\n",
1146
+ " query_vector = model.encode([query])[0].astype('float32')\n",
1147
+ " query_vector = np.expand_dims(query_vector, axis=0)\n",
1148
+ " distances, indices = index.search(query_vector, k)\n",
1149
+ " \n",
1150
+ " results = []\n",
1151
+ " for dist, idx in zip(distances[0], indices[0]):\n",
1152
+ " if dist < threshold: # Only include results within the threshold distance\n",
1153
+ " results.append({\n",
1154
+ " 'distance': dist,\n",
1155
+ " 'content': sections_data[idx]['content'],\n",
1156
+ " 'metadata': sections_data[idx]['metadata']\n",
1157
+ " })\n",
1158
+ " \n",
1159
+ " return results"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "code",
1164
+ "execution_count": 70,
1165
+ "id": "4f579654",
1166
+ "metadata": {},
1167
+ "outputs": [],
1168
+ "source": [
1169
+ "new_prompt_template = \"\"\"\n",
1170
+ "You are an AI assistant specialized in Mental Health guidelines.\n",
1171
+ "Use the provided context to answer the question short and accurately. \n",
1172
+ "If you don't know the answer, simply say, \"I don't know.\"\n",
1173
+ "\n",
1174
+ "Context:\n",
1175
+ "{context}\n",
1176
+ "\n",
1177
+ "Question: {question}\n",
1178
+ "\n",
1179
+ "Answer:\"\"\"\n",
1180
+ "\n",
1181
+ "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
1182
+ "\n",
1183
+ "llm = Ollama(\n",
1184
+ " model=\"llama3\"\n",
1185
+ ")\n",
1186
+ "\n",
1187
+ "# Create the chain\n",
1188
+ "chain = LLMChain(llm=llm, prompt=prompt)\n",
1189
+ "\n",
1190
+ "def new_answer_question(query):\n",
1191
+ " # Search for relevant context\n",
1192
+ " search_results = new_search_faiss(query)\n",
1193
+ " \n",
1194
+ " if search_results==[]:\n",
1195
+ " response=\"I don't know, sorry\"\n",
1196
+ " else:\n",
1197
+ " context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
1198
+ " response = chain.run(context=context, question=query)\n",
1199
+ " \n",
1200
+ " return response"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": 71,
1206
+ "id": "1f83ef1b",
1207
+ "metadata": {},
1208
+ "outputs": [],
1209
+ "source": [
1210
+ "irr_q2=irr_q.copy()"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": 72,
1216
+ "id": "f06474e3",
1217
+ "metadata": {},
1218
+ "outputs": [
1219
+ {
1220
+ "name": "stderr",
1221
+ "output_type": "stream",
1222
+ "text": [
1223
+ "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
1224
+ ]
1225
+ }
1226
+ ],
1227
+ "source": [
1228
+ "time_list=[]\n",
1229
+ "response_list=[]\n",
1230
+ "for i in tqdm(range(len(irr_q2))):\n",
1231
+ " query = irr_q['Questions'].values[i]\n",
1232
+ " start = time.time()\n",
1233
+ " response = new_answer_question(query)\n",
1234
+ " end = time.time() \n",
1235
+ " time_list.append(end-start)\n",
1236
+ " response_list.append(response)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": 73,
1242
+ "id": "52db6b82",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": [
1246
+ "irr_q2['response']=response_list\n",
1247
+ "irr_q2['latency']=time_list"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 80,
1253
+ "id": "80a178ee",
1254
+ "metadata": {},
1255
+ "outputs": [
1256
+ {
1257
+ "data": {
1258
+ "text/html": [
1259
+ "<div>\n",
1260
+ "<style scoped>\n",
1261
+ " .dataframe tbody tr th:only-of-type {\n",
1262
+ " vertical-align: middle;\n",
1263
+ " }\n",
1264
+ "\n",
1265
+ " .dataframe tbody tr th {\n",
1266
+ " vertical-align: top;\n",
1267
+ " }\n",
1268
+ "\n",
1269
+ " .dataframe thead th {\n",
1270
+ " text-align: right;\n",
1271
+ " }\n",
1272
+ "</style>\n",
1273
+ "<table border=\"1\" class=\"dataframe\">\n",
1274
+ " <thead>\n",
1275
+ " <tr style=\"text-align: right;\">\n",
1276
+ " <th></th>\n",
1277
+ " <th>Questions</th>\n",
1278
+ " <th>response</th>\n",
1279
+ " <th>latency</th>\n",
1280
+ " <th>irrelevant_score</th>\n",
1281
+ " </tr>\n",
1282
+ " </thead>\n",
1283
+ " <tbody>\n",
1284
+ " <tr>\n",
1285
+ " <th>0</th>\n",
1286
+ " <td>What is the capital of Mars?</td>\n",
1287
+ " <td>I don't know, sorry</td>\n",
1288
+ " <td>0.061378</td>\n",
1289
+ " <td>True</td>\n",
1290
+ " </tr>\n",
1291
+ " <tr>\n",
1292
+ " <th>1</th>\n",
1293
+ " <td>How many unicorns live in New York City?</td>\n",
1294
+ " <td>I don't know, sorry</td>\n",
1295
+ " <td>0.012511</td>\n",
1296
+ " <td>True</td>\n",
1297
+ " </tr>\n",
1298
+ " <tr>\n",
1299
+ " <th>2</th>\n",
1300
+ " <td>What is the color of happiness?</td>\n",
1301
+ " <td>I don't know, sorry</td>\n",
1302
+ " <td>0.011900</td>\n",
1303
+ " <td>True</td>\n",
1304
+ " </tr>\n",
1305
+ " <tr>\n",
1306
+ " <th>3</th>\n",
1307
+ " <td>Can cats fly on Tuesdays?</td>\n",
1308
+ " <td>I don't know, sorry</td>\n",
1309
+ " <td>0.011438</td>\n",
1310
+ " <td>True</td>\n",
1311
+ " </tr>\n",
1312
+ " <tr>\n",
1313
+ " <th>4</th>\n",
1314
+ " <td>How much does a thought weigh?</td>\n",
1315
+ " <td>I don't know, sorry</td>\n",
1316
+ " <td>0.010644</td>\n",
1317
+ " <td>True</td>\n",
1318
+ " </tr>\n",
1319
+ " </tbody>\n",
1320
+ "</table>\n",
1321
+ "</div>"
1322
+ ],
1323
+ "text/plain": [
1324
+ " Questions response latency \\\n",
1325
+ "0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
1326
+ "1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
1327
+ "2 What is the color of happiness? I don't know, sorry 0.011900 \n",
1328
+ "3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
1329
+ "4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
1330
+ "\n",
1331
+ " irrelevant_score \n",
1332
+ "0 True \n",
1333
+ "1 True \n",
1334
+ "2 True \n",
1335
+ "3 True \n",
1336
+ "4 True "
1337
+ ]
1338
+ },
1339
+ "execution_count": 80,
1340
+ "metadata": {},
1341
+ "output_type": "execute_result"
1342
+ }
1343
+ ],
1344
+ "source": [
1345
+ "irr_q2.head()"
1346
+ ]
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "execution_count": 74,
1351
+ "id": "4508de9e",
1352
+ "metadata": {},
1353
+ "outputs": [],
1354
+ "source": [
1355
+ "irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 75,
1361
+ "id": "3d34ba06",
1362
+ "metadata": {},
1363
+ "outputs": [
1364
+ {
1365
+ "data": {
1366
+ "text/plain": [
1367
+ "irrelevant_score 1.000000\n",
1368
+ "latency 0.016068\n",
1369
+ "dtype: float64"
1370
+ ]
1371
+ },
1372
+ "execution_count": 75,
1373
+ "metadata": {},
1374
+ "output_type": "execute_result"
1375
+ }
1376
+ ],
1377
+ "source": [
1378
+ "irr_q2[['irrelevant_score','latency']].mean()"
1379
+ ]
1380
+ }
1381
+ ],
1382
+ "metadata": {
1383
+ "kernelspec": {
1384
+ "display_name": "Python 3 (ipykernel)",
1385
+ "language": "python",
1386
+ "name": "python3"
1387
+ },
1388
+ "language_info": {
1389
+ "codemirror_mode": {
1390
+ "name": "ipython",
1391
+ "version": 3
1392
+ },
1393
+ "file_extension": ".py",
1394
+ "mimetype": "text/x-python",
1395
+ "name": "python",
1396
+ "nbconvert_exporter": "python",
1397
+ "pygments_lexer": "ipython3",
1398
+ "version": "3.11.0"
1399
+ }
1400
+ },
1401
+ "nbformat": 4,
1402
+ "nbformat_minor": 5
1403
+ }
Evaluation_MH/Mental Health Evaluation Report.pdf ADDED
Binary file (72.9 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Aditi Yadav
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MentalHealth/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Aditi Yadav
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MentalHealth/app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains import LLMChain
5
+ from langchain_community.llms import Ollama
6
+ import faiss
7
+ import numpy as np
8
+ import pickle
9
+
10
+ # Load the FAISS index
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_faiss_index():
13
+ try:
14
+ return faiss.read_index("database/pdf_sections_index.faiss")
15
+ except FileNotFoundError:
16
+ st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
17
+ st.stop()
18
+
19
+ # Load the embedding model
20
+ @st.cache(allow_output_mutation=True)
21
+ def load_embedding_model():
22
+ return SentenceTransformer('all-MiniLM-L6-v2')
23
+
24
+ # Load sections data
25
+ @st.cache(allow_output_mutation=True)
26
+ def load_sections_data():
27
+ try:
28
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
29
+ return pickle.load(f)
30
+ except FileNotFoundError:
31
+ st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
32
+ st.stop()
33
+
34
+ # Initialize resources
35
+ index = load_faiss_index()
36
+ model = load_embedding_model()
37
+ sections_data = load_sections_data()
38
+
39
+ def search_faiss(query, k=3):
40
+ query_vector = model.encode([query])[0].astype('float32')
41
+ query_vector = np.expand_dims(query_vector, axis=0)
42
+ distances, indices = index.search(query_vector, k)
43
+
44
+ results = []
45
+ for dist, idx in zip(distances[0], indices[0]):
46
+ results.append({
47
+ 'distance': dist,
48
+ 'content': sections_data[idx]['content'],
49
+ 'metadata': sections_data[idx]['metadata']
50
+ })
51
+
52
+ return results
53
+
54
+ prompt_template = """
55
+ You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
56
+
57
+ Context:
58
+ {context}
59
+
60
+ Question: {question}
61
+
62
+ Answer:"""
63
+
64
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
65
+
66
+ @st.cache(allow_output_mutation=True)
67
+ def load_llm():
68
+ return Ollama(model="llama3")
69
+
70
+ llm = load_llm()
71
+ chain = LLMChain(llm=llm, prompt=prompt)
72
+
73
+ def answer_question(query):
74
+ search_results = search_faiss(query)
75
+ context = "\n\n".join([result['content'] for result in search_results])
76
+ response = chain.run(context=context, question=query)
77
+ return response, context
78
+
79
+ # Streamlit UI
80
+ st.title("Mental Health Guidelines Q&A")
81
+
82
+ query = st.text_input("Enter your question about Mental Health guidelines:")
83
+
84
+ if st.button("Get Answer"):
85
+ if query:
86
+ with st.spinner("Searching and generating answer..."):
87
+ answer, context = answer_question(query)
88
+ st.subheader("Answer:")
89
+ st.write(answer)
90
+ with st.expander("Show Context"):
91
+ st.write(context)
92
+ else:
93
+ st.warning("Please enter a question.")
MentalHealth/create_vectordb.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import pickle
7
+
8
+ # Load the PDF
9
+ pdf_path = "data\Mental Health Handbook English.pdf"
10
+ loader = PyPDFLoader(file_path=pdf_path)
11
+
12
+ # Load the content
13
+ documents = loader.load()
14
+
15
+ # Split the document into sections
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
17
+ sections = text_splitter.split_documents(documents)
18
+
19
+ # Load the embedding model
20
+ model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ # Generate embeddings for each section
23
+ section_texts = [section.page_content for section in sections]
24
+ embeddings = model.encode(section_texts)
25
+
26
+ print(embeddings.shape)
27
+
28
+ embeddings_np = np.array(embeddings).astype('float32')
29
+
30
+ # Create a FAISS index
31
+ dimension = embeddings_np.shape[1]
32
+ index = faiss.IndexFlatL2(dimension)
33
+
34
+ # Add vectors to the index
35
+ index.add(embeddings_np)
36
+
37
+ # Save the index to a file
38
+ faiss.write_index(index, "database/pdf_sections_index.faiss")
39
+
40
+ # When creating the index:
41
+ sections_data = [
42
+ {
43
+ 'content': section.page_content,
44
+ 'metadata': section.metadata
45
+ }
46
+ for section in sections
47
+ ]
48
+
49
+ # Save sections data
50
+ with open('database/pdf_sections_data.pkl', 'wb') as f:
51
+ pickle.dump(sections_data, f)
52
+
53
+ print("Embeddings stored in FAISS index and saved to file.")
MentalHealth/data/Mental Health Handbook English.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
3
+ size 5413709
MentalHealth/database/pdf_sections_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
3
+ size 60012
MentalHealth/database/pdf_sections_index.faiss ADDED
Binary file (66.1 kB). View file
 
MentalHealth/rag.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.chains import LLMChain
4
+ from langchain_community.llms import Ollama
5
+ import faiss
6
+ import numpy as np
7
+ import pickle
8
+
9
+ # Load the FAISS index
10
+ try:
11
+ index = faiss.read_index("database/pdf_sections_index.faiss")
12
+ except FileNotFoundError:
13
+ print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
14
+ exit(1)
15
+
16
+ # Load the embedding model
17
+ model = SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ # Load sections data
20
+ try:
21
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
22
+ sections_data = pickle.load(f)
23
+ except FileNotFoundError:
24
+ print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
25
+ exit(1)
26
+
27
+ def search_faiss(query, k=3):
28
+ query_vector = model.encode([query])[0].astype('float32')
29
+ query_vector = np.expand_dims(query_vector, axis=0)
30
+ distances, indices = index.search(query_vector, k)
31
+
32
+ results = []
33
+ for dist, idx in zip(distances[0], indices[0]):
34
+ results.append({
35
+ 'distance': dist,
36
+ 'content': sections_data[idx]['content'],
37
+ 'metadata': sections_data[idx]['metadata']
38
+ })
39
+
40
+ return results
41
+
42
+ # Create a prompt template
43
+ prompt_template = """
44
+ You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
45
+
46
+ Context:
47
+ {context}
48
+
49
+ Question: {question}
50
+
51
+ Answer:"""
52
+
53
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
54
+
55
+ llm = Ollama(
56
+ model="llama3"
57
+ )
58
+
59
+ # Create the chain
60
+ chain = LLMChain(llm=llm, prompt=prompt)
61
+
62
+ def answer_question(query):
63
+ # Search for relevant context
64
+ search_results = search_faiss(query)
65
+
66
+ # Combine the content from the search results
67
+ context = "\n\n".join([result['content'] for result in search_results])
68
+
69
+ # Run the chain
70
+ response = chain.run(context=context, question=query)
71
+
72
+ return response
73
+
74
+ # Example usage
75
+ query = "What is Mental Health?"
76
+ answer = answer_question(query)
77
+
78
+ print(f"Question: {query}")
79
+ print(f"Answer: {answer}")
MentalHealth/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pypdf
3
+ langchain
4
+ sentence-transformers
5
+ langchain-community
6
+ opensearch-py
7
+ faiss-cpu
8
+
MentalHealth/simple_retrieval.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ # Load the FAISS_index
6
+ index = faiss.read_index("database/pdf_sections_index.faiss")
7
+
8
+ # Load the embedding model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
+ def search_faiss(query, k=3):
12
+ query_vector = model.encode([query])[0].astype('float32')
13
+ query_vector = np.expand_dims(query_vector, axis=0)
14
+ distances, indices = index.search(query_vector, k)
15
+ return distances, indices
16
+
17
+ # Example usage
18
+ query = "What are the main dietary guidelines for protein intake?"
19
+ distances, indices = search_faiss(query)
20
+
21
+ print(f"Query: {query}")
22
+ print(f"Distances: {distances}")
23
+ print(f"Indices: {indices}")
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains import LLMChain
5
+ from langchain_community.llms import Ollama
6
+ import faiss
7
+ import numpy as np
8
+ import pickle
9
+ import requests
10
+ import json
11
+
12
+ # Load the FAISS index
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_faiss_index():
15
+ try:
16
+ return faiss.read_index("database/pdf_sections_index.faiss")
17
+ except FileNotFoundError:
18
+ st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
19
+ st.stop()
20
+
21
+ # Load the embedding model
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_embedding_model():
24
+ return SentenceTransformer('all-MiniLM-L6-v2')
25
+
26
+ # Load sections data
27
+ @st.cache(allow_output_mutation=True)
28
+ def load_sections_data():
29
+ try:
30
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
31
+ return pickle.load(f)
32
+ except FileNotFoundError:
33
+ st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
34
+ st.stop()
35
+
36
+ # Initialize resources
37
+ index = load_faiss_index()
38
+ model = load_embedding_model()
39
+ sections_data = load_sections_data()
40
+
41
+ def search_faiss(query, k=3):
42
+ query_vector = model.encode([query])[0].astype('float32')
43
+ query_vector = np.expand_dims(query_vector, axis=0)
44
+ distances, indices = index.search(query_vector, k)
45
+
46
+ results = []
47
+ for dist, idx in zip(distances[0], indices[0]):
48
+ results.append({
49
+ 'distance': dist,
50
+ 'content': sections_data[idx]['content'],
51
+ 'metadata': sections_data[idx]['metadata']
52
+ })
53
+
54
+ return results
55
+
56
+ prompt_template = """
57
+ You are an AI assistant specialized in Mental Health & wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
58
+
59
+ Context:
60
+ {context}
61
+
62
+ Question: {question}
63
+
64
+ Answer:"""
65
+
66
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
67
+
68
+ @st.cache(allow_output_mutation=True)
69
+ def load_llm():
70
+ return Ollama(model="phi3")
71
+
72
+ llm = load_llm()
73
+ chain = LLMChain(llm=llm, prompt=prompt)
74
+
75
+ def answer_question(query):
76
+ search_results = search_faiss(query)
77
+ context = "\n\n".join([result['content'] for result in search_results])
78
+ response = chain.run(context=context, question=query)
79
+ return response, context
80
+
81
+ # Streamlit UI
82
+ st.title("Mental Health & Wellness Assistant")
83
+
84
+ query = st.text_input("Enter your question about Mental Health:")
85
+
86
+ if st.button("Get Answer"):
87
+ if query:
88
+ with st.spinner("Searching, Thinking and generating answer..."):
89
+ answer, context = answer_question(query)
90
+ st.subheader("Answer:")
91
+ st.write(answer)
92
+ with st.expander("Show Context"):
93
+ st.write(context)
94
+ else:
95
+ st.warning("Please enter a question.")
96
+
97
+ # Footer section with social links
98
+ st.markdown("""
99
+ <div class="social-icons">
100
+ <a href="https://github.com/yadavadit" target="_blank"><img src="https://img.icons8.com/material-outlined/48/e50914/github.png"/></a>
101
+ <a href="https://www.linkedin.com/in/yaditi/" target="_blank"><img src="https://img.icons8.com/color/48/e50914/linkedin.png"/></a>
102
+ <a href="mailto:[email protected]"><img src="https://img.icons8.com/color/48/e50914/gmail.png"/></a>
103
+ </div>
104
+ """, unsafe_allow_html=True)
create_vectordb.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import pickle
7
+
8
+ # Load the PDF
9
+ pdf_path = "data\Mental Health Handbook English.pdf"
10
+ loader = PyPDFLoader(file_path=pdf_path)
11
+
12
+ # Load the content
13
+ documents = loader.load()
14
+
15
+ # Split the document into sections
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
17
+ sections = text_splitter.split_documents(documents)
18
+
19
+ # Load the embedding model
20
+ model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ # Generate embeddings for each section
23
+ section_texts = [section.page_content for section in sections]
24
+ embeddings = model.encode(section_texts)
25
+
26
+ print(embeddings.shape)
27
+
28
+ embeddings_np = np.array(embeddings).astype('float32')
29
+
30
+ # Create a FAISS index
31
+ dimension = embeddings_np.shape[1]
32
+ index = faiss.IndexFlatL2(dimension)
33
+
34
+ # Add vectors to the index
35
+ index.add(embeddings_np)
36
+
37
+ # Save the index to a file
38
+ faiss.write_index(index, "database/pdf_sections_index.faiss")
39
+
40
+ # When creating the index:
41
+ sections_data = [
42
+ {
43
+ 'content': section.page_content,
44
+ 'metadata': section.metadata
45
+ }
46
+ for section in sections
47
+ ]
48
+
49
+ # Save sections data
50
+ with open('database/pdf_sections_data.pkl', 'wb') as f:
51
+ pickle.dump(sections_data, f)
52
+
53
+ print("Embeddings stored in FAISS index and saved to file.")
data/Mental Health Handbook English.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
3
+ size 5413709
data/MentalHealth_Dataset.xlsx ADDED
Binary file (17.7 kB). View file
 
database/pdf_sections_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
3
+ size 60012
database/pdf_sections_index.faiss ADDED
Binary file (66.1 kB). View file
 
rag.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.chains import LLMChain
4
+ from langchain_community.llms import Ollama
5
+ import faiss
6
+ import numpy as np
7
+ import pickle
8
+
9
+ # Load the FAISS index
10
+ try:
11
+ index = faiss.read_index("database/pdf_sections_index.faiss")
12
+ except FileNotFoundError:
13
+ print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
14
+ exit(1)
15
+
16
+ # Load the embedding model
17
+ model = SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ # Load sections data
20
+ try:
21
+ with open('database/pdf_sections_data.pkl', 'rb') as f:
22
+ sections_data = pickle.load(f)
23
+ except FileNotFoundError:
24
+ print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
25
+ exit(1)
26
+
27
+ def search_faiss(query, k=3):
28
+ query_vector = model.encode([query])[0].astype('float32')
29
+ query_vector = np.expand_dims(query_vector, axis=0)
30
+ distances, indices = index.search(query_vector, k)
31
+
32
+ results = []
33
+ for dist, idx in zip(distances[0], indices[0]):
34
+ results.append({
35
+ 'distance': dist,
36
+ 'content': sections_data[idx]['content'],
37
+ 'metadata': sections_data[idx]['metadata']
38
+ })
39
+
40
+ return results
41
+
42
+ # Create a prompt template
43
+ prompt_template = """
44
+ You are an AI assistant specialized in mental health and wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
45
+
46
+ Context:
47
+ {context}
48
+
49
+ Question: {question}
50
+
51
+ Answer:"""
52
+
53
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
54
+
55
+ llm = Ollama(
56
+ model="phi3"
57
+ )
58
+
59
+ # Create the chain
60
+ chain = LLMChain(llm=llm, prompt=prompt)
61
+
62
+ def answer_question(query):
63
+ # Search for relevant context
64
+ search_results = search_faiss(query)
65
+
66
+ # Combine the content from the search results
67
+ context = "\n\n".join([result['content'] for result in search_results])
68
+
69
+ # Run the chain
70
+ response = chain.run(context=context, question=query)
71
+
72
+ return response
73
+
74
+ # Example usage
75
+ query = "What is mental health?"
76
+ answer = answer_question(query)
77
+
78
+ print(f"Question: {query}")
79
+ print(f"Answer: {answer}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ evaluate
3
+ pypdf
4
+ langchain
5
+ sentence-transformers
6
+ langchain-community
7
+ opensearch-py
8
+ faiss-cpu
9
+ accelerate
10
+ bert_score
11
+
simple_retrieval.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ # Load the FAISS index
6
+ index = faiss.read_index("database/pdf_sections_index.faiss")
7
+
8
+ # Load the embedding model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
+ def search_faiss(query, k=3):
12
+ query_vector = model.encode([query])[0].astype('float32')
13
+ query_vector = np.expand_dims(query_vector, axis=0)
14
+ distances, indices = index.search(query_vector, k)
15
+ return distances, indices
16
+
17
+ # Example usage
18
+ query = "What is mental Health?"
19
+ distances, indices = search_faiss(query)
20
+
21
+ print(f"Query: {query}")
22
+ print(f"Distances: {distances}")
23
+ print(f"Indices: {indices}")