Upload 22 files
Browse files- .gitattributes +2 -0
- Evaluation_MH/.ipynb_checkpoints/Evaluation-checkpoint.ipynb +1403 -0
- Evaluation_MH/Evaluation.ipynb +1403 -0
- Evaluation_MH/Mental Health Evaluation Report.pdf +0 -0
- LICENSE +21 -0
- MentalHealth/LICENSE +21 -0
- MentalHealth/app.py +93 -0
- MentalHealth/create_vectordb.py +53 -0
- MentalHealth/data/Mental Health Handbook English.pdf +3 -0
- MentalHealth/database/pdf_sections_data.pkl +3 -0
- MentalHealth/database/pdf_sections_index.faiss +0 -0
- MentalHealth/rag.py +79 -0
- MentalHealth/requirements.txt +8 -0
- MentalHealth/simple_retrieval.py +23 -0
- app.py +104 -0
- create_vectordb.py +53 -0
- data/Mental Health Handbook English.pdf +3 -0
- data/MentalHealth_Dataset.xlsx +0 -0
- database/pdf_sections_data.pkl +3 -0
- database/pdf_sections_index.faiss +0 -0
- rag.py +79 -0
- requirements.txt +11 -0
- simple_retrieval.py +23 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
MentalHealth/data/Mental[[:space:]]Health[[:space:]]Handbook[[:space:]]English.pdf filter=lfs diff=lfs merge=lfs -text
|
Evaluation_MH/.ipynb_checkpoints/Evaluation-checkpoint.ipynb
ADDED
@@ -0,0 +1,1403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "f7b87c2c",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Imports"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 5,
|
14 |
+
"id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"# pip install -U sentence-transformers"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"id": "8a7cc9d8",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stderr",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
32 |
+
" from tqdm.autonotebook import tqdm, trange\n"
|
33 |
+
]
|
34 |
+
}
|
35 |
+
],
|
36 |
+
"source": [
|
37 |
+
"from sentence_transformers import SentenceTransformer\n",
|
38 |
+
"from langchain.prompts import PromptTemplate\n",
|
39 |
+
"from langchain.chains import LLMChain\n",
|
40 |
+
"from langchain_community.llms import Ollama\n",
|
41 |
+
"from langchain.evaluation import load_evaluator\n",
|
42 |
+
"import faiss\n",
|
43 |
+
"import pandas as pd\n",
|
44 |
+
"import numpy as np\n",
|
45 |
+
"import pickle\n",
|
46 |
+
"import time\n",
|
47 |
+
"from tqdm import tqdm"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "markdown",
|
52 |
+
"id": "b6efca1d",
|
53 |
+
"metadata": {},
|
54 |
+
"source": [
|
55 |
+
"# Intialization"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 2,
|
61 |
+
"id": "cc9a49d2",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"# Load the FAISS index\n",
|
66 |
+
"index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 3,
|
72 |
+
"id": "9af39b55",
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"model = SentenceTransformer('all-MiniLM-L6-v2')"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": 4,
|
82 |
+
"id": "fee8cdfd",
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
|
87 |
+
" sections_data = pickle.load(f)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"id": "d6a1ba6a",
|
93 |
+
"metadata": {},
|
94 |
+
"source": [
|
95 |
+
"# RAG functions"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": 5,
|
101 |
+
"id": "182bdbd8",
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"def search_faiss(query, k=3):\n",
|
106 |
+
" query_vector = model.encode([query])[0].astype('float32')\n",
|
107 |
+
" query_vector = np.expand_dims(query_vector, axis=0)\n",
|
108 |
+
" distances, indices = index.search(query_vector, k)\n",
|
109 |
+
" \n",
|
110 |
+
" results = []\n",
|
111 |
+
" for dist, idx in zip(distances[0], indices[0]):\n",
|
112 |
+
" results.append({\n",
|
113 |
+
" 'distance': dist,\n",
|
114 |
+
" 'content': sections_data[idx]['content'],\n",
|
115 |
+
" 'metadata': sections_data[idx]['metadata']\n",
|
116 |
+
" })\n",
|
117 |
+
" \n",
|
118 |
+
" return results"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 15,
|
124 |
+
"id": "67edc46a",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"# Create a prompt template\n",
|
129 |
+
"prompt_template = \"\"\"\n",
|
130 |
+
"You are an AI assistant specialized in Mental Health guidelines. \n",
|
131 |
+
"Use the following pieces of context to answer the question. \n",
|
132 |
+
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
|
133 |
+
"\n",
|
134 |
+
"Context:\n",
|
135 |
+
"{context}\n",
|
136 |
+
"\n",
|
137 |
+
"Question: {question}\n",
|
138 |
+
"\n",
|
139 |
+
"Answer:\"\"\"\n",
|
140 |
+
"\n",
|
141 |
+
"prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
|
142 |
+
"\n",
|
143 |
+
"llm = Ollama(\n",
|
144 |
+
" model=\"llama3\"\n",
|
145 |
+
")\n",
|
146 |
+
"\n",
|
147 |
+
"# Create the chain\n",
|
148 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
149 |
+
"\n",
|
150 |
+
"def answer_question(query):\n",
|
151 |
+
" # Search for relevant context\n",
|
152 |
+
" search_results = search_faiss(query)\n",
|
153 |
+
" \n",
|
154 |
+
" # Combine the content from the search results\n",
|
155 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
156 |
+
"\n",
|
157 |
+
" # Run the chain\n",
|
158 |
+
" response = chain.run(context=context, question=query)\n",
|
159 |
+
" \n",
|
160 |
+
" return response"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"id": "3b176af9",
|
166 |
+
"metadata": {},
|
167 |
+
"source": [
|
168 |
+
"# Reading GT"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": 16,
|
174 |
+
"id": "4ab68dff",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"df = pd.read_csv('data/MentalHealth_Dataset.csv')"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 17,
|
184 |
+
"id": "4e7e22d7",
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"name": "stderr",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
|
192 |
+
]
|
193 |
+
}
|
194 |
+
],
|
195 |
+
"source": [
|
196 |
+
"time_list=[]\n",
|
197 |
+
"response_list=[]\n",
|
198 |
+
"for i in tqdm(range(len(df))):\n",
|
199 |
+
" query = df['Questions'].values[i]\n",
|
200 |
+
" start = time.time()\n",
|
201 |
+
" response = answer_question(query)\n",
|
202 |
+
" end = time.time() \n",
|
203 |
+
" time_list.append(end-start)\n",
|
204 |
+
" response_list.append(response)"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 18,
|
210 |
+
"id": "2b327e90",
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"df['latency'] = time_list\n",
|
215 |
+
"df['response'] = response_list"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "markdown",
|
220 |
+
"id": "3c147204",
|
221 |
+
"metadata": {},
|
222 |
+
"source": [
|
223 |
+
"# Evaluation"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": 29,
|
229 |
+
"id": "d799e541",
|
230 |
+
"metadata": {},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"eval_llm = Ollama(\n",
|
234 |
+
" model=\"phi3\"\n",
|
235 |
+
")"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 30,
|
241 |
+
"id": "c2f788dc",
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": 31,
|
251 |
+
"id": "83ec2b8d",
|
252 |
+
"metadata": {},
|
253 |
+
"outputs": [
|
254 |
+
{
|
255 |
+
"name": "stderr",
|
256 |
+
"output_type": "stream",
|
257 |
+
"text": [
|
258 |
+
"100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
|
259 |
+
"100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
|
260 |
+
"100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
|
261 |
+
"100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
|
262 |
+
]
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"for metric in metrics:\n",
|
267 |
+
" evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
|
268 |
+
" \n",
|
269 |
+
" reasoning = []\n",
|
270 |
+
" value = []\n",
|
271 |
+
" score = []\n",
|
272 |
+
" \n",
|
273 |
+
" for i in tqdm(range(len(df))):\n",
|
274 |
+
" eval_result = evaluator.evaluate_strings(\n",
|
275 |
+
" prediction=df.response.values[i],\n",
|
276 |
+
" input=df.Questions.values[i],\n",
|
277 |
+
" reference=df.Answers.values[i]\n",
|
278 |
+
" )\n",
|
279 |
+
" reasoning.append(eval_result['reasoning'])\n",
|
280 |
+
" value.append(eval_result['value'])\n",
|
281 |
+
" score.append(eval_result['score'])\n",
|
282 |
+
" \n",
|
283 |
+
" df[metric+'_reasoning'] = reasoning\n",
|
284 |
+
" df[metric+'_value'] = value\n",
|
285 |
+
" df[metric+'_score'] = score "
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 78,
|
291 |
+
"id": "f1673a31",
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [
|
294 |
+
{
|
295 |
+
"data": {
|
296 |
+
"text/html": [
|
297 |
+
"<div>\n",
|
298 |
+
"<style scoped>\n",
|
299 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
300 |
+
" vertical-align: middle;\n",
|
301 |
+
" }\n",
|
302 |
+
"\n",
|
303 |
+
" .dataframe tbody tr th {\n",
|
304 |
+
" vertical-align: top;\n",
|
305 |
+
" }\n",
|
306 |
+
"\n",
|
307 |
+
" .dataframe thead th {\n",
|
308 |
+
" text-align: right;\n",
|
309 |
+
" }\n",
|
310 |
+
"</style>\n",
|
311 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
312 |
+
" <thead>\n",
|
313 |
+
" <tr style=\"text-align: right;\">\n",
|
314 |
+
" <th></th>\n",
|
315 |
+
" <th>Questions</th>\n",
|
316 |
+
" <th>Answers</th>\n",
|
317 |
+
" <th>latency</th>\n",
|
318 |
+
" <th>response</th>\n",
|
319 |
+
" <th>correctness_reasoning</th>\n",
|
320 |
+
" <th>correctness_value</th>\n",
|
321 |
+
" <th>correctness_score</th>\n",
|
322 |
+
" <th>relevance_reasoning</th>\n",
|
323 |
+
" <th>relevance_value</th>\n",
|
324 |
+
" <th>relevance_score</th>\n",
|
325 |
+
" <th>coherence_reasoning</th>\n",
|
326 |
+
" <th>coherence_value</th>\n",
|
327 |
+
" <th>coherence_score</th>\n",
|
328 |
+
" <th>conciseness_reasoning</th>\n",
|
329 |
+
" <th>conciseness_value</th>\n",
|
330 |
+
" <th>conciseness_score</th>\n",
|
331 |
+
" </tr>\n",
|
332 |
+
" </thead>\n",
|
333 |
+
" <tbody>\n",
|
334 |
+
" <tr>\n",
|
335 |
+
" <th>0</th>\n",
|
336 |
+
" <td>What is Mental Health</td>\n",
|
337 |
+
" <td>Mental Health is a \" state of well-being in wh...</td>\n",
|
338 |
+
" <td>11.974234</td>\n",
|
339 |
+
" <td>Based on the provided context, specifically fr...</td>\n",
|
340 |
+
" <td>The submission refers to the provided input wh...</td>\n",
|
341 |
+
" <td>Y</td>\n",
|
342 |
+
" <td>1</td>\n",
|
343 |
+
" <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
|
344 |
+
" <td>Y</td>\n",
|
345 |
+
" <td>1</td>\n",
|
346 |
+
" <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
|
347 |
+
" <td>Y</td>\n",
|
348 |
+
" <td>1</td>\n",
|
349 |
+
" <td>1. The submission directly answers the questio...</td>\n",
|
350 |
+
" <td>Y</td>\n",
|
351 |
+
" <td>1</td>\n",
|
352 |
+
" </tr>\n",
|
353 |
+
" <tr>\n",
|
354 |
+
" <th>1</th>\n",
|
355 |
+
" <td>What are the most common mental disorders ment...</td>\n",
|
356 |
+
" <td>The most common mental disorders include depre...</td>\n",
|
357 |
+
" <td>5.863329</td>\n",
|
358 |
+
" <td>Based on the provided context, the mental diso...</td>\n",
|
359 |
+
" <td>Step 1: Check if the submission is factually a...</td>\n",
|
360 |
+
" <td>Y</td>\n",
|
361 |
+
" <td>1</td>\n",
|
362 |
+
" <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
|
363 |
+
" <td>Y</td>\n",
|
364 |
+
" <td>1</td>\n",
|
365 |
+
" <td>The submission begins with an appropriate ques...</td>\n",
|
366 |
+
" <td>Y</td>\n",
|
367 |
+
" <td>1</td>\n",
|
368 |
+
" <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
|
369 |
+
" <td>Y</td>\n",
|
370 |
+
" <td>1</td>\n",
|
371 |
+
" </tr>\n",
|
372 |
+
" <tr>\n",
|
373 |
+
" <th>2</th>\n",
|
374 |
+
" <td>What are the early warning signs and symptoms ...</td>\n",
|
375 |
+
" <td>Early warning signs and symptoms of depression...</td>\n",
|
376 |
+
" <td>13.434543</td>\n",
|
377 |
+
" <td>Based on the provided context, I found a refer...</td>\n",
|
378 |
+
" <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
|
379 |
+
" <td>Y</td>\n",
|
380 |
+
" <td>1</td>\n",
|
381 |
+
" <td>Step 1: Identify the relevant criterion from t...</td>\n",
|
382 |
+
" <td>Y</td>\n",
|
383 |
+
" <td>1</td>\n",
|
384 |
+
" <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
|
385 |
+
" <td>Y</td>\n",
|
386 |
+
" <td>1</td>\n",
|
387 |
+
" <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
|
388 |
+
" <td>Y</td>\n",
|
389 |
+
" <td>1</td>\n",
|
390 |
+
" </tr>\n",
|
391 |
+
" <tr>\n",
|
392 |
+
" <th>3</th>\n",
|
393 |
+
" <td>How can someone help a person who suffers from...</td>\n",
|
394 |
+
" <td>To help someone with anxiety, one can support ...</td>\n",
|
395 |
+
" <td>13.838464</td>\n",
|
396 |
+
" <td>According to the provided context, specificall...</td>\n",
|
397 |
+
" <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
|
398 |
+
" <td>Y</td>\n",
|
399 |
+
" <td>1</td>\n",
|
400 |
+
" <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
|
401 |
+
" <td>Y</td>\n",
|
402 |
+
" <td>1</td>\n",
|
403 |
+
" <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
|
404 |
+
" <td>Y</td>\n",
|
405 |
+
" <td>1</td>\n",
|
406 |
+
" <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
|
407 |
+
" <td>N</td>\n",
|
408 |
+
" <td>0</td>\n",
|
409 |
+
" </tr>\n",
|
410 |
+
" <tr>\n",
|
411 |
+
" <th>4</th>\n",
|
412 |
+
" <td>What are the causes of mental illness listed i...</td>\n",
|
413 |
+
" <td>Causes of mental illness include abnormal func...</td>\n",
|
414 |
+
" <td>6.871735</td>\n",
|
415 |
+
" <td>According to the provided context, the causes ...</td>\n",
|
416 |
+
" <td>The submission lists factors that align with t...</td>\n",
|
417 |
+
" <td>N</td>\n",
|
418 |
+
" <td>0</td>\n",
|
419 |
+
" <td>Step 1: Review relevance criterion - Check if ...</td>\n",
|
420 |
+
" <td>Y</td>\n",
|
421 |
+
" <td>1</td>\n",
|
422 |
+
" <td>Step 1: Compare the submission with the provid...</td>\n",
|
423 |
+
" <td>Y</td>\n",
|
424 |
+
" <td>1</td>\n",
|
425 |
+
" <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
|
426 |
+
" <td>Y</td>\n",
|
427 |
+
" <td>1</td>\n",
|
428 |
+
" </tr>\n",
|
429 |
+
" </tbody>\n",
|
430 |
+
"</table>\n",
|
431 |
+
"</div>"
|
432 |
+
],
|
433 |
+
"text/plain": [
|
434 |
+
" Questions \\\n",
|
435 |
+
"0 What is Mental Health \n",
|
436 |
+
"1 What are the most common mental disorders ment... \n",
|
437 |
+
"2 What are the early warning signs and symptoms ... \n",
|
438 |
+
"3 How can someone help a person who suffers from... \n",
|
439 |
+
"4 What are the causes of mental illness listed i... \n",
|
440 |
+
"\n",
|
441 |
+
" Answers latency \\\n",
|
442 |
+
"0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
|
443 |
+
"1 The most common mental disorders include depre... 5.863329 \n",
|
444 |
+
"2 Early warning signs and symptoms of depression... 13.434543 \n",
|
445 |
+
"3 To help someone with anxiety, one can support ... 13.838464 \n",
|
446 |
+
"4 Causes of mental illness include abnormal func... 6.871735 \n",
|
447 |
+
"\n",
|
448 |
+
" response \\\n",
|
449 |
+
"0 Based on the provided context, specifically fr... \n",
|
450 |
+
"1 Based on the provided context, the mental diso... \n",
|
451 |
+
"2 Based on the provided context, I found a refer... \n",
|
452 |
+
"3 According to the provided context, specificall... \n",
|
453 |
+
"4 According to the provided context, the causes ... \n",
|
454 |
+
"\n",
|
455 |
+
" correctness_reasoning correctness_value \\\n",
|
456 |
+
"0 The submission refers to the provided input wh... Y \n",
|
457 |
+
"1 Step 1: Check if the submission is factually a... Y \n",
|
458 |
+
"2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
|
459 |
+
"3 Step 1: Correctness\\nThe submission accurately... Y \n",
|
460 |
+
"4 The submission lists factors that align with t... N \n",
|
461 |
+
"\n",
|
462 |
+
" correctness_score relevance_reasoning \\\n",
|
463 |
+
"0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
|
464 |
+
"1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
|
465 |
+
"2 1 Step 1: Identify the relevant criterion from t... \n",
|
466 |
+
"3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
|
467 |
+
"4 0 Step 1: Review relevance criterion - Check if ... \n",
|
468 |
+
"\n",
|
469 |
+
" relevance_value relevance_score \\\n",
|
470 |
+
"0 Y 1 \n",
|
471 |
+
"1 Y 1 \n",
|
472 |
+
"2 Y 1 \n",
|
473 |
+
"3 Y 1 \n",
|
474 |
+
"4 Y 1 \n",
|
475 |
+
"\n",
|
476 |
+
" coherence_reasoning coherence_value \\\n",
|
477 |
+
"0 Step 1: Assess coherence\\nThe submission direc... Y \n",
|
478 |
+
"1 The submission begins with an appropriate ques... Y \n",
|
479 |
+
"2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
|
480 |
+
"3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
|
481 |
+
"4 Step 1: Compare the submission with the provid... Y \n",
|
482 |
+
"\n",
|
483 |
+
" coherence_score conciseness_reasoning \\\n",
|
484 |
+
"0 1 1. The submission directly answers the questio... \n",
|
485 |
+
"1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
|
486 |
+
"2 1 Step 1: Evaluate conciseness - The submission ... \n",
|
487 |
+
"3 1 Step 1: Evaluate conciseness - The submission ... \n",
|
488 |
+
"4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
|
489 |
+
"\n",
|
490 |
+
" conciseness_value conciseness_score \n",
|
491 |
+
"0 Y 1 \n",
|
492 |
+
"1 Y 1 \n",
|
493 |
+
"2 Y 1 \n",
|
494 |
+
"3 N 0 \n",
|
495 |
+
"4 Y 1 "
|
496 |
+
]
|
497 |
+
},
|
498 |
+
"execution_count": 78,
|
499 |
+
"metadata": {},
|
500 |
+
"output_type": "execute_result"
|
501 |
+
}
|
502 |
+
],
|
503 |
+
"source": [
|
504 |
+
"df.head()"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": 32,
|
510 |
+
"id": "7797a360",
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [
|
513 |
+
{
|
514 |
+
"data": {
|
515 |
+
"text/plain": [
|
516 |
+
"correctness_score 0.800000\n",
|
517 |
+
"relevance_score 0.900000\n",
|
518 |
+
"coherence_score 1.000000\n",
|
519 |
+
"conciseness_score 0.800000\n",
|
520 |
+
"latency 10.544803\n",
|
521 |
+
"dtype: float64"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
"execution_count": 32,
|
525 |
+
"metadata": {},
|
526 |
+
"output_type": "execute_result"
|
527 |
+
}
|
528 |
+
],
|
529 |
+
"source": [
|
530 |
+
"df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
|
531 |
+
]
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"cell_type": "code",
|
535 |
+
"execution_count": 34,
|
536 |
+
"id": "fe667926",
|
537 |
+
"metadata": {},
|
538 |
+
"outputs": [],
|
539 |
+
"source": [
|
540 |
+
"irr_q=pd.read_csv('data/Unrelated_questions.csv')"
|
541 |
+
]
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"cell_type": "code",
|
545 |
+
"execution_count": 35,
|
546 |
+
"id": "189f8a0f",
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [
|
549 |
+
{
|
550 |
+
"name": "stderr",
|
551 |
+
"output_type": "stream",
|
552 |
+
"text": [
|
553 |
+
"100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
|
554 |
+
]
|
555 |
+
}
|
556 |
+
],
|
557 |
+
"source": [
|
558 |
+
"time_list=[]\n",
|
559 |
+
"response_list=[]\n",
|
560 |
+
"for i in tqdm(range(len(irr_q))):\n",
|
561 |
+
" query = irr_q['Questions'].values[i]\n",
|
562 |
+
" start = time.time()\n",
|
563 |
+
" response = answer_question(query)\n",
|
564 |
+
" end = time.time() \n",
|
565 |
+
" time_list.append(end-start)\n",
|
566 |
+
" response_list.append(response)"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "code",
|
571 |
+
"execution_count": 36,
|
572 |
+
"id": "b0244ea0",
|
573 |
+
"metadata": {},
|
574 |
+
"outputs": [],
|
575 |
+
"source": [
|
576 |
+
"irr_q['response']=response_list\n",
|
577 |
+
"irr_q['latency']=time_list"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 79,
|
583 |
+
"id": "dc3b1ade",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [
|
586 |
+
{
|
587 |
+
"data": {
|
588 |
+
"text/html": [
|
589 |
+
"<div>\n",
|
590 |
+
"<style scoped>\n",
|
591 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
592 |
+
" vertical-align: middle;\n",
|
593 |
+
" }\n",
|
594 |
+
"\n",
|
595 |
+
" .dataframe tbody tr th {\n",
|
596 |
+
" vertical-align: top;\n",
|
597 |
+
" }\n",
|
598 |
+
"\n",
|
599 |
+
" .dataframe thead th {\n",
|
600 |
+
" text-align: right;\n",
|
601 |
+
" }\n",
|
602 |
+
"</style>\n",
|
603 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
604 |
+
" <thead>\n",
|
605 |
+
" <tr style=\"text-align: right;\">\n",
|
606 |
+
" <th></th>\n",
|
607 |
+
" <th>Questions</th>\n",
|
608 |
+
" <th>response</th>\n",
|
609 |
+
" <th>latency</th>\n",
|
610 |
+
" <th>irrelevant_score</th>\n",
|
611 |
+
" </tr>\n",
|
612 |
+
" </thead>\n",
|
613 |
+
" <tbody>\n",
|
614 |
+
" <tr>\n",
|
615 |
+
" <th>0</th>\n",
|
616 |
+
" <td>What is the capital of Mars?</td>\n",
|
617 |
+
" <td>I don't know. The provided context does not se...</td>\n",
|
618 |
+
" <td>12.207266</td>\n",
|
619 |
+
" <td>True</td>\n",
|
620 |
+
" </tr>\n",
|
621 |
+
" <tr>\n",
|
622 |
+
" <th>1</th>\n",
|
623 |
+
" <td>How many unicorns live in New York City?</td>\n",
|
624 |
+
" <td>I don't know. The information provided does no...</td>\n",
|
625 |
+
" <td>2.368774</td>\n",
|
626 |
+
" <td>True</td>\n",
|
627 |
+
" </tr>\n",
|
628 |
+
" <tr>\n",
|
629 |
+
" <th>2</th>\n",
|
630 |
+
" <td>What is the color of happiness?</td>\n",
|
631 |
+
" <td>I don't know! The provided context only talks ...</td>\n",
|
632 |
+
" <td>5.480067</td>\n",
|
633 |
+
" <td>True</td>\n",
|
634 |
+
" </tr>\n",
|
635 |
+
" <tr>\n",
|
636 |
+
" <th>3</th>\n",
|
637 |
+
" <td>Can cats fly on Tuesdays?</td>\n",
|
638 |
+
" <td>I don't know the answer to this question as it...</td>\n",
|
639 |
+
" <td>5.272529</td>\n",
|
640 |
+
" <td>True</td>\n",
|
641 |
+
" </tr>\n",
|
642 |
+
" <tr>\n",
|
643 |
+
" <th>4</th>\n",
|
644 |
+
" <td>How much does a thought weigh?</td>\n",
|
645 |
+
" <td>I don't know. The context provided is about me...</td>\n",
|
646 |
+
" <td>5.253224</td>\n",
|
647 |
+
" <td>True</td>\n",
|
648 |
+
" </tr>\n",
|
649 |
+
" </tbody>\n",
|
650 |
+
"</table>\n",
|
651 |
+
"</div>"
|
652 |
+
],
|
653 |
+
"text/plain": [
|
654 |
+
" Questions \\\n",
|
655 |
+
"0 What is the capital of Mars? \n",
|
656 |
+
"1 How many unicorns live in New York City? \n",
|
657 |
+
"2 What is the color of happiness? \n",
|
658 |
+
"3 Can cats fly on Tuesdays? \n",
|
659 |
+
"4 How much does a thought weigh? \n",
|
660 |
+
"\n",
|
661 |
+
" response latency \\\n",
|
662 |
+
"0 I don't know. The provided context does not se... 12.207266 \n",
|
663 |
+
"1 I don't know. The information provided does no... 2.368774 \n",
|
664 |
+
"2 I don't know! The provided context only talks ... 5.480067 \n",
|
665 |
+
"3 I don't know the answer to this question as it... 5.272529 \n",
|
666 |
+
"4 I don't know. The context provided is about me... 5.253224 \n",
|
667 |
+
"\n",
|
668 |
+
" irrelevant_score \n",
|
669 |
+
"0 True \n",
|
670 |
+
"1 True \n",
|
671 |
+
"2 True \n",
|
672 |
+
"3 True \n",
|
673 |
+
"4 True "
|
674 |
+
]
|
675 |
+
},
|
676 |
+
"execution_count": 79,
|
677 |
+
"metadata": {},
|
678 |
+
"output_type": "execute_result"
|
679 |
+
}
|
680 |
+
],
|
681 |
+
"source": [
|
682 |
+
"irr_q.head()"
|
683 |
+
]
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"cell_type": "code",
|
687 |
+
"execution_count": 37,
|
688 |
+
"id": "8620e50c",
|
689 |
+
"metadata": {},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"data": {
|
693 |
+
"text/plain": [
|
694 |
+
"0 12.207266\n",
|
695 |
+
"1 2.368774\n",
|
696 |
+
"2 5.480067\n",
|
697 |
+
"3 5.272529\n",
|
698 |
+
"4 5.253224\n",
|
699 |
+
"5 5.351224\n",
|
700 |
+
"6 8.118429\n",
|
701 |
+
"7 7.288261\n",
|
702 |
+
"8 3.856500\n",
|
703 |
+
"9 7.745016\n",
|
704 |
+
"Name: latency, dtype: float64"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
"execution_count": 37,
|
708 |
+
"metadata": {},
|
709 |
+
"output_type": "execute_result"
|
710 |
+
}
|
711 |
+
],
|
712 |
+
"source": [
|
713 |
+
"irr_q['latency']"
|
714 |
+
]
|
715 |
+
},
|
716 |
+
{
|
717 |
+
"cell_type": "code",
|
718 |
+
"execution_count": 39,
|
719 |
+
"id": "debd3461",
|
720 |
+
"metadata": {},
|
721 |
+
"outputs": [],
|
722 |
+
"source": [
|
723 |
+
"irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "code",
|
728 |
+
"execution_count": 40,
|
729 |
+
"id": "bef1d3a4",
|
730 |
+
"metadata": {},
|
731 |
+
"outputs": [
|
732 |
+
{
|
733 |
+
"data": {
|
734 |
+
"text/plain": [
|
735 |
+
"irrelevant_score 0.900000\n",
|
736 |
+
"latency 6.294129\n",
|
737 |
+
"dtype: float64"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
"execution_count": 40,
|
741 |
+
"metadata": {},
|
742 |
+
"output_type": "execute_result"
|
743 |
+
}
|
744 |
+
],
|
745 |
+
"source": [
|
746 |
+
"irr_q[['irrelevant_score','latency']].mean()"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "markdown",
|
751 |
+
"id": "c1610a70",
|
752 |
+
"metadata": {},
|
753 |
+
"source": [
|
754 |
+
"# Improvement"
|
755 |
+
]
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"cell_type": "code",
|
759 |
+
"execution_count": 48,
|
760 |
+
"id": "ff6614f9",
|
761 |
+
"metadata": {},
|
762 |
+
"outputs": [],
|
763 |
+
"source": [
|
764 |
+
"new_prompt_template = \"\"\"\n",
|
765 |
+
"You are an AI assistant specialized in Mental Health guidelines.\n",
|
766 |
+
"Use the provided context to answer the question short and accurately. \n",
|
767 |
+
"If you don't know the answer, simply say, \"I don't know.\"\n",
|
768 |
+
"\n",
|
769 |
+
"Context:\n",
|
770 |
+
"{context}\n",
|
771 |
+
"\n",
|
772 |
+
"Question: {question}\n",
|
773 |
+
"\n",
|
774 |
+
"Answer:\"\"\"\n",
|
775 |
+
"\n",
|
776 |
+
"prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
|
777 |
+
"\n",
|
778 |
+
"llm = Ollama(\n",
|
779 |
+
" model=\"llama3\"\n",
|
780 |
+
")\n",
|
781 |
+
"\n",
|
782 |
+
"# Create the chain\n",
|
783 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
784 |
+
"\n",
|
785 |
+
"def answer_question_new(query):\n",
|
786 |
+
" # Search for relevant context\n",
|
787 |
+
" search_results = search_faiss(query)\n",
|
788 |
+
" \n",
|
789 |
+
" # Combine the content from the search results\n",
|
790 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
791 |
+
"\n",
|
792 |
+
" # Run the chain\n",
|
793 |
+
" response = chain.run(context=context, question=query)\n",
|
794 |
+
" \n",
|
795 |
+
" return response"
|
796 |
+
]
|
797 |
+
},
|
798 |
+
{
|
799 |
+
"cell_type": "code",
|
800 |
+
"execution_count": 49,
|
801 |
+
"id": "20580d50",
|
802 |
+
"metadata": {},
|
803 |
+
"outputs": [],
|
804 |
+
"source": [
|
805 |
+
"df2=df.copy()"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "code",
|
810 |
+
"execution_count": 50,
|
811 |
+
"id": "b1b3d725",
|
812 |
+
"metadata": {},
|
813 |
+
"outputs": [
|
814 |
+
{
|
815 |
+
"name": "stderr",
|
816 |
+
"output_type": "stream",
|
817 |
+
"text": [
|
818 |
+
"100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
|
819 |
+
]
|
820 |
+
}
|
821 |
+
],
|
822 |
+
"source": [
|
823 |
+
"time_list=[]\n",
|
824 |
+
"response_list=[]\n",
|
825 |
+
"for i in tqdm(range(len(df2))):\n",
|
826 |
+
" query = df2['Questions'].values[i]\n",
|
827 |
+
" start = time.time()\n",
|
828 |
+
" response = answer_question(query)\n",
|
829 |
+
" end = time.time() \n",
|
830 |
+
" time_list.append(end-start)\n",
|
831 |
+
" response_list.append(response)"
|
832 |
+
]
|
833 |
+
},
|
834 |
+
{
|
835 |
+
"cell_type": "code",
|
836 |
+
"execution_count": 51,
|
837 |
+
"id": "63f41256",
|
838 |
+
"metadata": {},
|
839 |
+
"outputs": [],
|
840 |
+
"source": [
|
841 |
+
"df2['latency'] = time_list\n",
|
842 |
+
"df2['response'] = response_list"
|
843 |
+
]
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"cell_type": "code",
|
847 |
+
"execution_count": 52,
|
848 |
+
"id": "0d8a6065",
|
849 |
+
"metadata": {},
|
850 |
+
"outputs": [
|
851 |
+
{
|
852 |
+
"name": "stderr",
|
853 |
+
"output_type": "stream",
|
854 |
+
"text": [
|
855 |
+
"100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
|
856 |
+
"100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
|
857 |
+
"100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
|
858 |
+
"100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
|
859 |
+
]
|
860 |
+
}
|
861 |
+
],
|
862 |
+
"source": [
|
863 |
+
"for metric in metrics:\n",
|
864 |
+
" evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
|
865 |
+
" \n",
|
866 |
+
" reasoning = []\n",
|
867 |
+
" value = []\n",
|
868 |
+
" score = []\n",
|
869 |
+
" \n",
|
870 |
+
" for i in tqdm(range(len(df2))):\n",
|
871 |
+
" eval_result = evaluator.evaluate_strings(\n",
|
872 |
+
" prediction=df2.response.values[i],\n",
|
873 |
+
" input=df2.Questions.values[i],\n",
|
874 |
+
" reference=df2.Answers.values[i]\n",
|
875 |
+
" )\n",
|
876 |
+
" reasoning.append(eval_result['reasoning'])\n",
|
877 |
+
" value.append(eval_result['value'])\n",
|
878 |
+
" score.append(eval_result['score'])\n",
|
879 |
+
" \n",
|
880 |
+
" df2[metric+'_reasoning'] = reasoning\n",
|
881 |
+
" df2[metric+'_value'] = value\n",
|
882 |
+
" df2[metric+'_score'] = score "
|
883 |
+
]
|
884 |
+
},
|
885 |
+
{
|
886 |
+
"cell_type": "code",
|
887 |
+
"execution_count": 77,
|
888 |
+
"id": "c648632c",
|
889 |
+
"metadata": {},
|
890 |
+
"outputs": [
|
891 |
+
{
|
892 |
+
"data": {
|
893 |
+
"text/html": [
|
894 |
+
"<div>\n",
|
895 |
+
"<style scoped>\n",
|
896 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
897 |
+
" vertical-align: middle;\n",
|
898 |
+
" }\n",
|
899 |
+
"\n",
|
900 |
+
" .dataframe tbody tr th {\n",
|
901 |
+
" vertical-align: top;\n",
|
902 |
+
" }\n",
|
903 |
+
"\n",
|
904 |
+
" .dataframe thead th {\n",
|
905 |
+
" text-align: right;\n",
|
906 |
+
" }\n",
|
907 |
+
"</style>\n",
|
908 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
909 |
+
" <thead>\n",
|
910 |
+
" <tr style=\"text-align: right;\">\n",
|
911 |
+
" <th></th>\n",
|
912 |
+
" <th>Questions</th>\n",
|
913 |
+
" <th>Answers</th>\n",
|
914 |
+
" <th>latency</th>\n",
|
915 |
+
" <th>response</th>\n",
|
916 |
+
" <th>correctness_reasoning</th>\n",
|
917 |
+
" <th>correctness_value</th>\n",
|
918 |
+
" <th>correctness_score</th>\n",
|
919 |
+
" <th>relevance_reasoning</th>\n",
|
920 |
+
" <th>relevance_value</th>\n",
|
921 |
+
" <th>relevance_score</th>\n",
|
922 |
+
" <th>coherence_reasoning</th>\n",
|
923 |
+
" <th>coherence_value</th>\n",
|
924 |
+
" <th>coherence_score</th>\n",
|
925 |
+
" <th>conciseness_reasoning</th>\n",
|
926 |
+
" <th>conciseness_value</th>\n",
|
927 |
+
" <th>conciseness_score</th>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" </thead>\n",
|
930 |
+
" <tbody>\n",
|
931 |
+
" <tr>\n",
|
932 |
+
" <th>0</th>\n",
|
933 |
+
" <td>What is Mental Health</td>\n",
|
934 |
+
" <td>Mental Health is a \" state of well-being in wh...</td>\n",
|
935 |
+
" <td>11.046327</td>\n",
|
936 |
+
" <td>Based on the context provided, mental health r...</td>\n",
|
937 |
+
" <td>Step 1: Evaluate if the submission is factuall...</td>\n",
|
938 |
+
" <td>N</td>\n",
|
939 |
+
" <td>0</td>\n",
|
940 |
+
" <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
|
941 |
+
" <td>N</td>\n",
|
942 |
+
" <td>0</td>\n",
|
943 |
+
" <td>The submission discusses mental health in rela...</td>\n",
|
944 |
+
" <td>Y</td>\n",
|
945 |
+
" <td>1</td>\n",
|
946 |
+
" <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
|
947 |
+
" <td>Y</td>\n",
|
948 |
+
" <td>1</td>\n",
|
949 |
+
" </tr>\n",
|
950 |
+
" <tr>\n",
|
951 |
+
" <th>1</th>\n",
|
952 |
+
" <td>What are the most common mental disorders ment...</td>\n",
|
953 |
+
" <td>The most common mental disorders include depre...</td>\n",
|
954 |
+
" <td>4.509713</td>\n",
|
955 |
+
" <td>The handbook mentions several mental illnesses...</td>\n",
|
956 |
+
" <td>The submission mentions depression and schizop...</td>\n",
|
957 |
+
" <td>N</td>\n",
|
958 |
+
" <td>0</td>\n",
|
959 |
+
" <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
|
960 |
+
" <td>Y</td>\n",
|
961 |
+
" <td>1</td>\n",
|
962 |
+
" <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
|
963 |
+
" <td>N</td>\n",
|
964 |
+
" <td>0</td>\n",
|
965 |
+
" <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
|
966 |
+
" <td>N</td>\n",
|
967 |
+
" <td>0</td>\n",
|
968 |
+
" </tr>\n",
|
969 |
+
" <tr>\n",
|
970 |
+
" <th>2</th>\n",
|
971 |
+
" <td>What are the early warning signs and symptoms ...</td>\n",
|
972 |
+
" <td>Early warning signs and symptoms of depression...</td>\n",
|
973 |
+
" <td>8.501180</td>\n",
|
974 |
+
" <td>According to the provided context, specificall...</td>\n",
|
975 |
+
" <td>The submission matches the reference data in t...</td>\n",
|
976 |
+
" <td>Y</td>\n",
|
977 |
+
" <td>1</td>\n",
|
978 |
+
" <td>The submission refers directly to information ...</td>\n",
|
979 |
+
" <td>Y</td>\n",
|
980 |
+
" <td>1</td>\n",
|
981 |
+
" <td>Step 1: Evaluate coherence - The submission is...</td>\n",
|
982 |
+
" <td>Y</td>\n",
|
983 |
+
" <td>1</td>\n",
|
984 |
+
" <td>The submission is concise and includes most of...</td>\n",
|
985 |
+
" <td>Y</td>\n",
|
986 |
+
" <td>1</td>\n",
|
987 |
+
" </tr>\n",
|
988 |
+
" <tr>\n",
|
989 |
+
" <th>3</th>\n",
|
990 |
+
" <td>How can someone help a person who suffers from...</td>\n",
|
991 |
+
" <td>To help someone with anxiety, one can support ...</td>\n",
|
992 |
+
" <td>10.611402</td>\n",
|
993 |
+
" <td>According to the Mental Health Handbook, when ...</td>\n",
|
994 |
+
" <td>The submission seems consistent with the refer...</td>\n",
|
995 |
+
" <td>Y</td>\n",
|
996 |
+
" <td>1</td>\n",
|
997 |
+
" <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
|
998 |
+
" <td>Y</td>\n",
|
999 |
+
" <td>1</td>\n",
|
1000 |
+
" <td>The submission is coherent, well-structured, a...</td>\n",
|
1001 |
+
" <td>Y</td>\n",
|
1002 |
+
" <td>1</td>\n",
|
1003 |
+
" <td>The submission is relatively concise and cover...</td>\n",
|
1004 |
+
" <td>Y</td>\n",
|
1005 |
+
" <td>1</td>\n",
|
1006 |
+
" </tr>\n",
|
1007 |
+
" <tr>\n",
|
1008 |
+
" <th>4</th>\n",
|
1009 |
+
" <td>What are the causes of mental illness listed i...</td>\n",
|
1010 |
+
" <td>Causes of mental illness include abnormal func...</td>\n",
|
1011 |
+
" <td>6.299272</td>\n",
|
1012 |
+
" <td>According to the context, the causes of mental...</td>\n",
|
1013 |
+
" <td>The submission lists causes such as neglect, s...</td>\n",
|
1014 |
+
" <td>N</td>\n",
|
1015 |
+
" <td>0</td>\n",
|
1016 |
+
" <td>The submission mentions factors that are part ...</td>\n",
|
1017 |
+
" <td>N</td>\n",
|
1018 |
+
" <td>0</td>\n",
|
1019 |
+
" <td>The submission is coherent and well-structured...</td>\n",
|
1020 |
+
" <td>Y</td>\n",
|
1021 |
+
" <td>1</td>\n",
|
1022 |
+
" <td>Step 1: Read and understand both the input dat...</td>\n",
|
1023 |
+
" <td>N</td>\n",
|
1024 |
+
" <td>0</td>\n",
|
1025 |
+
" </tr>\n",
|
1026 |
+
" </tbody>\n",
|
1027 |
+
"</table>\n",
|
1028 |
+
"</div>"
|
1029 |
+
],
|
1030 |
+
"text/plain": [
|
1031 |
+
" Questions \\\n",
|
1032 |
+
"0 What is Mental Health \n",
|
1033 |
+
"1 What are the most common mental disorders ment... \n",
|
1034 |
+
"2 What are the early warning signs and symptoms ... \n",
|
1035 |
+
"3 How can someone help a person who suffers from... \n",
|
1036 |
+
"4 What are the causes of mental illness listed i... \n",
|
1037 |
+
"\n",
|
1038 |
+
" Answers latency \\\n",
|
1039 |
+
"0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
|
1040 |
+
"1 The most common mental disorders include depre... 4.509713 \n",
|
1041 |
+
"2 Early warning signs and symptoms of depression... 8.501180 \n",
|
1042 |
+
"3 To help someone with anxiety, one can support ... 10.611402 \n",
|
1043 |
+
"4 Causes of mental illness include abnormal func... 6.299272 \n",
|
1044 |
+
"\n",
|
1045 |
+
" response \\\n",
|
1046 |
+
"0 Based on the context provided, mental health r... \n",
|
1047 |
+
"1 The handbook mentions several mental illnesses... \n",
|
1048 |
+
"2 According to the provided context, specificall... \n",
|
1049 |
+
"3 According to the Mental Health Handbook, when ... \n",
|
1050 |
+
"4 According to the context, the causes of mental... \n",
|
1051 |
+
"\n",
|
1052 |
+
" correctness_reasoning correctness_value \\\n",
|
1053 |
+
"0 Step 1: Evaluate if the submission is factuall... N \n",
|
1054 |
+
"1 The submission mentions depression and schizop... N \n",
|
1055 |
+
"2 The submission matches the reference data in t... Y \n",
|
1056 |
+
"3 The submission seems consistent with the refer... Y \n",
|
1057 |
+
"4 The submission lists causes such as neglect, s... N \n",
|
1058 |
+
"\n",
|
1059 |
+
" correctness_score relevance_reasoning \\\n",
|
1060 |
+
"0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
|
1061 |
+
"1 0 Step 1: Analyze relevance criterion - Check if... \n",
|
1062 |
+
"2 1 The submission refers directly to information ... \n",
|
1063 |
+
"3 1 Step 1: Review relevance criterion\\nThe submis... \n",
|
1064 |
+
"4 0 The submission mentions factors that are part ... \n",
|
1065 |
+
"\n",
|
1066 |
+
" relevance_value relevance_score \\\n",
|
1067 |
+
"0 N 0 \n",
|
1068 |
+
"1 Y 1 \n",
|
1069 |
+
"2 Y 1 \n",
|
1070 |
+
"3 Y 1 \n",
|
1071 |
+
"4 N 0 \n",
|
1072 |
+
"\n",
|
1073 |
+
" coherence_reasoning coherence_value \\\n",
|
1074 |
+
"0 The submission discusses mental health in rela... Y \n",
|
1075 |
+
"1 Step 1: Assess coherence\\nThe submission menti... N \n",
|
1076 |
+
"2 Step 1: Evaluate coherence - The submission is... Y \n",
|
1077 |
+
"3 The submission is coherent, well-structured, a... Y \n",
|
1078 |
+
"4 The submission is coherent and well-structured... Y \n",
|
1079 |
+
"\n",
|
1080 |
+
" coherence_score conciseness_reasoning \\\n",
|
1081 |
+
"0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
|
1082 |
+
"1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
|
1083 |
+
"2 1 The submission is concise and includes most of... \n",
|
1084 |
+
"3 1 The submission is relatively concise and cover... \n",
|
1085 |
+
"4 1 Step 1: Read and understand both the input dat... \n",
|
1086 |
+
"\n",
|
1087 |
+
" conciseness_value conciseness_score \n",
|
1088 |
+
"0 Y 1 \n",
|
1089 |
+
"1 N 0 \n",
|
1090 |
+
"2 Y 1 \n",
|
1091 |
+
"3 Y 1 \n",
|
1092 |
+
"4 N 0 "
|
1093 |
+
]
|
1094 |
+
},
|
1095 |
+
"execution_count": 77,
|
1096 |
+
"metadata": {},
|
1097 |
+
"output_type": "execute_result"
|
1098 |
+
}
|
1099 |
+
],
|
1100 |
+
"source": [
|
1101 |
+
"df2.head()"
|
1102 |
+
]
|
1103 |
+
},
|
1104 |
+
{
|
1105 |
+
"cell_type": "code",
|
1106 |
+
"execution_count": 47,
|
1107 |
+
"id": "2d1002b2",
|
1108 |
+
"metadata": {},
|
1109 |
+
"outputs": [
|
1110 |
+
{
|
1111 |
+
"data": {
|
1112 |
+
"text/plain": [
|
1113 |
+
"correctness_score 0.500000\n",
|
1114 |
+
"relevance_score 0.888889\n",
|
1115 |
+
"coherence_score 0.888889\n",
|
1116 |
+
"conciseness_score 0.900000\n",
|
1117 |
+
"latency 8.190205\n",
|
1118 |
+
"dtype: float64"
|
1119 |
+
]
|
1120 |
+
},
|
1121 |
+
"execution_count": 47,
|
1122 |
+
"metadata": {},
|
1123 |
+
"output_type": "execute_result"
|
1124 |
+
}
|
1125 |
+
],
|
1126 |
+
"source": [
|
1127 |
+
"df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
|
1128 |
+
]
|
1129 |
+
},
|
1130 |
+
{
|
1131 |
+
"cell_type": "markdown",
|
1132 |
+
"id": "e808bdcf",
|
1133 |
+
"metadata": {},
|
1134 |
+
"source": [
|
1135 |
+
"# Query relevance"
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "code",
|
1140 |
+
"execution_count": 66,
|
1141 |
+
"id": "6b541f3d",
|
1142 |
+
"metadata": {},
|
1143 |
+
"outputs": [],
|
1144 |
+
"source": [
|
1145 |
+
"def new_search_faiss(query, k=3, threshold=1.5):\n",
|
1146 |
+
" query_vector = model.encode([query])[0].astype('float32')\n",
|
1147 |
+
" query_vector = np.expand_dims(query_vector, axis=0)\n",
|
1148 |
+
" distances, indices = index.search(query_vector, k)\n",
|
1149 |
+
" \n",
|
1150 |
+
" results = []\n",
|
1151 |
+
" for dist, idx in zip(distances[0], indices[0]):\n",
|
1152 |
+
" if dist < threshold: # Only include results within the threshold distance\n",
|
1153 |
+
" results.append({\n",
|
1154 |
+
" 'distance': dist,\n",
|
1155 |
+
" 'content': sections_data[idx]['content'],\n",
|
1156 |
+
" 'metadata': sections_data[idx]['metadata']\n",
|
1157 |
+
" })\n",
|
1158 |
+
" \n",
|
1159 |
+
" return results"
|
1160 |
+
]
|
1161 |
+
},
|
1162 |
+
{
|
1163 |
+
"cell_type": "code",
|
1164 |
+
"execution_count": 70,
|
1165 |
+
"id": "4f579654",
|
1166 |
+
"metadata": {},
|
1167 |
+
"outputs": [],
|
1168 |
+
"source": [
|
1169 |
+
"new_prompt_template = \"\"\"\n",
|
1170 |
+
"You are an AI assistant specialized in Mental Health guidelines.\n",
|
1171 |
+
"Use the provided context to answer the question short and accurately. \n",
|
1172 |
+
"If you don't know the answer, simply say, \"I don't know.\"\n",
|
1173 |
+
"\n",
|
1174 |
+
"Context:\n",
|
1175 |
+
"{context}\n",
|
1176 |
+
"\n",
|
1177 |
+
"Question: {question}\n",
|
1178 |
+
"\n",
|
1179 |
+
"Answer:\"\"\"\n",
|
1180 |
+
"\n",
|
1181 |
+
"prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
|
1182 |
+
"\n",
|
1183 |
+
"llm = Ollama(\n",
|
1184 |
+
" model=\"llama3\"\n",
|
1185 |
+
")\n",
|
1186 |
+
"\n",
|
1187 |
+
"# Create the chain\n",
|
1188 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
1189 |
+
"\n",
|
1190 |
+
"def new_answer_question(query):\n",
|
1191 |
+
" # Search for relevant context\n",
|
1192 |
+
" search_results = new_search_faiss(query)\n",
|
1193 |
+
" \n",
|
1194 |
+
" if search_results==[]:\n",
|
1195 |
+
" response=\"I don't know, sorry\"\n",
|
1196 |
+
" else:\n",
|
1197 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
1198 |
+
" response = chain.run(context=context, question=query)\n",
|
1199 |
+
" \n",
|
1200 |
+
" return response"
|
1201 |
+
]
|
1202 |
+
},
|
1203 |
+
{
|
1204 |
+
"cell_type": "code",
|
1205 |
+
"execution_count": 71,
|
1206 |
+
"id": "1f83ef1b",
|
1207 |
+
"metadata": {},
|
1208 |
+
"outputs": [],
|
1209 |
+
"source": [
|
1210 |
+
"irr_q2=irr_q.copy()"
|
1211 |
+
]
|
1212 |
+
},
|
1213 |
+
{
|
1214 |
+
"cell_type": "code",
|
1215 |
+
"execution_count": 72,
|
1216 |
+
"id": "f06474e3",
|
1217 |
+
"metadata": {},
|
1218 |
+
"outputs": [
|
1219 |
+
{
|
1220 |
+
"name": "stderr",
|
1221 |
+
"output_type": "stream",
|
1222 |
+
"text": [
|
1223 |
+
"100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
|
1224 |
+
]
|
1225 |
+
}
|
1226 |
+
],
|
1227 |
+
"source": [
|
1228 |
+
"time_list=[]\n",
|
1229 |
+
"response_list=[]\n",
|
1230 |
+
"for i in tqdm(range(len(irr_q2))):\n",
|
1231 |
+
" query = irr_q['Questions'].values[i]\n",
|
1232 |
+
" start = time.time()\n",
|
1233 |
+
" response = new_answer_question(query)\n",
|
1234 |
+
" end = time.time() \n",
|
1235 |
+
" time_list.append(end-start)\n",
|
1236 |
+
" response_list.append(response)"
|
1237 |
+
]
|
1238 |
+
},
|
1239 |
+
{
|
1240 |
+
"cell_type": "code",
|
1241 |
+
"execution_count": 73,
|
1242 |
+
"id": "52db6b82",
|
1243 |
+
"metadata": {},
|
1244 |
+
"outputs": [],
|
1245 |
+
"source": [
|
1246 |
+
"irr_q2['response']=response_list\n",
|
1247 |
+
"irr_q2['latency']=time_list"
|
1248 |
+
]
|
1249 |
+
},
|
1250 |
+
{
|
1251 |
+
"cell_type": "code",
|
1252 |
+
"execution_count": 80,
|
1253 |
+
"id": "80a178ee",
|
1254 |
+
"metadata": {},
|
1255 |
+
"outputs": [
|
1256 |
+
{
|
1257 |
+
"data": {
|
1258 |
+
"text/html": [
|
1259 |
+
"<div>\n",
|
1260 |
+
"<style scoped>\n",
|
1261 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
1262 |
+
" vertical-align: middle;\n",
|
1263 |
+
" }\n",
|
1264 |
+
"\n",
|
1265 |
+
" .dataframe tbody tr th {\n",
|
1266 |
+
" vertical-align: top;\n",
|
1267 |
+
" }\n",
|
1268 |
+
"\n",
|
1269 |
+
" .dataframe thead th {\n",
|
1270 |
+
" text-align: right;\n",
|
1271 |
+
" }\n",
|
1272 |
+
"</style>\n",
|
1273 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
1274 |
+
" <thead>\n",
|
1275 |
+
" <tr style=\"text-align: right;\">\n",
|
1276 |
+
" <th></th>\n",
|
1277 |
+
" <th>Questions</th>\n",
|
1278 |
+
" <th>response</th>\n",
|
1279 |
+
" <th>latency</th>\n",
|
1280 |
+
" <th>irrelevant_score</th>\n",
|
1281 |
+
" </tr>\n",
|
1282 |
+
" </thead>\n",
|
1283 |
+
" <tbody>\n",
|
1284 |
+
" <tr>\n",
|
1285 |
+
" <th>0</th>\n",
|
1286 |
+
" <td>What is the capital of Mars?</td>\n",
|
1287 |
+
" <td>I don't know, sorry</td>\n",
|
1288 |
+
" <td>0.061378</td>\n",
|
1289 |
+
" <td>True</td>\n",
|
1290 |
+
" </tr>\n",
|
1291 |
+
" <tr>\n",
|
1292 |
+
" <th>1</th>\n",
|
1293 |
+
" <td>How many unicorns live in New York City?</td>\n",
|
1294 |
+
" <td>I don't know, sorry</td>\n",
|
1295 |
+
" <td>0.012511</td>\n",
|
1296 |
+
" <td>True</td>\n",
|
1297 |
+
" </tr>\n",
|
1298 |
+
" <tr>\n",
|
1299 |
+
" <th>2</th>\n",
|
1300 |
+
" <td>What is the color of happiness?</td>\n",
|
1301 |
+
" <td>I don't know, sorry</td>\n",
|
1302 |
+
" <td>0.011900</td>\n",
|
1303 |
+
" <td>True</td>\n",
|
1304 |
+
" </tr>\n",
|
1305 |
+
" <tr>\n",
|
1306 |
+
" <th>3</th>\n",
|
1307 |
+
" <td>Can cats fly on Tuesdays?</td>\n",
|
1308 |
+
" <td>I don't know, sorry</td>\n",
|
1309 |
+
" <td>0.011438</td>\n",
|
1310 |
+
" <td>True</td>\n",
|
1311 |
+
" </tr>\n",
|
1312 |
+
" <tr>\n",
|
1313 |
+
" <th>4</th>\n",
|
1314 |
+
" <td>How much does a thought weigh?</td>\n",
|
1315 |
+
" <td>I don't know, sorry</td>\n",
|
1316 |
+
" <td>0.010644</td>\n",
|
1317 |
+
" <td>True</td>\n",
|
1318 |
+
" </tr>\n",
|
1319 |
+
" </tbody>\n",
|
1320 |
+
"</table>\n",
|
1321 |
+
"</div>"
|
1322 |
+
],
|
1323 |
+
"text/plain": [
|
1324 |
+
" Questions response latency \\\n",
|
1325 |
+
"0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
|
1326 |
+
"1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
|
1327 |
+
"2 What is the color of happiness? I don't know, sorry 0.011900 \n",
|
1328 |
+
"3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
|
1329 |
+
"4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
|
1330 |
+
"\n",
|
1331 |
+
" irrelevant_score \n",
|
1332 |
+
"0 True \n",
|
1333 |
+
"1 True \n",
|
1334 |
+
"2 True \n",
|
1335 |
+
"3 True \n",
|
1336 |
+
"4 True "
|
1337 |
+
]
|
1338 |
+
},
|
1339 |
+
"execution_count": 80,
|
1340 |
+
"metadata": {},
|
1341 |
+
"output_type": "execute_result"
|
1342 |
+
}
|
1343 |
+
],
|
1344 |
+
"source": [
|
1345 |
+
"irr_q2.head()"
|
1346 |
+
]
|
1347 |
+
},
|
1348 |
+
{
|
1349 |
+
"cell_type": "code",
|
1350 |
+
"execution_count": 74,
|
1351 |
+
"id": "4508de9e",
|
1352 |
+
"metadata": {},
|
1353 |
+
"outputs": [],
|
1354 |
+
"source": [
|
1355 |
+
"irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
|
1356 |
+
]
|
1357 |
+
},
|
1358 |
+
{
|
1359 |
+
"cell_type": "code",
|
1360 |
+
"execution_count": 75,
|
1361 |
+
"id": "3d34ba06",
|
1362 |
+
"metadata": {},
|
1363 |
+
"outputs": [
|
1364 |
+
{
|
1365 |
+
"data": {
|
1366 |
+
"text/plain": [
|
1367 |
+
"irrelevant_score 1.000000\n",
|
1368 |
+
"latency 0.016068\n",
|
1369 |
+
"dtype: float64"
|
1370 |
+
]
|
1371 |
+
},
|
1372 |
+
"execution_count": 75,
|
1373 |
+
"metadata": {},
|
1374 |
+
"output_type": "execute_result"
|
1375 |
+
}
|
1376 |
+
],
|
1377 |
+
"source": [
|
1378 |
+
"irr_q2[['irrelevant_score','latency']].mean()"
|
1379 |
+
]
|
1380 |
+
}
|
1381 |
+
],
|
1382 |
+
"metadata": {
|
1383 |
+
"kernelspec": {
|
1384 |
+
"display_name": "Python 3 (ipykernel)",
|
1385 |
+
"language": "python",
|
1386 |
+
"name": "python3"
|
1387 |
+
},
|
1388 |
+
"language_info": {
|
1389 |
+
"codemirror_mode": {
|
1390 |
+
"name": "ipython",
|
1391 |
+
"version": 3
|
1392 |
+
},
|
1393 |
+
"file_extension": ".py",
|
1394 |
+
"mimetype": "text/x-python",
|
1395 |
+
"name": "python",
|
1396 |
+
"nbconvert_exporter": "python",
|
1397 |
+
"pygments_lexer": "ipython3",
|
1398 |
+
"version": "3.11.5"
|
1399 |
+
}
|
1400 |
+
},
|
1401 |
+
"nbformat": 4,
|
1402 |
+
"nbformat_minor": 5
|
1403 |
+
}
|
Evaluation_MH/Evaluation.ipynb
ADDED
@@ -0,0 +1,1403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "f7b87c2c",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Imports"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 5,
|
14 |
+
"id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"# pip install -U sentence-transformers"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"id": "8a7cc9d8",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stderr",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
32 |
+
" from tqdm.autonotebook import tqdm, trange\n"
|
33 |
+
]
|
34 |
+
}
|
35 |
+
],
|
36 |
+
"source": [
|
37 |
+
"from sentence_transformers import SentenceTransformer\n",
|
38 |
+
"from langchain.prompts import PromptTemplate\n",
|
39 |
+
"from langchain.chains import LLMChain\n",
|
40 |
+
"from langchain_community.llms import Ollama\n",
|
41 |
+
"from langchain.evaluation import load_evaluator\n",
|
42 |
+
"import faiss\n",
|
43 |
+
"import pandas as pd\n",
|
44 |
+
"import numpy as np\n",
|
45 |
+
"import pickle\n",
|
46 |
+
"import time\n",
|
47 |
+
"from tqdm import tqdm"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "markdown",
|
52 |
+
"id": "b6efca1d",
|
53 |
+
"metadata": {},
|
54 |
+
"source": [
|
55 |
+
"# Intialization"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 2,
|
61 |
+
"id": "cc9a49d2",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"# Load the FAISS index\n",
|
66 |
+
"index = faiss.read_index(\"database/pdf_sections_index.faiss\")"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 3,
|
72 |
+
"id": "9af39b55",
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"model = SentenceTransformer('all-MiniLM-L6-v2')"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": 4,
|
82 |
+
"id": "fee8cdfd",
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"with open('database/pdf_sections_data.pkl', 'rb') as f:\n",
|
87 |
+
" sections_data = pickle.load(f)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"id": "d6a1ba6a",
|
93 |
+
"metadata": {},
|
94 |
+
"source": [
|
95 |
+
"# RAG functions"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": 5,
|
101 |
+
"id": "182bdbd8",
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"def search_faiss(query, k=3):\n",
|
106 |
+
" query_vector = model.encode([query])[0].astype('float32')\n",
|
107 |
+
" query_vector = np.expand_dims(query_vector, axis=0)\n",
|
108 |
+
" distances, indices = index.search(query_vector, k)\n",
|
109 |
+
" \n",
|
110 |
+
" results = []\n",
|
111 |
+
" for dist, idx in zip(distances[0], indices[0]):\n",
|
112 |
+
" results.append({\n",
|
113 |
+
" 'distance': dist,\n",
|
114 |
+
" 'content': sections_data[idx]['content'],\n",
|
115 |
+
" 'metadata': sections_data[idx]['metadata']\n",
|
116 |
+
" })\n",
|
117 |
+
" \n",
|
118 |
+
" return results"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 15,
|
124 |
+
"id": "67edc46a",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"# Create a prompt template\n",
|
129 |
+
"prompt_template = \"\"\"\n",
|
130 |
+
"You are an AI assistant specialized in Mental Health guidelines. \n",
|
131 |
+
"Use the following pieces of context to answer the question. \n",
|
132 |
+
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
|
133 |
+
"\n",
|
134 |
+
"Context:\n",
|
135 |
+
"{context}\n",
|
136 |
+
"\n",
|
137 |
+
"Question: {question}\n",
|
138 |
+
"\n",
|
139 |
+
"Answer:\"\"\"\n",
|
140 |
+
"\n",
|
141 |
+
"prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
|
142 |
+
"\n",
|
143 |
+
"llm = Ollama(\n",
|
144 |
+
" model=\"llama3\"\n",
|
145 |
+
")\n",
|
146 |
+
"\n",
|
147 |
+
"# Create the chain\n",
|
148 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
149 |
+
"\n",
|
150 |
+
"def answer_question(query):\n",
|
151 |
+
" # Search for relevant context\n",
|
152 |
+
" search_results = search_faiss(query)\n",
|
153 |
+
" \n",
|
154 |
+
" # Combine the content from the search results\n",
|
155 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
156 |
+
"\n",
|
157 |
+
" # Run the chain\n",
|
158 |
+
" response = chain.run(context=context, question=query)\n",
|
159 |
+
" \n",
|
160 |
+
" return response"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"id": "3b176af9",
|
166 |
+
"metadata": {},
|
167 |
+
"source": [
|
168 |
+
"# Reading GT"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": 16,
|
174 |
+
"id": "4ab68dff",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"df = pd.read_csv('data/MentalHealth_Dataset.csv')"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 17,
|
184 |
+
"id": "4e7e22d7",
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"name": "stderr",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n"
|
192 |
+
]
|
193 |
+
}
|
194 |
+
],
|
195 |
+
"source": [
|
196 |
+
"time_list=[]\n",
|
197 |
+
"response_list=[]\n",
|
198 |
+
"for i in tqdm(range(len(df))):\n",
|
199 |
+
" query = df['Questions'].values[i]\n",
|
200 |
+
" start = time.time()\n",
|
201 |
+
" response = answer_question(query)\n",
|
202 |
+
" end = time.time() \n",
|
203 |
+
" time_list.append(end-start)\n",
|
204 |
+
" response_list.append(response)"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 18,
|
210 |
+
"id": "2b327e90",
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"df['latency'] = time_list\n",
|
215 |
+
"df['response'] = response_list"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "markdown",
|
220 |
+
"id": "3c147204",
|
221 |
+
"metadata": {},
|
222 |
+
"source": [
|
223 |
+
"# Evaluation"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": 29,
|
229 |
+
"id": "d799e541",
|
230 |
+
"metadata": {},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"eval_llm = Ollama(\n",
|
234 |
+
" model=\"phi3\"\n",
|
235 |
+
")"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 30,
|
241 |
+
"id": "c2f788dc",
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"metrics = ['correctness', 'relevance', 'coherence', 'conciseness']"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": 31,
|
251 |
+
"id": "83ec2b8d",
|
252 |
+
"metadata": {},
|
253 |
+
"outputs": [
|
254 |
+
{
|
255 |
+
"name": "stderr",
|
256 |
+
"output_type": "stream",
|
257 |
+
"text": [
|
258 |
+
"100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n",
|
259 |
+
"100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n",
|
260 |
+
"100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n",
|
261 |
+
"100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n"
|
262 |
+
]
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"for metric in metrics:\n",
|
267 |
+
" evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
|
268 |
+
" \n",
|
269 |
+
" reasoning = []\n",
|
270 |
+
" value = []\n",
|
271 |
+
" score = []\n",
|
272 |
+
" \n",
|
273 |
+
" for i in tqdm(range(len(df))):\n",
|
274 |
+
" eval_result = evaluator.evaluate_strings(\n",
|
275 |
+
" prediction=df.response.values[i],\n",
|
276 |
+
" input=df.Questions.values[i],\n",
|
277 |
+
" reference=df.Answers.values[i]\n",
|
278 |
+
" )\n",
|
279 |
+
" reasoning.append(eval_result['reasoning'])\n",
|
280 |
+
" value.append(eval_result['value'])\n",
|
281 |
+
" score.append(eval_result['score'])\n",
|
282 |
+
" \n",
|
283 |
+
" df[metric+'_reasoning'] = reasoning\n",
|
284 |
+
" df[metric+'_value'] = value\n",
|
285 |
+
" df[metric+'_score'] = score "
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 78,
|
291 |
+
"id": "f1673a31",
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [
|
294 |
+
{
|
295 |
+
"data": {
|
296 |
+
"text/html": [
|
297 |
+
"<div>\n",
|
298 |
+
"<style scoped>\n",
|
299 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
300 |
+
" vertical-align: middle;\n",
|
301 |
+
" }\n",
|
302 |
+
"\n",
|
303 |
+
" .dataframe tbody tr th {\n",
|
304 |
+
" vertical-align: top;\n",
|
305 |
+
" }\n",
|
306 |
+
"\n",
|
307 |
+
" .dataframe thead th {\n",
|
308 |
+
" text-align: right;\n",
|
309 |
+
" }\n",
|
310 |
+
"</style>\n",
|
311 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
312 |
+
" <thead>\n",
|
313 |
+
" <tr style=\"text-align: right;\">\n",
|
314 |
+
" <th></th>\n",
|
315 |
+
" <th>Questions</th>\n",
|
316 |
+
" <th>Answers</th>\n",
|
317 |
+
" <th>latency</th>\n",
|
318 |
+
" <th>response</th>\n",
|
319 |
+
" <th>correctness_reasoning</th>\n",
|
320 |
+
" <th>correctness_value</th>\n",
|
321 |
+
" <th>correctness_score</th>\n",
|
322 |
+
" <th>relevance_reasoning</th>\n",
|
323 |
+
" <th>relevance_value</th>\n",
|
324 |
+
" <th>relevance_score</th>\n",
|
325 |
+
" <th>coherence_reasoning</th>\n",
|
326 |
+
" <th>coherence_value</th>\n",
|
327 |
+
" <th>coherence_score</th>\n",
|
328 |
+
" <th>conciseness_reasoning</th>\n",
|
329 |
+
" <th>conciseness_value</th>\n",
|
330 |
+
" <th>conciseness_score</th>\n",
|
331 |
+
" </tr>\n",
|
332 |
+
" </thead>\n",
|
333 |
+
" <tbody>\n",
|
334 |
+
" <tr>\n",
|
335 |
+
" <th>0</th>\n",
|
336 |
+
" <td>What is Mental Health</td>\n",
|
337 |
+
" <td>Mental Health is a \" state of well-being in wh...</td>\n",
|
338 |
+
" <td>11.974234</td>\n",
|
339 |
+
" <td>Based on the provided context, specifically fr...</td>\n",
|
340 |
+
" <td>The submission refers to the provided input wh...</td>\n",
|
341 |
+
" <td>Y</td>\n",
|
342 |
+
" <td>1</td>\n",
|
343 |
+
" <td>Step 1: Evaluate relevance criterion\\nThe subm...</td>\n",
|
344 |
+
" <td>Y</td>\n",
|
345 |
+
" <td>1</td>\n",
|
346 |
+
" <td>Step 1: Assess coherence\\nThe submission direc...</td>\n",
|
347 |
+
" <td>Y</td>\n",
|
348 |
+
" <td>1</td>\n",
|
349 |
+
" <td>1. The submission directly answers the questio...</td>\n",
|
350 |
+
" <td>Y</td>\n",
|
351 |
+
" <td>1</td>\n",
|
352 |
+
" </tr>\n",
|
353 |
+
" <tr>\n",
|
354 |
+
" <th>1</th>\n",
|
355 |
+
" <td>What are the most common mental disorders ment...</td>\n",
|
356 |
+
" <td>The most common mental disorders include depre...</td>\n",
|
357 |
+
" <td>5.863329</td>\n",
|
358 |
+
" <td>Based on the provided context, the mental diso...</td>\n",
|
359 |
+
" <td>Step 1: Check if the submission is factually a...</td>\n",
|
360 |
+
" <td>Y</td>\n",
|
361 |
+
" <td>1</td>\n",
|
362 |
+
" <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
|
363 |
+
" <td>Y</td>\n",
|
364 |
+
" <td>1</td>\n",
|
365 |
+
" <td>The submission begins with an appropriate ques...</td>\n",
|
366 |
+
" <td>Y</td>\n",
|
367 |
+
" <td>1</td>\n",
|
368 |
+
" <td>Step 1: Review conciseness criterion\\nThe subm...</td>\n",
|
369 |
+
" <td>Y</td>\n",
|
370 |
+
" <td>1</td>\n",
|
371 |
+
" </tr>\n",
|
372 |
+
" <tr>\n",
|
373 |
+
" <th>2</th>\n",
|
374 |
+
" <td>What are the early warning signs and symptoms ...</td>\n",
|
375 |
+
" <td>Early warning signs and symptoms of depression...</td>\n",
|
376 |
+
" <td>13.434543</td>\n",
|
377 |
+
" <td>Based on the provided context, I found a refer...</td>\n",
|
378 |
+
" <td>Step 1: Evaluate Correctness\\nThe submission a...</td>\n",
|
379 |
+
" <td>Y</td>\n",
|
380 |
+
" <td>1</td>\n",
|
381 |
+
" <td>Step 1: Identify the relevant criterion from t...</td>\n",
|
382 |
+
" <td>Y</td>\n",
|
383 |
+
" <td>1</td>\n",
|
384 |
+
" <td>Step 1: Evaluate coherence\\nThe submission is ...</td>\n",
|
385 |
+
" <td>Y</td>\n",
|
386 |
+
" <td>1</td>\n",
|
387 |
+
" <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
|
388 |
+
" <td>Y</td>\n",
|
389 |
+
" <td>1</td>\n",
|
390 |
+
" </tr>\n",
|
391 |
+
" <tr>\n",
|
392 |
+
" <th>3</th>\n",
|
393 |
+
" <td>How can someone help a person who suffers from...</td>\n",
|
394 |
+
" <td>To help someone with anxiety, one can support ...</td>\n",
|
395 |
+
" <td>13.838464</td>\n",
|
396 |
+
" <td>According to the provided context, specificall...</td>\n",
|
397 |
+
" <td>Step 1: Correctness\\nThe submission accurately...</td>\n",
|
398 |
+
" <td>Y</td>\n",
|
399 |
+
" <td>1</td>\n",
|
400 |
+
" <td>Step 1: Analyze relevance criterion\\nThe submi...</td>\n",
|
401 |
+
" <td>Y</td>\n",
|
402 |
+
" <td>1</td>\n",
|
403 |
+
" <td>Step 1: Evaluate coherence\\nThe submission dis...</td>\n",
|
404 |
+
" <td>Y</td>\n",
|
405 |
+
" <td>1</td>\n",
|
406 |
+
" <td>Step 1: Evaluate conciseness - The submission ...</td>\n",
|
407 |
+
" <td>N</td>\n",
|
408 |
+
" <td>0</td>\n",
|
409 |
+
" </tr>\n",
|
410 |
+
" <tr>\n",
|
411 |
+
" <th>4</th>\n",
|
412 |
+
" <td>What are the causes of mental illness listed i...</td>\n",
|
413 |
+
" <td>Causes of mental illness include abnormal func...</td>\n",
|
414 |
+
" <td>6.871735</td>\n",
|
415 |
+
" <td>According to the provided context, the causes ...</td>\n",
|
416 |
+
" <td>The submission lists factors that align with t...</td>\n",
|
417 |
+
" <td>N</td>\n",
|
418 |
+
" <td>0</td>\n",
|
419 |
+
" <td>Step 1: Review relevance criterion - Check if ...</td>\n",
|
420 |
+
" <td>Y</td>\n",
|
421 |
+
" <td>1</td>\n",
|
422 |
+
" <td>Step 1: Compare the submission with the provid...</td>\n",
|
423 |
+
" <td>Y</td>\n",
|
424 |
+
" <td>1</td>\n",
|
425 |
+
" <td>Step 1: Assess conciseness\\nThe submission is ...</td>\n",
|
426 |
+
" <td>Y</td>\n",
|
427 |
+
" <td>1</td>\n",
|
428 |
+
" </tr>\n",
|
429 |
+
" </tbody>\n",
|
430 |
+
"</table>\n",
|
431 |
+
"</div>"
|
432 |
+
],
|
433 |
+
"text/plain": [
|
434 |
+
" Questions \\\n",
|
435 |
+
"0 What is Mental Health \n",
|
436 |
+
"1 What are the most common mental disorders ment... \n",
|
437 |
+
"2 What are the early warning signs and symptoms ... \n",
|
438 |
+
"3 How can someone help a person who suffers from... \n",
|
439 |
+
"4 What are the causes of mental illness listed i... \n",
|
440 |
+
"\n",
|
441 |
+
" Answers latency \\\n",
|
442 |
+
"0 Mental Health is a \" state of well-being in wh... 11.974234 \n",
|
443 |
+
"1 The most common mental disorders include depre... 5.863329 \n",
|
444 |
+
"2 Early warning signs and symptoms of depression... 13.434543 \n",
|
445 |
+
"3 To help someone with anxiety, one can support ... 13.838464 \n",
|
446 |
+
"4 Causes of mental illness include abnormal func... 6.871735 \n",
|
447 |
+
"\n",
|
448 |
+
" response \\\n",
|
449 |
+
"0 Based on the provided context, specifically fr... \n",
|
450 |
+
"1 Based on the provided context, the mental diso... \n",
|
451 |
+
"2 Based on the provided context, I found a refer... \n",
|
452 |
+
"3 According to the provided context, specificall... \n",
|
453 |
+
"4 According to the provided context, the causes ... \n",
|
454 |
+
"\n",
|
455 |
+
" correctness_reasoning correctness_value \\\n",
|
456 |
+
"0 The submission refers to the provided input wh... Y \n",
|
457 |
+
"1 Step 1: Check if the submission is factually a... Y \n",
|
458 |
+
"2 Step 1: Evaluate Correctness\\nThe submission a... Y \n",
|
459 |
+
"3 Step 1: Correctness\\nThe submission accurately... Y \n",
|
460 |
+
"4 The submission lists factors that align with t... N \n",
|
461 |
+
"\n",
|
462 |
+
" correctness_score relevance_reasoning \\\n",
|
463 |
+
"0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n",
|
464 |
+
"1 1 Step 1: Analyze the relevance criterion\\nThe s... \n",
|
465 |
+
"2 1 Step 1: Identify the relevant criterion from t... \n",
|
466 |
+
"3 1 Step 1: Analyze relevance criterion\\nThe submi... \n",
|
467 |
+
"4 0 Step 1: Review relevance criterion - Check if ... \n",
|
468 |
+
"\n",
|
469 |
+
" relevance_value relevance_score \\\n",
|
470 |
+
"0 Y 1 \n",
|
471 |
+
"1 Y 1 \n",
|
472 |
+
"2 Y 1 \n",
|
473 |
+
"3 Y 1 \n",
|
474 |
+
"4 Y 1 \n",
|
475 |
+
"\n",
|
476 |
+
" coherence_reasoning coherence_value \\\n",
|
477 |
+
"0 Step 1: Assess coherence\\nThe submission direc... Y \n",
|
478 |
+
"1 The submission begins with an appropriate ques... Y \n",
|
479 |
+
"2 Step 1: Evaluate coherence\\nThe submission is ... Y \n",
|
480 |
+
"3 Step 1: Evaluate coherence\\nThe submission dis... Y \n",
|
481 |
+
"4 Step 1: Compare the submission with the provid... Y \n",
|
482 |
+
"\n",
|
483 |
+
" coherence_score conciseness_reasoning \\\n",
|
484 |
+
"0 1 1. The submission directly answers the questio... \n",
|
485 |
+
"1 1 Step 1: Review conciseness criterion\\nThe subm... \n",
|
486 |
+
"2 1 Step 1: Evaluate conciseness - The submission ... \n",
|
487 |
+
"3 1 Step 1: Evaluate conciseness - The submission ... \n",
|
488 |
+
"4 1 Step 1: Assess conciseness\\nThe submission is ... \n",
|
489 |
+
"\n",
|
490 |
+
" conciseness_value conciseness_score \n",
|
491 |
+
"0 Y 1 \n",
|
492 |
+
"1 Y 1 \n",
|
493 |
+
"2 Y 1 \n",
|
494 |
+
"3 N 0 \n",
|
495 |
+
"4 Y 1 "
|
496 |
+
]
|
497 |
+
},
|
498 |
+
"execution_count": 78,
|
499 |
+
"metadata": {},
|
500 |
+
"output_type": "execute_result"
|
501 |
+
}
|
502 |
+
],
|
503 |
+
"source": [
|
504 |
+
"df.head()"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": 32,
|
510 |
+
"id": "7797a360",
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [
|
513 |
+
{
|
514 |
+
"data": {
|
515 |
+
"text/plain": [
|
516 |
+
"correctness_score 0.800000\n",
|
517 |
+
"relevance_score 0.900000\n",
|
518 |
+
"coherence_score 1.000000\n",
|
519 |
+
"conciseness_score 0.800000\n",
|
520 |
+
"latency 10.544803\n",
|
521 |
+
"dtype: float64"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
"execution_count": 32,
|
525 |
+
"metadata": {},
|
526 |
+
"output_type": "execute_result"
|
527 |
+
}
|
528 |
+
],
|
529 |
+
"source": [
|
530 |
+
"df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
|
531 |
+
]
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"cell_type": "code",
|
535 |
+
"execution_count": 34,
|
536 |
+
"id": "fe667926",
|
537 |
+
"metadata": {},
|
538 |
+
"outputs": [],
|
539 |
+
"source": [
|
540 |
+
"irr_q=pd.read_csv('data/Unrelated_questions.csv')"
|
541 |
+
]
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"cell_type": "code",
|
545 |
+
"execution_count": 35,
|
546 |
+
"id": "189f8a0f",
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [
|
549 |
+
{
|
550 |
+
"name": "stderr",
|
551 |
+
"output_type": "stream",
|
552 |
+
"text": [
|
553 |
+
"100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n"
|
554 |
+
]
|
555 |
+
}
|
556 |
+
],
|
557 |
+
"source": [
|
558 |
+
"time_list=[]\n",
|
559 |
+
"response_list=[]\n",
|
560 |
+
"for i in tqdm(range(len(irr_q))):\n",
|
561 |
+
" query = irr_q['Questions'].values[i]\n",
|
562 |
+
" start = time.time()\n",
|
563 |
+
" response = answer_question(query)\n",
|
564 |
+
" end = time.time() \n",
|
565 |
+
" time_list.append(end-start)\n",
|
566 |
+
" response_list.append(response)"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "code",
|
571 |
+
"execution_count": 36,
|
572 |
+
"id": "b0244ea0",
|
573 |
+
"metadata": {},
|
574 |
+
"outputs": [],
|
575 |
+
"source": [
|
576 |
+
"irr_q['response']=response_list\n",
|
577 |
+
"irr_q['latency']=time_list"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 79,
|
583 |
+
"id": "dc3b1ade",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [
|
586 |
+
{
|
587 |
+
"data": {
|
588 |
+
"text/html": [
|
589 |
+
"<div>\n",
|
590 |
+
"<style scoped>\n",
|
591 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
592 |
+
" vertical-align: middle;\n",
|
593 |
+
" }\n",
|
594 |
+
"\n",
|
595 |
+
" .dataframe tbody tr th {\n",
|
596 |
+
" vertical-align: top;\n",
|
597 |
+
" }\n",
|
598 |
+
"\n",
|
599 |
+
" .dataframe thead th {\n",
|
600 |
+
" text-align: right;\n",
|
601 |
+
" }\n",
|
602 |
+
"</style>\n",
|
603 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
604 |
+
" <thead>\n",
|
605 |
+
" <tr style=\"text-align: right;\">\n",
|
606 |
+
" <th></th>\n",
|
607 |
+
" <th>Questions</th>\n",
|
608 |
+
" <th>response</th>\n",
|
609 |
+
" <th>latency</th>\n",
|
610 |
+
" <th>irrelevant_score</th>\n",
|
611 |
+
" </tr>\n",
|
612 |
+
" </thead>\n",
|
613 |
+
" <tbody>\n",
|
614 |
+
" <tr>\n",
|
615 |
+
" <th>0</th>\n",
|
616 |
+
" <td>What is the capital of Mars?</td>\n",
|
617 |
+
" <td>I don't know. The provided context does not se...</td>\n",
|
618 |
+
" <td>12.207266</td>\n",
|
619 |
+
" <td>True</td>\n",
|
620 |
+
" </tr>\n",
|
621 |
+
" <tr>\n",
|
622 |
+
" <th>1</th>\n",
|
623 |
+
" <td>How many unicorns live in New York City?</td>\n",
|
624 |
+
" <td>I don't know. The information provided does no...</td>\n",
|
625 |
+
" <td>2.368774</td>\n",
|
626 |
+
" <td>True</td>\n",
|
627 |
+
" </tr>\n",
|
628 |
+
" <tr>\n",
|
629 |
+
" <th>2</th>\n",
|
630 |
+
" <td>What is the color of happiness?</td>\n",
|
631 |
+
" <td>I don't know! The provided context only talks ...</td>\n",
|
632 |
+
" <td>5.480067</td>\n",
|
633 |
+
" <td>True</td>\n",
|
634 |
+
" </tr>\n",
|
635 |
+
" <tr>\n",
|
636 |
+
" <th>3</th>\n",
|
637 |
+
" <td>Can cats fly on Tuesdays?</td>\n",
|
638 |
+
" <td>I don't know the answer to this question as it...</td>\n",
|
639 |
+
" <td>5.272529</td>\n",
|
640 |
+
" <td>True</td>\n",
|
641 |
+
" </tr>\n",
|
642 |
+
" <tr>\n",
|
643 |
+
" <th>4</th>\n",
|
644 |
+
" <td>How much does a thought weigh?</td>\n",
|
645 |
+
" <td>I don't know. The context provided is about me...</td>\n",
|
646 |
+
" <td>5.253224</td>\n",
|
647 |
+
" <td>True</td>\n",
|
648 |
+
" </tr>\n",
|
649 |
+
" </tbody>\n",
|
650 |
+
"</table>\n",
|
651 |
+
"</div>"
|
652 |
+
],
|
653 |
+
"text/plain": [
|
654 |
+
" Questions \\\n",
|
655 |
+
"0 What is the capital of Mars? \n",
|
656 |
+
"1 How many unicorns live in New York City? \n",
|
657 |
+
"2 What is the color of happiness? \n",
|
658 |
+
"3 Can cats fly on Tuesdays? \n",
|
659 |
+
"4 How much does a thought weigh? \n",
|
660 |
+
"\n",
|
661 |
+
" response latency \\\n",
|
662 |
+
"0 I don't know. The provided context does not se... 12.207266 \n",
|
663 |
+
"1 I don't know. The information provided does no... 2.368774 \n",
|
664 |
+
"2 I don't know! The provided context only talks ... 5.480067 \n",
|
665 |
+
"3 I don't know the answer to this question as it... 5.272529 \n",
|
666 |
+
"4 I don't know. The context provided is about me... 5.253224 \n",
|
667 |
+
"\n",
|
668 |
+
" irrelevant_score \n",
|
669 |
+
"0 True \n",
|
670 |
+
"1 True \n",
|
671 |
+
"2 True \n",
|
672 |
+
"3 True \n",
|
673 |
+
"4 True "
|
674 |
+
]
|
675 |
+
},
|
676 |
+
"execution_count": 79,
|
677 |
+
"metadata": {},
|
678 |
+
"output_type": "execute_result"
|
679 |
+
}
|
680 |
+
],
|
681 |
+
"source": [
|
682 |
+
"irr_q.head()"
|
683 |
+
]
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"cell_type": "code",
|
687 |
+
"execution_count": 37,
|
688 |
+
"id": "8620e50c",
|
689 |
+
"metadata": {},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"data": {
|
693 |
+
"text/plain": [
|
694 |
+
"0 12.207266\n",
|
695 |
+
"1 2.368774\n",
|
696 |
+
"2 5.480067\n",
|
697 |
+
"3 5.272529\n",
|
698 |
+
"4 5.253224\n",
|
699 |
+
"5 5.351224\n",
|
700 |
+
"6 8.118429\n",
|
701 |
+
"7 7.288261\n",
|
702 |
+
"8 3.856500\n",
|
703 |
+
"9 7.745016\n",
|
704 |
+
"Name: latency, dtype: float64"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
"execution_count": 37,
|
708 |
+
"metadata": {},
|
709 |
+
"output_type": "execute_result"
|
710 |
+
}
|
711 |
+
],
|
712 |
+
"source": [
|
713 |
+
"irr_q['latency']"
|
714 |
+
]
|
715 |
+
},
|
716 |
+
{
|
717 |
+
"cell_type": "code",
|
718 |
+
"execution_count": 39,
|
719 |
+
"id": "debd3461",
|
720 |
+
"metadata": {},
|
721 |
+
"outputs": [],
|
722 |
+
"source": [
|
723 |
+
"irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "code",
|
728 |
+
"execution_count": 40,
|
729 |
+
"id": "bef1d3a4",
|
730 |
+
"metadata": {},
|
731 |
+
"outputs": [
|
732 |
+
{
|
733 |
+
"data": {
|
734 |
+
"text/plain": [
|
735 |
+
"irrelevant_score 0.900000\n",
|
736 |
+
"latency 6.294129\n",
|
737 |
+
"dtype: float64"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
"execution_count": 40,
|
741 |
+
"metadata": {},
|
742 |
+
"output_type": "execute_result"
|
743 |
+
}
|
744 |
+
],
|
745 |
+
"source": [
|
746 |
+
"irr_q[['irrelevant_score','latency']].mean()"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "markdown",
|
751 |
+
"id": "c1610a70",
|
752 |
+
"metadata": {},
|
753 |
+
"source": [
|
754 |
+
"# Improvement"
|
755 |
+
]
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"cell_type": "code",
|
759 |
+
"execution_count": 48,
|
760 |
+
"id": "ff6614f9",
|
761 |
+
"metadata": {},
|
762 |
+
"outputs": [],
|
763 |
+
"source": [
|
764 |
+
"new_prompt_template = \"\"\"\n",
|
765 |
+
"You are an AI assistant specialized in Mental Health guidelines.\n",
|
766 |
+
"Use the provided context to answer the question short and accurately. \n",
|
767 |
+
"If you don't know the answer, simply say, \"I don't know.\"\n",
|
768 |
+
"\n",
|
769 |
+
"Context:\n",
|
770 |
+
"{context}\n",
|
771 |
+
"\n",
|
772 |
+
"Question: {question}\n",
|
773 |
+
"\n",
|
774 |
+
"Answer:\"\"\"\n",
|
775 |
+
"\n",
|
776 |
+
"prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n",
|
777 |
+
"\n",
|
778 |
+
"llm = Ollama(\n",
|
779 |
+
" model=\"llama3\"\n",
|
780 |
+
")\n",
|
781 |
+
"\n",
|
782 |
+
"# Create the chain\n",
|
783 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
784 |
+
"\n",
|
785 |
+
"def answer_question_new(query):\n",
|
786 |
+
" # Search for relevant context\n",
|
787 |
+
" search_results = search_faiss(query)\n",
|
788 |
+
" \n",
|
789 |
+
" # Combine the content from the search results\n",
|
790 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
791 |
+
"\n",
|
792 |
+
" # Run the chain\n",
|
793 |
+
" response = chain.run(context=context, question=query)\n",
|
794 |
+
" \n",
|
795 |
+
" return response"
|
796 |
+
]
|
797 |
+
},
|
798 |
+
{
|
799 |
+
"cell_type": "code",
|
800 |
+
"execution_count": 49,
|
801 |
+
"id": "20580d50",
|
802 |
+
"metadata": {},
|
803 |
+
"outputs": [],
|
804 |
+
"source": [
|
805 |
+
"df2=df.copy()"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "code",
|
810 |
+
"execution_count": 50,
|
811 |
+
"id": "b1b3d725",
|
812 |
+
"metadata": {},
|
813 |
+
"outputs": [
|
814 |
+
{
|
815 |
+
"name": "stderr",
|
816 |
+
"output_type": "stream",
|
817 |
+
"text": [
|
818 |
+
"100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n"
|
819 |
+
]
|
820 |
+
}
|
821 |
+
],
|
822 |
+
"source": [
|
823 |
+
"time_list=[]\n",
|
824 |
+
"response_list=[]\n",
|
825 |
+
"for i in tqdm(range(len(df2))):\n",
|
826 |
+
" query = df2['Questions'].values[i]\n",
|
827 |
+
" start = time.time()\n",
|
828 |
+
" response = answer_question(query)\n",
|
829 |
+
" end = time.time() \n",
|
830 |
+
" time_list.append(end-start)\n",
|
831 |
+
" response_list.append(response)"
|
832 |
+
]
|
833 |
+
},
|
834 |
+
{
|
835 |
+
"cell_type": "code",
|
836 |
+
"execution_count": 51,
|
837 |
+
"id": "63f41256",
|
838 |
+
"metadata": {},
|
839 |
+
"outputs": [],
|
840 |
+
"source": [
|
841 |
+
"df2['latency'] = time_list\n",
|
842 |
+
"df2['response'] = response_list"
|
843 |
+
]
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"cell_type": "code",
|
847 |
+
"execution_count": 52,
|
848 |
+
"id": "0d8a6065",
|
849 |
+
"metadata": {},
|
850 |
+
"outputs": [
|
851 |
+
{
|
852 |
+
"name": "stderr",
|
853 |
+
"output_type": "stream",
|
854 |
+
"text": [
|
855 |
+
"100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n",
|
856 |
+
"100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n",
|
857 |
+
"100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n",
|
858 |
+
"100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n"
|
859 |
+
]
|
860 |
+
}
|
861 |
+
],
|
862 |
+
"source": [
|
863 |
+
"for metric in metrics:\n",
|
864 |
+
" evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n",
|
865 |
+
" \n",
|
866 |
+
" reasoning = []\n",
|
867 |
+
" value = []\n",
|
868 |
+
" score = []\n",
|
869 |
+
" \n",
|
870 |
+
" for i in tqdm(range(len(df2))):\n",
|
871 |
+
" eval_result = evaluator.evaluate_strings(\n",
|
872 |
+
" prediction=df2.response.values[i],\n",
|
873 |
+
" input=df2.Questions.values[i],\n",
|
874 |
+
" reference=df2.Answers.values[i]\n",
|
875 |
+
" )\n",
|
876 |
+
" reasoning.append(eval_result['reasoning'])\n",
|
877 |
+
" value.append(eval_result['value'])\n",
|
878 |
+
" score.append(eval_result['score'])\n",
|
879 |
+
" \n",
|
880 |
+
" df2[metric+'_reasoning'] = reasoning\n",
|
881 |
+
" df2[metric+'_value'] = value\n",
|
882 |
+
" df2[metric+'_score'] = score "
|
883 |
+
]
|
884 |
+
},
|
885 |
+
{
|
886 |
+
"cell_type": "code",
|
887 |
+
"execution_count": 77,
|
888 |
+
"id": "c648632c",
|
889 |
+
"metadata": {},
|
890 |
+
"outputs": [
|
891 |
+
{
|
892 |
+
"data": {
|
893 |
+
"text/html": [
|
894 |
+
"<div>\n",
|
895 |
+
"<style scoped>\n",
|
896 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
897 |
+
" vertical-align: middle;\n",
|
898 |
+
" }\n",
|
899 |
+
"\n",
|
900 |
+
" .dataframe tbody tr th {\n",
|
901 |
+
" vertical-align: top;\n",
|
902 |
+
" }\n",
|
903 |
+
"\n",
|
904 |
+
" .dataframe thead th {\n",
|
905 |
+
" text-align: right;\n",
|
906 |
+
" }\n",
|
907 |
+
"</style>\n",
|
908 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
909 |
+
" <thead>\n",
|
910 |
+
" <tr style=\"text-align: right;\">\n",
|
911 |
+
" <th></th>\n",
|
912 |
+
" <th>Questions</th>\n",
|
913 |
+
" <th>Answers</th>\n",
|
914 |
+
" <th>latency</th>\n",
|
915 |
+
" <th>response</th>\n",
|
916 |
+
" <th>correctness_reasoning</th>\n",
|
917 |
+
" <th>correctness_value</th>\n",
|
918 |
+
" <th>correctness_score</th>\n",
|
919 |
+
" <th>relevance_reasoning</th>\n",
|
920 |
+
" <th>relevance_value</th>\n",
|
921 |
+
" <th>relevance_score</th>\n",
|
922 |
+
" <th>coherence_reasoning</th>\n",
|
923 |
+
" <th>coherence_value</th>\n",
|
924 |
+
" <th>coherence_score</th>\n",
|
925 |
+
" <th>conciseness_reasoning</th>\n",
|
926 |
+
" <th>conciseness_value</th>\n",
|
927 |
+
" <th>conciseness_score</th>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" </thead>\n",
|
930 |
+
" <tbody>\n",
|
931 |
+
" <tr>\n",
|
932 |
+
" <th>0</th>\n",
|
933 |
+
" <td>What is Mental Health</td>\n",
|
934 |
+
" <td>Mental Health is a \" state of well-being in wh...</td>\n",
|
935 |
+
" <td>11.046327</td>\n",
|
936 |
+
" <td>Based on the context provided, mental health r...</td>\n",
|
937 |
+
" <td>Step 1: Evaluate if the submission is factuall...</td>\n",
|
938 |
+
" <td>N</td>\n",
|
939 |
+
" <td>0</td>\n",
|
940 |
+
" <td>Step 1: Analyze the relevance criterion\\nThe s...</td>\n",
|
941 |
+
" <td>N</td>\n",
|
942 |
+
" <td>0</td>\n",
|
943 |
+
" <td>The submission discusses mental health in rela...</td>\n",
|
944 |
+
" <td>Y</td>\n",
|
945 |
+
" <td>1</td>\n",
|
946 |
+
" <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
|
947 |
+
" <td>Y</td>\n",
|
948 |
+
" <td>1</td>\n",
|
949 |
+
" </tr>\n",
|
950 |
+
" <tr>\n",
|
951 |
+
" <th>1</th>\n",
|
952 |
+
" <td>What are the most common mental disorders ment...</td>\n",
|
953 |
+
" <td>The most common mental disorders include depre...</td>\n",
|
954 |
+
" <td>4.509713</td>\n",
|
955 |
+
" <td>The handbook mentions several mental illnesses...</td>\n",
|
956 |
+
" <td>The submission mentions depression and schizop...</td>\n",
|
957 |
+
" <td>N</td>\n",
|
958 |
+
" <td>0</td>\n",
|
959 |
+
" <td>Step 1: Analyze relevance criterion - Check if...</td>\n",
|
960 |
+
" <td>Y</td>\n",
|
961 |
+
" <td>1</td>\n",
|
962 |
+
" <td>Step 1: Assess coherence\\nThe submission menti...</td>\n",
|
963 |
+
" <td>N</td>\n",
|
964 |
+
" <td>0</td>\n",
|
965 |
+
" <td>Step 1: Analyze conciseness criterion\\nThe sub...</td>\n",
|
966 |
+
" <td>N</td>\n",
|
967 |
+
" <td>0</td>\n",
|
968 |
+
" </tr>\n",
|
969 |
+
" <tr>\n",
|
970 |
+
" <th>2</th>\n",
|
971 |
+
" <td>What are the early warning signs and symptoms ...</td>\n",
|
972 |
+
" <td>Early warning signs and symptoms of depression...</td>\n",
|
973 |
+
" <td>8.501180</td>\n",
|
974 |
+
" <td>According to the provided context, specificall...</td>\n",
|
975 |
+
" <td>The submission matches the reference data in t...</td>\n",
|
976 |
+
" <td>Y</td>\n",
|
977 |
+
" <td>1</td>\n",
|
978 |
+
" <td>The submission refers directly to information ...</td>\n",
|
979 |
+
" <td>Y</td>\n",
|
980 |
+
" <td>1</td>\n",
|
981 |
+
" <td>Step 1: Evaluate coherence - The submission is...</td>\n",
|
982 |
+
" <td>Y</td>\n",
|
983 |
+
" <td>1</td>\n",
|
984 |
+
" <td>The submission is concise and includes most of...</td>\n",
|
985 |
+
" <td>Y</td>\n",
|
986 |
+
" <td>1</td>\n",
|
987 |
+
" </tr>\n",
|
988 |
+
" <tr>\n",
|
989 |
+
" <th>3</th>\n",
|
990 |
+
" <td>How can someone help a person who suffers from...</td>\n",
|
991 |
+
" <td>To help someone with anxiety, one can support ...</td>\n",
|
992 |
+
" <td>10.611402</td>\n",
|
993 |
+
" <td>According to the Mental Health Handbook, when ...</td>\n",
|
994 |
+
" <td>The submission seems consistent with the refer...</td>\n",
|
995 |
+
" <td>Y</td>\n",
|
996 |
+
" <td>1</td>\n",
|
997 |
+
" <td>Step 1: Review relevance criterion\\nThe submis...</td>\n",
|
998 |
+
" <td>Y</td>\n",
|
999 |
+
" <td>1</td>\n",
|
1000 |
+
" <td>The submission is coherent, well-structured, a...</td>\n",
|
1001 |
+
" <td>Y</td>\n",
|
1002 |
+
" <td>1</td>\n",
|
1003 |
+
" <td>The submission is relatively concise and cover...</td>\n",
|
1004 |
+
" <td>Y</td>\n",
|
1005 |
+
" <td>1</td>\n",
|
1006 |
+
" </tr>\n",
|
1007 |
+
" <tr>\n",
|
1008 |
+
" <th>4</th>\n",
|
1009 |
+
" <td>What are the causes of mental illness listed i...</td>\n",
|
1010 |
+
" <td>Causes of mental illness include abnormal func...</td>\n",
|
1011 |
+
" <td>6.299272</td>\n",
|
1012 |
+
" <td>According to the context, the causes of mental...</td>\n",
|
1013 |
+
" <td>The submission lists causes such as neglect, s...</td>\n",
|
1014 |
+
" <td>N</td>\n",
|
1015 |
+
" <td>0</td>\n",
|
1016 |
+
" <td>The submission mentions factors that are part ...</td>\n",
|
1017 |
+
" <td>N</td>\n",
|
1018 |
+
" <td>0</td>\n",
|
1019 |
+
" <td>The submission is coherent and well-structured...</td>\n",
|
1020 |
+
" <td>Y</td>\n",
|
1021 |
+
" <td>1</td>\n",
|
1022 |
+
" <td>Step 1: Read and understand both the input dat...</td>\n",
|
1023 |
+
" <td>N</td>\n",
|
1024 |
+
" <td>0</td>\n",
|
1025 |
+
" </tr>\n",
|
1026 |
+
" </tbody>\n",
|
1027 |
+
"</table>\n",
|
1028 |
+
"</div>"
|
1029 |
+
],
|
1030 |
+
"text/plain": [
|
1031 |
+
" Questions \\\n",
|
1032 |
+
"0 What is Mental Health \n",
|
1033 |
+
"1 What are the most common mental disorders ment... \n",
|
1034 |
+
"2 What are the early warning signs and symptoms ... \n",
|
1035 |
+
"3 How can someone help a person who suffers from... \n",
|
1036 |
+
"4 What are the causes of mental illness listed i... \n",
|
1037 |
+
"\n",
|
1038 |
+
" Answers latency \\\n",
|
1039 |
+
"0 Mental Health is a \" state of well-being in wh... 11.046327 \n",
|
1040 |
+
"1 The most common mental disorders include depre... 4.509713 \n",
|
1041 |
+
"2 Early warning signs and symptoms of depression... 8.501180 \n",
|
1042 |
+
"3 To help someone with anxiety, one can support ... 10.611402 \n",
|
1043 |
+
"4 Causes of mental illness include abnormal func... 6.299272 \n",
|
1044 |
+
"\n",
|
1045 |
+
" response \\\n",
|
1046 |
+
"0 Based on the context provided, mental health r... \n",
|
1047 |
+
"1 The handbook mentions several mental illnesses... \n",
|
1048 |
+
"2 According to the provided context, specificall... \n",
|
1049 |
+
"3 According to the Mental Health Handbook, when ... \n",
|
1050 |
+
"4 According to the context, the causes of mental... \n",
|
1051 |
+
"\n",
|
1052 |
+
" correctness_reasoning correctness_value \\\n",
|
1053 |
+
"0 Step 1: Evaluate if the submission is factuall... N \n",
|
1054 |
+
"1 The submission mentions depression and schizop... N \n",
|
1055 |
+
"2 The submission matches the reference data in t... Y \n",
|
1056 |
+
"3 The submission seems consistent with the refer... Y \n",
|
1057 |
+
"4 The submission lists causes such as neglect, s... N \n",
|
1058 |
+
"\n",
|
1059 |
+
" correctness_score relevance_reasoning \\\n",
|
1060 |
+
"0 0 Step 1: Analyze the relevance criterion\\nThe s... \n",
|
1061 |
+
"1 0 Step 1: Analyze relevance criterion - Check if... \n",
|
1062 |
+
"2 1 The submission refers directly to information ... \n",
|
1063 |
+
"3 1 Step 1: Review relevance criterion\\nThe submis... \n",
|
1064 |
+
"4 0 The submission mentions factors that are part ... \n",
|
1065 |
+
"\n",
|
1066 |
+
" relevance_value relevance_score \\\n",
|
1067 |
+
"0 N 0 \n",
|
1068 |
+
"1 Y 1 \n",
|
1069 |
+
"2 Y 1 \n",
|
1070 |
+
"3 Y 1 \n",
|
1071 |
+
"4 N 0 \n",
|
1072 |
+
"\n",
|
1073 |
+
" coherence_reasoning coherence_value \\\n",
|
1074 |
+
"0 The submission discusses mental health in rela... Y \n",
|
1075 |
+
"1 Step 1: Assess coherence\\nThe submission menti... N \n",
|
1076 |
+
"2 Step 1: Evaluate coherence - The submission is... Y \n",
|
1077 |
+
"3 The submission is coherent, well-structured, a... Y \n",
|
1078 |
+
"4 The submission is coherent and well-structured... Y \n",
|
1079 |
+
"\n",
|
1080 |
+
" coherence_score conciseness_reasoning \\\n",
|
1081 |
+
"0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n",
|
1082 |
+
"1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n",
|
1083 |
+
"2 1 The submission is concise and includes most of... \n",
|
1084 |
+
"3 1 The submission is relatively concise and cover... \n",
|
1085 |
+
"4 1 Step 1: Read and understand both the input dat... \n",
|
1086 |
+
"\n",
|
1087 |
+
" conciseness_value conciseness_score \n",
|
1088 |
+
"0 Y 1 \n",
|
1089 |
+
"1 N 0 \n",
|
1090 |
+
"2 Y 1 \n",
|
1091 |
+
"3 Y 1 \n",
|
1092 |
+
"4 N 0 "
|
1093 |
+
]
|
1094 |
+
},
|
1095 |
+
"execution_count": 77,
|
1096 |
+
"metadata": {},
|
1097 |
+
"output_type": "execute_result"
|
1098 |
+
}
|
1099 |
+
],
|
1100 |
+
"source": [
|
1101 |
+
"df2.head()"
|
1102 |
+
]
|
1103 |
+
},
|
1104 |
+
{
|
1105 |
+
"cell_type": "code",
|
1106 |
+
"execution_count": 47,
|
1107 |
+
"id": "2d1002b2",
|
1108 |
+
"metadata": {},
|
1109 |
+
"outputs": [
|
1110 |
+
{
|
1111 |
+
"data": {
|
1112 |
+
"text/plain": [
|
1113 |
+
"correctness_score 0.500000\n",
|
1114 |
+
"relevance_score 0.888889\n",
|
1115 |
+
"coherence_score 0.888889\n",
|
1116 |
+
"conciseness_score 0.900000\n",
|
1117 |
+
"latency 8.190205\n",
|
1118 |
+
"dtype: float64"
|
1119 |
+
]
|
1120 |
+
},
|
1121 |
+
"execution_count": 47,
|
1122 |
+
"metadata": {},
|
1123 |
+
"output_type": "execute_result"
|
1124 |
+
}
|
1125 |
+
],
|
1126 |
+
"source": [
|
1127 |
+
"df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()"
|
1128 |
+
]
|
1129 |
+
},
|
1130 |
+
{
|
1131 |
+
"cell_type": "markdown",
|
1132 |
+
"id": "e808bdcf",
|
1133 |
+
"metadata": {},
|
1134 |
+
"source": [
|
1135 |
+
"# Query relevance"
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "code",
|
1140 |
+
"execution_count": 66,
|
1141 |
+
"id": "6b541f3d",
|
1142 |
+
"metadata": {},
|
1143 |
+
"outputs": [],
|
1144 |
+
"source": [
|
1145 |
+
"def new_search_faiss(query, k=3, threshold=1.5):\n",
|
1146 |
+
" query_vector = model.encode([query])[0].astype('float32')\n",
|
1147 |
+
" query_vector = np.expand_dims(query_vector, axis=0)\n",
|
1148 |
+
" distances, indices = index.search(query_vector, k)\n",
|
1149 |
+
" \n",
|
1150 |
+
" results = []\n",
|
1151 |
+
" for dist, idx in zip(distances[0], indices[0]):\n",
|
1152 |
+
" if dist < threshold: # Only include results within the threshold distance\n",
|
1153 |
+
" results.append({\n",
|
1154 |
+
" 'distance': dist,\n",
|
1155 |
+
" 'content': sections_data[idx]['content'],\n",
|
1156 |
+
" 'metadata': sections_data[idx]['metadata']\n",
|
1157 |
+
" })\n",
|
1158 |
+
" \n",
|
1159 |
+
" return results"
|
1160 |
+
]
|
1161 |
+
},
|
1162 |
+
{
|
1163 |
+
"cell_type": "code",
|
1164 |
+
"execution_count": 70,
|
1165 |
+
"id": "4f579654",
|
1166 |
+
"metadata": {},
|
1167 |
+
"outputs": [],
|
1168 |
+
"source": [
|
1169 |
+
"new_prompt_template = \"\"\"\n",
|
1170 |
+
"You are an AI assistant specialized in Mental Health guidelines.\n",
|
1171 |
+
"Use the provided context to answer the question short and accurately. \n",
|
1172 |
+
"If you don't know the answer, simply say, \"I don't know.\"\n",
|
1173 |
+
"\n",
|
1174 |
+
"Context:\n",
|
1175 |
+
"{context}\n",
|
1176 |
+
"\n",
|
1177 |
+
"Question: {question}\n",
|
1178 |
+
"\n",
|
1179 |
+
"Answer:\"\"\"\n",
|
1180 |
+
"\n",
|
1181 |
+
"prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
|
1182 |
+
"\n",
|
1183 |
+
"llm = Ollama(\n",
|
1184 |
+
" model=\"llama3\"\n",
|
1185 |
+
")\n",
|
1186 |
+
"\n",
|
1187 |
+
"# Create the chain\n",
|
1188 |
+
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
1189 |
+
"\n",
|
1190 |
+
"def new_answer_question(query):\n",
|
1191 |
+
" # Search for relevant context\n",
|
1192 |
+
" search_results = new_search_faiss(query)\n",
|
1193 |
+
" \n",
|
1194 |
+
" if search_results==[]:\n",
|
1195 |
+
" response=\"I don't know, sorry\"\n",
|
1196 |
+
" else:\n",
|
1197 |
+
" context = \"\\n\\n\".join([result['content'] for result in search_results])\n",
|
1198 |
+
" response = chain.run(context=context, question=query)\n",
|
1199 |
+
" \n",
|
1200 |
+
" return response"
|
1201 |
+
]
|
1202 |
+
},
|
1203 |
+
{
|
1204 |
+
"cell_type": "code",
|
1205 |
+
"execution_count": 71,
|
1206 |
+
"id": "1f83ef1b",
|
1207 |
+
"metadata": {},
|
1208 |
+
"outputs": [],
|
1209 |
+
"source": [
|
1210 |
+
"irr_q2=irr_q.copy()"
|
1211 |
+
]
|
1212 |
+
},
|
1213 |
+
{
|
1214 |
+
"cell_type": "code",
|
1215 |
+
"execution_count": 72,
|
1216 |
+
"id": "f06474e3",
|
1217 |
+
"metadata": {},
|
1218 |
+
"outputs": [
|
1219 |
+
{
|
1220 |
+
"name": "stderr",
|
1221 |
+
"output_type": "stream",
|
1222 |
+
"text": [
|
1223 |
+
"100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n"
|
1224 |
+
]
|
1225 |
+
}
|
1226 |
+
],
|
1227 |
+
"source": [
|
1228 |
+
"time_list=[]\n",
|
1229 |
+
"response_list=[]\n",
|
1230 |
+
"for i in tqdm(range(len(irr_q2))):\n",
|
1231 |
+
" query = irr_q['Questions'].values[i]\n",
|
1232 |
+
" start = time.time()\n",
|
1233 |
+
" response = new_answer_question(query)\n",
|
1234 |
+
" end = time.time() \n",
|
1235 |
+
" time_list.append(end-start)\n",
|
1236 |
+
" response_list.append(response)"
|
1237 |
+
]
|
1238 |
+
},
|
1239 |
+
{
|
1240 |
+
"cell_type": "code",
|
1241 |
+
"execution_count": 73,
|
1242 |
+
"id": "52db6b82",
|
1243 |
+
"metadata": {},
|
1244 |
+
"outputs": [],
|
1245 |
+
"source": [
|
1246 |
+
"irr_q2['response']=response_list\n",
|
1247 |
+
"irr_q2['latency']=time_list"
|
1248 |
+
]
|
1249 |
+
},
|
1250 |
+
{
|
1251 |
+
"cell_type": "code",
|
1252 |
+
"execution_count": 80,
|
1253 |
+
"id": "80a178ee",
|
1254 |
+
"metadata": {},
|
1255 |
+
"outputs": [
|
1256 |
+
{
|
1257 |
+
"data": {
|
1258 |
+
"text/html": [
|
1259 |
+
"<div>\n",
|
1260 |
+
"<style scoped>\n",
|
1261 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
1262 |
+
" vertical-align: middle;\n",
|
1263 |
+
" }\n",
|
1264 |
+
"\n",
|
1265 |
+
" .dataframe tbody tr th {\n",
|
1266 |
+
" vertical-align: top;\n",
|
1267 |
+
" }\n",
|
1268 |
+
"\n",
|
1269 |
+
" .dataframe thead th {\n",
|
1270 |
+
" text-align: right;\n",
|
1271 |
+
" }\n",
|
1272 |
+
"</style>\n",
|
1273 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
1274 |
+
" <thead>\n",
|
1275 |
+
" <tr style=\"text-align: right;\">\n",
|
1276 |
+
" <th></th>\n",
|
1277 |
+
" <th>Questions</th>\n",
|
1278 |
+
" <th>response</th>\n",
|
1279 |
+
" <th>latency</th>\n",
|
1280 |
+
" <th>irrelevant_score</th>\n",
|
1281 |
+
" </tr>\n",
|
1282 |
+
" </thead>\n",
|
1283 |
+
" <tbody>\n",
|
1284 |
+
" <tr>\n",
|
1285 |
+
" <th>0</th>\n",
|
1286 |
+
" <td>What is the capital of Mars?</td>\n",
|
1287 |
+
" <td>I don't know, sorry</td>\n",
|
1288 |
+
" <td>0.061378</td>\n",
|
1289 |
+
" <td>True</td>\n",
|
1290 |
+
" </tr>\n",
|
1291 |
+
" <tr>\n",
|
1292 |
+
" <th>1</th>\n",
|
1293 |
+
" <td>How many unicorns live in New York City?</td>\n",
|
1294 |
+
" <td>I don't know, sorry</td>\n",
|
1295 |
+
" <td>0.012511</td>\n",
|
1296 |
+
" <td>True</td>\n",
|
1297 |
+
" </tr>\n",
|
1298 |
+
" <tr>\n",
|
1299 |
+
" <th>2</th>\n",
|
1300 |
+
" <td>What is the color of happiness?</td>\n",
|
1301 |
+
" <td>I don't know, sorry</td>\n",
|
1302 |
+
" <td>0.011900</td>\n",
|
1303 |
+
" <td>True</td>\n",
|
1304 |
+
" </tr>\n",
|
1305 |
+
" <tr>\n",
|
1306 |
+
" <th>3</th>\n",
|
1307 |
+
" <td>Can cats fly on Tuesdays?</td>\n",
|
1308 |
+
" <td>I don't know, sorry</td>\n",
|
1309 |
+
" <td>0.011438</td>\n",
|
1310 |
+
" <td>True</td>\n",
|
1311 |
+
" </tr>\n",
|
1312 |
+
" <tr>\n",
|
1313 |
+
" <th>4</th>\n",
|
1314 |
+
" <td>How much does a thought weigh?</td>\n",
|
1315 |
+
" <td>I don't know, sorry</td>\n",
|
1316 |
+
" <td>0.010644</td>\n",
|
1317 |
+
" <td>True</td>\n",
|
1318 |
+
" </tr>\n",
|
1319 |
+
" </tbody>\n",
|
1320 |
+
"</table>\n",
|
1321 |
+
"</div>"
|
1322 |
+
],
|
1323 |
+
"text/plain": [
|
1324 |
+
" Questions response latency \\\n",
|
1325 |
+
"0 What is the capital of Mars? I don't know, sorry 0.061378 \n",
|
1326 |
+
"1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n",
|
1327 |
+
"2 What is the color of happiness? I don't know, sorry 0.011900 \n",
|
1328 |
+
"3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n",
|
1329 |
+
"4 How much does a thought weigh? I don't know, sorry 0.010644 \n",
|
1330 |
+
"\n",
|
1331 |
+
" irrelevant_score \n",
|
1332 |
+
"0 True \n",
|
1333 |
+
"1 True \n",
|
1334 |
+
"2 True \n",
|
1335 |
+
"3 True \n",
|
1336 |
+
"4 True "
|
1337 |
+
]
|
1338 |
+
},
|
1339 |
+
"execution_count": 80,
|
1340 |
+
"metadata": {},
|
1341 |
+
"output_type": "execute_result"
|
1342 |
+
}
|
1343 |
+
],
|
1344 |
+
"source": [
|
1345 |
+
"irr_q2.head()"
|
1346 |
+
]
|
1347 |
+
},
|
1348 |
+
{
|
1349 |
+
"cell_type": "code",
|
1350 |
+
"execution_count": 74,
|
1351 |
+
"id": "4508de9e",
|
1352 |
+
"metadata": {},
|
1353 |
+
"outputs": [],
|
1354 |
+
"source": [
|
1355 |
+
"irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")"
|
1356 |
+
]
|
1357 |
+
},
|
1358 |
+
{
|
1359 |
+
"cell_type": "code",
|
1360 |
+
"execution_count": 75,
|
1361 |
+
"id": "3d34ba06",
|
1362 |
+
"metadata": {},
|
1363 |
+
"outputs": [
|
1364 |
+
{
|
1365 |
+
"data": {
|
1366 |
+
"text/plain": [
|
1367 |
+
"irrelevant_score 1.000000\n",
|
1368 |
+
"latency 0.016068\n",
|
1369 |
+
"dtype: float64"
|
1370 |
+
]
|
1371 |
+
},
|
1372 |
+
"execution_count": 75,
|
1373 |
+
"metadata": {},
|
1374 |
+
"output_type": "execute_result"
|
1375 |
+
}
|
1376 |
+
],
|
1377 |
+
"source": [
|
1378 |
+
"irr_q2[['irrelevant_score','latency']].mean()"
|
1379 |
+
]
|
1380 |
+
}
|
1381 |
+
],
|
1382 |
+
"metadata": {
|
1383 |
+
"kernelspec": {
|
1384 |
+
"display_name": "Python 3 (ipykernel)",
|
1385 |
+
"language": "python",
|
1386 |
+
"name": "python3"
|
1387 |
+
},
|
1388 |
+
"language_info": {
|
1389 |
+
"codemirror_mode": {
|
1390 |
+
"name": "ipython",
|
1391 |
+
"version": 3
|
1392 |
+
},
|
1393 |
+
"file_extension": ".py",
|
1394 |
+
"mimetype": "text/x-python",
|
1395 |
+
"name": "python",
|
1396 |
+
"nbconvert_exporter": "python",
|
1397 |
+
"pygments_lexer": "ipython3",
|
1398 |
+
"version": "3.11.0"
|
1399 |
+
}
|
1400 |
+
},
|
1401 |
+
"nbformat": 4,
|
1402 |
+
"nbformat_minor": 5
|
1403 |
+
}
|
Evaluation_MH/Mental Health Evaluation Report.pdf
ADDED
Binary file (72.9 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Aditi Yadav
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
MentalHealth/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Aditi Yadav
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
MentalHealth/app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain.chains import LLMChain
|
5 |
+
from langchain_community.llms import Ollama
|
6 |
+
import faiss
|
7 |
+
import numpy as np
|
8 |
+
import pickle
|
9 |
+
|
10 |
+
# Load the FAISS index
|
11 |
+
@st.cache(allow_output_mutation=True)
|
12 |
+
def load_faiss_index():
|
13 |
+
try:
|
14 |
+
return faiss.read_index("database/pdf_sections_index.faiss")
|
15 |
+
except FileNotFoundError:
|
16 |
+
st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
|
17 |
+
st.stop()
|
18 |
+
|
19 |
+
# Load the embedding model
|
20 |
+
@st.cache(allow_output_mutation=True)
|
21 |
+
def load_embedding_model():
|
22 |
+
return SentenceTransformer('all-MiniLM-L6-v2')
|
23 |
+
|
24 |
+
# Load sections data
|
25 |
+
@st.cache(allow_output_mutation=True)
|
26 |
+
def load_sections_data():
|
27 |
+
try:
|
28 |
+
with open('database/pdf_sections_data.pkl', 'rb') as f:
|
29 |
+
return pickle.load(f)
|
30 |
+
except FileNotFoundError:
|
31 |
+
st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
|
32 |
+
st.stop()
|
33 |
+
|
34 |
+
# Initialize resources
|
35 |
+
index = load_faiss_index()
|
36 |
+
model = load_embedding_model()
|
37 |
+
sections_data = load_sections_data()
|
38 |
+
|
39 |
+
def search_faiss(query, k=3):
|
40 |
+
query_vector = model.encode([query])[0].astype('float32')
|
41 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
42 |
+
distances, indices = index.search(query_vector, k)
|
43 |
+
|
44 |
+
results = []
|
45 |
+
for dist, idx in zip(distances[0], indices[0]):
|
46 |
+
results.append({
|
47 |
+
'distance': dist,
|
48 |
+
'content': sections_data[idx]['content'],
|
49 |
+
'metadata': sections_data[idx]['metadata']
|
50 |
+
})
|
51 |
+
|
52 |
+
return results
|
53 |
+
|
54 |
+
prompt_template = """
|
55 |
+
You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
56 |
+
|
57 |
+
Context:
|
58 |
+
{context}
|
59 |
+
|
60 |
+
Question: {question}
|
61 |
+
|
62 |
+
Answer:"""
|
63 |
+
|
64 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
65 |
+
|
66 |
+
@st.cache(allow_output_mutation=True)
|
67 |
+
def load_llm():
|
68 |
+
return Ollama(model="llama3")
|
69 |
+
|
70 |
+
llm = load_llm()
|
71 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
72 |
+
|
73 |
+
def answer_question(query):
|
74 |
+
search_results = search_faiss(query)
|
75 |
+
context = "\n\n".join([result['content'] for result in search_results])
|
76 |
+
response = chain.run(context=context, question=query)
|
77 |
+
return response, context
|
78 |
+
|
79 |
+
# Streamlit UI
|
80 |
+
st.title("Mental Health Guidelines Q&A")
|
81 |
+
|
82 |
+
query = st.text_input("Enter your question about Mental Health guidelines:")
|
83 |
+
|
84 |
+
if st.button("Get Answer"):
|
85 |
+
if query:
|
86 |
+
with st.spinner("Searching and generating answer..."):
|
87 |
+
answer, context = answer_question(query)
|
88 |
+
st.subheader("Answer:")
|
89 |
+
st.write(answer)
|
90 |
+
with st.expander("Show Context"):
|
91 |
+
st.write(context)
|
92 |
+
else:
|
93 |
+
st.warning("Please enter a question.")
|
MentalHealth/create_vectordb.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.document_loaders import PyPDFLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
# Load the PDF
|
9 |
+
pdf_path = "data\Mental Health Handbook English.pdf"
|
10 |
+
loader = PyPDFLoader(file_path=pdf_path)
|
11 |
+
|
12 |
+
# Load the content
|
13 |
+
documents = loader.load()
|
14 |
+
|
15 |
+
# Split the document into sections
|
16 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
|
17 |
+
sections = text_splitter.split_documents(documents)
|
18 |
+
|
19 |
+
# Load the embedding model
|
20 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
21 |
+
|
22 |
+
# Generate embeddings for each section
|
23 |
+
section_texts = [section.page_content for section in sections]
|
24 |
+
embeddings = model.encode(section_texts)
|
25 |
+
|
26 |
+
print(embeddings.shape)
|
27 |
+
|
28 |
+
embeddings_np = np.array(embeddings).astype('float32')
|
29 |
+
|
30 |
+
# Create a FAISS index
|
31 |
+
dimension = embeddings_np.shape[1]
|
32 |
+
index = faiss.IndexFlatL2(dimension)
|
33 |
+
|
34 |
+
# Add vectors to the index
|
35 |
+
index.add(embeddings_np)
|
36 |
+
|
37 |
+
# Save the index to a file
|
38 |
+
faiss.write_index(index, "database/pdf_sections_index.faiss")
|
39 |
+
|
40 |
+
# When creating the index:
|
41 |
+
sections_data = [
|
42 |
+
{
|
43 |
+
'content': section.page_content,
|
44 |
+
'metadata': section.metadata
|
45 |
+
}
|
46 |
+
for section in sections
|
47 |
+
]
|
48 |
+
|
49 |
+
# Save sections data
|
50 |
+
with open('database/pdf_sections_data.pkl', 'wb') as f:
|
51 |
+
pickle.dump(sections_data, f)
|
52 |
+
|
53 |
+
print("Embeddings stored in FAISS index and saved to file.")
|
MentalHealth/data/Mental Health Handbook English.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
|
3 |
+
size 5413709
|
MentalHealth/database/pdf_sections_data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
|
3 |
+
size 60012
|
MentalHealth/database/pdf_sections_index.faiss
ADDED
Binary file (66.1 kB). View file
|
|
MentalHealth/rag.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from langchain.chains import LLMChain
|
4 |
+
from langchain_community.llms import Ollama
|
5 |
+
import faiss
|
6 |
+
import numpy as np
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
# Load the FAISS index
|
10 |
+
try:
|
11 |
+
index = faiss.read_index("database/pdf_sections_index.faiss")
|
12 |
+
except FileNotFoundError:
|
13 |
+
print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
|
14 |
+
exit(1)
|
15 |
+
|
16 |
+
# Load the embedding model
|
17 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
18 |
+
|
19 |
+
# Load sections data
|
20 |
+
try:
|
21 |
+
with open('database/pdf_sections_data.pkl', 'rb') as f:
|
22 |
+
sections_data = pickle.load(f)
|
23 |
+
except FileNotFoundError:
|
24 |
+
print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
|
25 |
+
exit(1)
|
26 |
+
|
27 |
+
def search_faiss(query, k=3):
|
28 |
+
query_vector = model.encode([query])[0].astype('float32')
|
29 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
30 |
+
distances, indices = index.search(query_vector, k)
|
31 |
+
|
32 |
+
results = []
|
33 |
+
for dist, idx in zip(distances[0], indices[0]):
|
34 |
+
results.append({
|
35 |
+
'distance': dist,
|
36 |
+
'content': sections_data[idx]['content'],
|
37 |
+
'metadata': sections_data[idx]['metadata']
|
38 |
+
})
|
39 |
+
|
40 |
+
return results
|
41 |
+
|
42 |
+
# Create a prompt template
|
43 |
+
prompt_template = """
|
44 |
+
You are an AI assistant specialized in dietary guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
45 |
+
|
46 |
+
Context:
|
47 |
+
{context}
|
48 |
+
|
49 |
+
Question: {question}
|
50 |
+
|
51 |
+
Answer:"""
|
52 |
+
|
53 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
54 |
+
|
55 |
+
llm = Ollama(
|
56 |
+
model="llama3"
|
57 |
+
)
|
58 |
+
|
59 |
+
# Create the chain
|
60 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
61 |
+
|
62 |
+
def answer_question(query):
|
63 |
+
# Search for relevant context
|
64 |
+
search_results = search_faiss(query)
|
65 |
+
|
66 |
+
# Combine the content from the search results
|
67 |
+
context = "\n\n".join([result['content'] for result in search_results])
|
68 |
+
|
69 |
+
# Run the chain
|
70 |
+
response = chain.run(context=context, question=query)
|
71 |
+
|
72 |
+
return response
|
73 |
+
|
74 |
+
# Example usage
|
75 |
+
query = "What is Mental Health?"
|
76 |
+
answer = answer_question(query)
|
77 |
+
|
78 |
+
print(f"Question: {query}")
|
79 |
+
print(f"Answer: {answer}")
|
MentalHealth/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pypdf
|
3 |
+
langchain
|
4 |
+
sentence-transformers
|
5 |
+
langchain-community
|
6 |
+
opensearch-py
|
7 |
+
faiss-cpu
|
8 |
+
|
MentalHealth/simple_retrieval.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# Load the FAISS_index
|
6 |
+
index = faiss.read_index("database/pdf_sections_index.faiss")
|
7 |
+
|
8 |
+
# Load the embedding model
|
9 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
10 |
+
|
11 |
+
def search_faiss(query, k=3):
|
12 |
+
query_vector = model.encode([query])[0].astype('float32')
|
13 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
14 |
+
distances, indices = index.search(query_vector, k)
|
15 |
+
return distances, indices
|
16 |
+
|
17 |
+
# Example usage
|
18 |
+
query = "What are the main dietary guidelines for protein intake?"
|
19 |
+
distances, indices = search_faiss(query)
|
20 |
+
|
21 |
+
print(f"Query: {query}")
|
22 |
+
print(f"Distances: {distances}")
|
23 |
+
print(f"Indices: {indices}")
|
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain.chains import LLMChain
|
5 |
+
from langchain_community.llms import Ollama
|
6 |
+
import faiss
|
7 |
+
import numpy as np
|
8 |
+
import pickle
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
|
12 |
+
# Load the FAISS index
|
13 |
+
@st.cache(allow_output_mutation=True)
|
14 |
+
def load_faiss_index():
|
15 |
+
try:
|
16 |
+
return faiss.read_index("database/pdf_sections_index.faiss")
|
17 |
+
except FileNotFoundError:
|
18 |
+
st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
|
19 |
+
st.stop()
|
20 |
+
|
21 |
+
# Load the embedding model
|
22 |
+
@st.cache(allow_output_mutation=True)
|
23 |
+
def load_embedding_model():
|
24 |
+
return SentenceTransformer('all-MiniLM-L6-v2')
|
25 |
+
|
26 |
+
# Load sections data
|
27 |
+
@st.cache(allow_output_mutation=True)
|
28 |
+
def load_sections_data():
|
29 |
+
try:
|
30 |
+
with open('database/pdf_sections_data.pkl', 'rb') as f:
|
31 |
+
return pickle.load(f)
|
32 |
+
except FileNotFoundError:
|
33 |
+
st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
|
34 |
+
st.stop()
|
35 |
+
|
36 |
+
# Initialize resources
|
37 |
+
index = load_faiss_index()
|
38 |
+
model = load_embedding_model()
|
39 |
+
sections_data = load_sections_data()
|
40 |
+
|
41 |
+
def search_faiss(query, k=3):
|
42 |
+
query_vector = model.encode([query])[0].astype('float32')
|
43 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
44 |
+
distances, indices = index.search(query_vector, k)
|
45 |
+
|
46 |
+
results = []
|
47 |
+
for dist, idx in zip(distances[0], indices[0]):
|
48 |
+
results.append({
|
49 |
+
'distance': dist,
|
50 |
+
'content': sections_data[idx]['content'],
|
51 |
+
'metadata': sections_data[idx]['metadata']
|
52 |
+
})
|
53 |
+
|
54 |
+
return results
|
55 |
+
|
56 |
+
prompt_template = """
|
57 |
+
You are an AI assistant specialized in Mental Health & wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
58 |
+
|
59 |
+
Context:
|
60 |
+
{context}
|
61 |
+
|
62 |
+
Question: {question}
|
63 |
+
|
64 |
+
Answer:"""
|
65 |
+
|
66 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
67 |
+
|
68 |
+
@st.cache(allow_output_mutation=True)
|
69 |
+
def load_llm():
|
70 |
+
return Ollama(model="phi3")
|
71 |
+
|
72 |
+
llm = load_llm()
|
73 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
74 |
+
|
75 |
+
def answer_question(query):
|
76 |
+
search_results = search_faiss(query)
|
77 |
+
context = "\n\n".join([result['content'] for result in search_results])
|
78 |
+
response = chain.run(context=context, question=query)
|
79 |
+
return response, context
|
80 |
+
|
81 |
+
# Streamlit UI
|
82 |
+
st.title("Mental Health & Wellness Assistant")
|
83 |
+
|
84 |
+
query = st.text_input("Enter your question about Mental Health:")
|
85 |
+
|
86 |
+
if st.button("Get Answer"):
|
87 |
+
if query:
|
88 |
+
with st.spinner("Searching, Thinking and generating answer..."):
|
89 |
+
answer, context = answer_question(query)
|
90 |
+
st.subheader("Answer:")
|
91 |
+
st.write(answer)
|
92 |
+
with st.expander("Show Context"):
|
93 |
+
st.write(context)
|
94 |
+
else:
|
95 |
+
st.warning("Please enter a question.")
|
96 |
+
|
97 |
+
# Footer section with social links
|
98 |
+
st.markdown("""
|
99 |
+
<div class="social-icons">
|
100 |
+
<a href="https://github.com/yadavadit" target="_blank"><img src="https://img.icons8.com/material-outlined/48/e50914/github.png"/></a>
|
101 |
+
<a href="https://www.linkedin.com/in/yaditi/" target="_blank"><img src="https://img.icons8.com/color/48/e50914/linkedin.png"/></a>
|
102 |
+
<a href="mailto:[email protected]"><img src="https://img.icons8.com/color/48/e50914/gmail.png"/></a>
|
103 |
+
</div>
|
104 |
+
""", unsafe_allow_html=True)
|
create_vectordb.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.document_loaders import PyPDFLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
# Load the PDF
|
9 |
+
pdf_path = "data\Mental Health Handbook English.pdf"
|
10 |
+
loader = PyPDFLoader(file_path=pdf_path)
|
11 |
+
|
12 |
+
# Load the content
|
13 |
+
documents = loader.load()
|
14 |
+
|
15 |
+
# Split the document into sections
|
16 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
|
17 |
+
sections = text_splitter.split_documents(documents)
|
18 |
+
|
19 |
+
# Load the embedding model
|
20 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
21 |
+
|
22 |
+
# Generate embeddings for each section
|
23 |
+
section_texts = [section.page_content for section in sections]
|
24 |
+
embeddings = model.encode(section_texts)
|
25 |
+
|
26 |
+
print(embeddings.shape)
|
27 |
+
|
28 |
+
embeddings_np = np.array(embeddings).astype('float32')
|
29 |
+
|
30 |
+
# Create a FAISS index
|
31 |
+
dimension = embeddings_np.shape[1]
|
32 |
+
index = faiss.IndexFlatL2(dimension)
|
33 |
+
|
34 |
+
# Add vectors to the index
|
35 |
+
index.add(embeddings_np)
|
36 |
+
|
37 |
+
# Save the index to a file
|
38 |
+
faiss.write_index(index, "database/pdf_sections_index.faiss")
|
39 |
+
|
40 |
+
# When creating the index:
|
41 |
+
sections_data = [
|
42 |
+
{
|
43 |
+
'content': section.page_content,
|
44 |
+
'metadata': section.metadata
|
45 |
+
}
|
46 |
+
for section in sections
|
47 |
+
]
|
48 |
+
|
49 |
+
# Save sections data
|
50 |
+
with open('database/pdf_sections_data.pkl', 'wb') as f:
|
51 |
+
pickle.dump(sections_data, f)
|
52 |
+
|
53 |
+
print("Embeddings stored in FAISS index and saved to file.")
|
data/Mental Health Handbook English.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19da603f69fff5a4bc28a04fde30cf977f8fdb8310e9e31f6d21f4c45240c14b
|
3 |
+
size 5413709
|
data/MentalHealth_Dataset.xlsx
ADDED
Binary file (17.7 kB). View file
|
|
database/pdf_sections_data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4ceb3d84f382b1162df9c1b91f285c167411642572d56999c6bd1cd6b0dd2d7
|
3 |
+
size 60012
|
database/pdf_sections_index.faiss
ADDED
Binary file (66.1 kB). View file
|
|
rag.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from langchain.chains import LLMChain
|
4 |
+
from langchain_community.llms import Ollama
|
5 |
+
import faiss
|
6 |
+
import numpy as np
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
# Load the FAISS index
|
10 |
+
try:
|
11 |
+
index = faiss.read_index("database/pdf_sections_index.faiss")
|
12 |
+
except FileNotFoundError:
|
13 |
+
print("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
|
14 |
+
exit(1)
|
15 |
+
|
16 |
+
# Load the embedding model
|
17 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
18 |
+
|
19 |
+
# Load sections data
|
20 |
+
try:
|
21 |
+
with open('database/pdf_sections_data.pkl', 'rb') as f:
|
22 |
+
sections_data = pickle.load(f)
|
23 |
+
except FileNotFoundError:
|
24 |
+
print("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
|
25 |
+
exit(1)
|
26 |
+
|
27 |
+
def search_faiss(query, k=3):
|
28 |
+
query_vector = model.encode([query])[0].astype('float32')
|
29 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
30 |
+
distances, indices = index.search(query_vector, k)
|
31 |
+
|
32 |
+
results = []
|
33 |
+
for dist, idx in zip(distances[0], indices[0]):
|
34 |
+
results.append({
|
35 |
+
'distance': dist,
|
36 |
+
'content': sections_data[idx]['content'],
|
37 |
+
'metadata': sections_data[idx]['metadata']
|
38 |
+
})
|
39 |
+
|
40 |
+
return results
|
41 |
+
|
42 |
+
# Create a prompt template
|
43 |
+
prompt_template = """
|
44 |
+
You are an AI assistant specialized in mental health and wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
45 |
+
|
46 |
+
Context:
|
47 |
+
{context}
|
48 |
+
|
49 |
+
Question: {question}
|
50 |
+
|
51 |
+
Answer:"""
|
52 |
+
|
53 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
54 |
+
|
55 |
+
llm = Ollama(
|
56 |
+
model="phi3"
|
57 |
+
)
|
58 |
+
|
59 |
+
# Create the chain
|
60 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
61 |
+
|
62 |
+
def answer_question(query):
|
63 |
+
# Search for relevant context
|
64 |
+
search_results = search_faiss(query)
|
65 |
+
|
66 |
+
# Combine the content from the search results
|
67 |
+
context = "\n\n".join([result['content'] for result in search_results])
|
68 |
+
|
69 |
+
# Run the chain
|
70 |
+
response = chain.run(context=context, question=query)
|
71 |
+
|
72 |
+
return response
|
73 |
+
|
74 |
+
# Example usage
|
75 |
+
query = "What is mental health?"
|
76 |
+
answer = answer_question(query)
|
77 |
+
|
78 |
+
print(f"Question: {query}")
|
79 |
+
print(f"Answer: {answer}")
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
evaluate
|
3 |
+
pypdf
|
4 |
+
langchain
|
5 |
+
sentence-transformers
|
6 |
+
langchain-community
|
7 |
+
opensearch-py
|
8 |
+
faiss-cpu
|
9 |
+
accelerate
|
10 |
+
bert_score
|
11 |
+
|
simple_retrieval.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# Load the FAISS index
|
6 |
+
index = faiss.read_index("database/pdf_sections_index.faiss")
|
7 |
+
|
8 |
+
# Load the embedding model
|
9 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
10 |
+
|
11 |
+
def search_faiss(query, k=3):
|
12 |
+
query_vector = model.encode([query])[0].astype('float32')
|
13 |
+
query_vector = np.expand_dims(query_vector, axis=0)
|
14 |
+
distances, indices = index.search(query_vector, k)
|
15 |
+
return distances, indices
|
16 |
+
|
17 |
+
# Example usage
|
18 |
+
query = "What is mental Health?"
|
19 |
+
distances, indices = search_faiss(query)
|
20 |
+
|
21 |
+
print(f"Query: {query}")
|
22 |
+
print(f"Distances: {distances}")
|
23 |
+
print(f"Indices: {indices}")
|