{ "cells": [ { "cell_type": "markdown", "id": "f7b87c2c", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 5, "id": "c22401c2-2fd2-4459-9ee8-71bc3bd362c8", "metadata": {}, "outputs": [], "source": [ "# pip install -U sentence-transformers" ] }, { "cell_type": "code", "execution_count": 1, "id": "8a7cc9d8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/arnabchakraborty/anaconda3/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", " from tqdm.autonotebook import tqdm, trange\n" ] } ], "source": [ "from sentence_transformers import SentenceTransformer\n", "from langchain.prompts import PromptTemplate\n", "from langchain.chains import LLMChain\n", "from langchain_community.llms import Ollama\n", "from langchain.evaluation import load_evaluator\n", "import faiss\n", "import pandas as pd\n", "import numpy as np\n", "import pickle\n", "import time\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "b6efca1d", "metadata": {}, "source": [ "# Intialization" ] }, { "cell_type": "code", "execution_count": 2, "id": "cc9a49d2", "metadata": {}, "outputs": [], "source": [ "# Load the FAISS index\n", "index = faiss.read_index(\"database/pdf_sections_index.faiss\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "9af39b55", "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer('all-MiniLM-L6-v2')" ] }, { "cell_type": "code", "execution_count": 4, "id": "fee8cdfd", "metadata": {}, "outputs": [], "source": [ "with open('database/pdf_sections_data.pkl', 'rb') as f:\n", " sections_data = pickle.load(f)" ] }, { "cell_type": "markdown", "id": "d6a1ba6a", "metadata": {}, "source": [ "# RAG functions" ] }, { "cell_type": "code", "execution_count": 5, "id": "182bdbd8", "metadata": {}, "outputs": [], "source": [ "def search_faiss(query, k=3):\n", " query_vector = model.encode([query])[0].astype('float32')\n", " query_vector = np.expand_dims(query_vector, axis=0)\n", " distances, indices = index.search(query_vector, k)\n", " \n", " results = []\n", " for dist, idx in zip(distances[0], indices[0]):\n", " results.append({\n", " 'distance': dist,\n", " 'content': sections_data[idx]['content'],\n", " 'metadata': sections_data[idx]['metadata']\n", " })\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": 15, "id": "67edc46a", "metadata": {}, "outputs": [], "source": [ "# Create a prompt template\n", "prompt_template = \"\"\"\n", "You are an AI assistant specialized in Mental Health guidelines. \n", "Use the following pieces of context to answer the question. \n", "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n", "\n", "Context:\n", "{context}\n", "\n", "Question: {question}\n", "\n", "Answer:\"\"\"\n", "\n", "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n", "\n", "llm = Ollama(\n", " model=\"llama3\"\n", ")\n", "\n", "# Create the chain\n", "chain = LLMChain(llm=llm, prompt=prompt)\n", "\n", "def answer_question(query):\n", " # Search for relevant context\n", " search_results = search_faiss(query)\n", " \n", " # Combine the content from the search results\n", " context = \"\\n\\n\".join([result['content'] for result in search_results])\n", "\n", " # Run the chain\n", " response = chain.run(context=context, question=query)\n", " \n", " return response" ] }, { "cell_type": "markdown", "id": "3b176af9", "metadata": {}, "source": [ "# Reading GT" ] }, { "cell_type": "code", "execution_count": 16, "id": "4ab68dff", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('data/MentalHealth_Dataset.csv')" ] }, { "cell_type": "code", "execution_count": 17, "id": "4e7e22d7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [01:45<00:00, 10.55s/it]\n" ] } ], "source": [ "time_list=[]\n", "response_list=[]\n", "for i in tqdm(range(len(df))):\n", " query = df['Questions'].values[i]\n", " start = time.time()\n", " response = answer_question(query)\n", " end = time.time() \n", " time_list.append(end-start)\n", " response_list.append(response)" ] }, { "cell_type": "code", "execution_count": 18, "id": "2b327e90", "metadata": {}, "outputs": [], "source": [ "df['latency'] = time_list\n", "df['response'] = response_list" ] }, { "cell_type": "markdown", "id": "3c147204", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 29, "id": "d799e541", "metadata": {}, "outputs": [], "source": [ "eval_llm = Ollama(\n", " model=\"phi3\"\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "id": "c2f788dc", "metadata": {}, "outputs": [], "source": [ "metrics = ['correctness', 'relevance', 'coherence', 'conciseness']" ] }, { "cell_type": "code", "execution_count": 31, "id": "83ec2b8d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.51s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:59<00:00, 5.99s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:50<00:00, 5.10s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.88s/it]\n" ] } ], "source": [ "for metric in metrics:\n", " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n", " \n", " reasoning = []\n", " value = []\n", " score = []\n", " \n", " for i in tqdm(range(len(df))):\n", " eval_result = evaluator.evaluate_strings(\n", " prediction=df.response.values[i],\n", " input=df.Questions.values[i],\n", " reference=df.Answers.values[i]\n", " )\n", " reasoning.append(eval_result['reasoning'])\n", " value.append(eval_result['value'])\n", " score.append(eval_result['score'])\n", " \n", " df[metric+'_reasoning'] = reasoning\n", " df[metric+'_value'] = value\n", " df[metric+'_score'] = score " ] }, { "cell_type": "code", "execution_count": 78, "id": "f1673a31", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
QuestionsAnswerslatencyresponsecorrectness_reasoningcorrectness_valuecorrectness_scorerelevance_reasoningrelevance_valuerelevance_scorecoherence_reasoningcoherence_valuecoherence_scoreconciseness_reasoningconciseness_valueconciseness_score
0What is Mental HealthMental Health is a \" state of well-being in wh...11.974234Based on the provided context, specifically fr...The submission refers to the provided input wh...Y1Step 1: Evaluate relevance criterion\\nThe subm...Y1Step 1: Assess coherence\\nThe submission direc...Y11. The submission directly answers the questio...Y1
1What are the most common mental disorders ment...The most common mental disorders include depre...5.863329Based on the provided context, the mental diso...Step 1: Check if the submission is factually a...Y1Step 1: Analyze the relevance criterion\\nThe s...Y1The submission begins with an appropriate ques...Y1Step 1: Review conciseness criterion\\nThe subm...Y1
2What are the early warning signs and symptoms ...Early warning signs and symptoms of depression...13.434543Based on the provided context, I found a refer...Step 1: Evaluate Correctness\\nThe submission a...Y1Step 1: Identify the relevant criterion from t...Y1Step 1: Evaluate coherence\\nThe submission is ...Y1Step 1: Evaluate conciseness - The submission ...Y1
3How can someone help a person who suffers from...To help someone with anxiety, one can support ...13.838464According to the provided context, specificall...Step 1: Correctness\\nThe submission accurately...Y1Step 1: Analyze relevance criterion\\nThe submi...Y1Step 1: Evaluate coherence\\nThe submission dis...Y1Step 1: Evaluate conciseness - The submission ...N0
4What are the causes of mental illness listed i...Causes of mental illness include abnormal func...6.871735According to the provided context, the causes ...The submission lists factors that align with t...N0Step 1: Review relevance criterion - Check if ...Y1Step 1: Compare the submission with the provid...Y1Step 1: Assess conciseness\\nThe submission is ...Y1
\n", "
" ], "text/plain": [ " Questions \\\n", "0 What is Mental Health \n", "1 What are the most common mental disorders ment... \n", "2 What are the early warning signs and symptoms ... \n", "3 How can someone help a person who suffers from... \n", "4 What are the causes of mental illness listed i... \n", "\n", " Answers latency \\\n", "0 Mental Health is a \" state of well-being in wh... 11.974234 \n", "1 The most common mental disorders include depre... 5.863329 \n", "2 Early warning signs and symptoms of depression... 13.434543 \n", "3 To help someone with anxiety, one can support ... 13.838464 \n", "4 Causes of mental illness include abnormal func... 6.871735 \n", "\n", " response \\\n", "0 Based on the provided context, specifically fr... \n", "1 Based on the provided context, the mental diso... \n", "2 Based on the provided context, I found a refer... \n", "3 According to the provided context, specificall... \n", "4 According to the provided context, the causes ... \n", "\n", " correctness_reasoning correctness_value \\\n", "0 The submission refers to the provided input wh... Y \n", "1 Step 1: Check if the submission is factually a... Y \n", "2 Step 1: Evaluate Correctness\\nThe submission a... Y \n", "3 Step 1: Correctness\\nThe submission accurately... Y \n", "4 The submission lists factors that align with t... N \n", "\n", " correctness_score relevance_reasoning \\\n", "0 1 Step 1: Evaluate relevance criterion\\nThe subm... \n", "1 1 Step 1: Analyze the relevance criterion\\nThe s... \n", "2 1 Step 1: Identify the relevant criterion from t... \n", "3 1 Step 1: Analyze relevance criterion\\nThe submi... \n", "4 0 Step 1: Review relevance criterion - Check if ... \n", "\n", " relevance_value relevance_score \\\n", "0 Y 1 \n", "1 Y 1 \n", "2 Y 1 \n", "3 Y 1 \n", "4 Y 1 \n", "\n", " coherence_reasoning coherence_value \\\n", "0 Step 1: Assess coherence\\nThe submission direc... Y \n", "1 The submission begins with an appropriate ques... Y \n", "2 Step 1: Evaluate coherence\\nThe submission is ... Y \n", "3 Step 1: Evaluate coherence\\nThe submission dis... Y \n", "4 Step 1: Compare the submission with the provid... Y \n", "\n", " coherence_score conciseness_reasoning \\\n", "0 1 1. The submission directly answers the questio... \n", "1 1 Step 1: Review conciseness criterion\\nThe subm... \n", "2 1 Step 1: Evaluate conciseness - The submission ... \n", "3 1 Step 1: Evaluate conciseness - The submission ... \n", "4 1 Step 1: Assess conciseness\\nThe submission is ... \n", "\n", " conciseness_value conciseness_score \n", "0 Y 1 \n", "1 Y 1 \n", "2 Y 1 \n", "3 N 0 \n", "4 Y 1 " ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 32, "id": "7797a360", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "correctness_score 0.800000\n", "relevance_score 0.900000\n", "coherence_score 1.000000\n", "conciseness_score 0.800000\n", "latency 10.544803\n", "dtype: float64" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()" ] }, { "cell_type": "code", "execution_count": 34, "id": "fe667926", "metadata": {}, "outputs": [], "source": [ "irr_q=pd.read_csv('data/Unrelated_questions.csv')" ] }, { "cell_type": "code", "execution_count": 35, "id": "189f8a0f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.30s/it]\n" ] } ], "source": [ "time_list=[]\n", "response_list=[]\n", "for i in tqdm(range(len(irr_q))):\n", " query = irr_q['Questions'].values[i]\n", " start = time.time()\n", " response = answer_question(query)\n", " end = time.time() \n", " time_list.append(end-start)\n", " response_list.append(response)" ] }, { "cell_type": "code", "execution_count": 36, "id": "b0244ea0", "metadata": {}, "outputs": [], "source": [ "irr_q['response']=response_list\n", "irr_q['latency']=time_list" ] }, { "cell_type": "code", "execution_count": 79, "id": "dc3b1ade", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Questionsresponselatencyirrelevant_score
0What is the capital of Mars?I don't know. The provided context does not se...12.207266True
1How many unicorns live in New York City?I don't know. The information provided does no...2.368774True
2What is the color of happiness?I don't know! The provided context only talks ...5.480067True
3Can cats fly on Tuesdays?I don't know the answer to this question as it...5.272529True
4How much does a thought weigh?I don't know. The context provided is about me...5.253224True
\n", "
" ], "text/plain": [ " Questions \\\n", "0 What is the capital of Mars? \n", "1 How many unicorns live in New York City? \n", "2 What is the color of happiness? \n", "3 Can cats fly on Tuesdays? \n", "4 How much does a thought weigh? \n", "\n", " response latency \\\n", "0 I don't know. The provided context does not se... 12.207266 \n", "1 I don't know. The information provided does no... 2.368774 \n", "2 I don't know! The provided context only talks ... 5.480067 \n", "3 I don't know the answer to this question as it... 5.272529 \n", "4 I don't know. The context provided is about me... 5.253224 \n", "\n", " irrelevant_score \n", "0 True \n", "1 True \n", "2 True \n", "3 True \n", "4 True " ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "irr_q.head()" ] }, { "cell_type": "code", "execution_count": 37, "id": "8620e50c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 12.207266\n", "1 2.368774\n", "2 5.480067\n", "3 5.272529\n", "4 5.253224\n", "5 5.351224\n", "6 8.118429\n", "7 7.288261\n", "8 3.856500\n", "9 7.745016\n", "Name: latency, dtype: float64" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "irr_q['latency']" ] }, { "cell_type": "code", "execution_count": 39, "id": "debd3461", "metadata": {}, "outputs": [], "source": [ "irr_q['irrelevant_score'] = irr_q['response'].str.contains(\"I don't know\")" ] }, { "cell_type": "code", "execution_count": 40, "id": "bef1d3a4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "irrelevant_score 0.900000\n", "latency 6.294129\n", "dtype: float64" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "irr_q[['irrelevant_score','latency']].mean()" ] }, { "cell_type": "markdown", "id": "c1610a70", "metadata": {}, "source": [ "# Improvement" ] }, { "cell_type": "code", "execution_count": 48, "id": "ff6614f9", "metadata": {}, "outputs": [], "source": [ "new_prompt_template = \"\"\"\n", "You are an AI assistant specialized in Mental Health guidelines.\n", "Use the provided context to answer the question short and accurately. \n", "If you don't know the answer, simply say, \"I don't know.\"\n", "\n", "Context:\n", "{context}\n", "\n", "Question: {question}\n", "\n", "Answer:\"\"\"\n", "\n", "prompt = PromptTemplate(template=new_prompt_template, input_variables=[\"context\", \"question\"])\n", "\n", "llm = Ollama(\n", " model=\"llama3\"\n", ")\n", "\n", "# Create the chain\n", "chain = LLMChain(llm=llm, prompt=prompt)\n", "\n", "def answer_question_new(query):\n", " # Search for relevant context\n", " search_results = search_faiss(query)\n", " \n", " # Combine the content from the search results\n", " context = \"\\n\\n\".join([result['content'] for result in search_results])\n", "\n", " # Run the chain\n", " response = chain.run(context=context, question=query)\n", " \n", " return response" ] }, { "cell_type": "code", "execution_count": 49, "id": "20580d50", "metadata": {}, "outputs": [], "source": [ "df2=df.copy()" ] }, { "cell_type": "code", "execution_count": 50, "id": "b1b3d725", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [01:34<00:00, 9.40s/it]\n" ] } ], "source": [ "time_list=[]\n", "response_list=[]\n", "for i in tqdm(range(len(df2))):\n", " query = df2['Questions'].values[i]\n", " start = time.time()\n", " response = answer_question(query)\n", " end = time.time() \n", " time_list.append(end-start)\n", " response_list.append(response)" ] }, { "cell_type": "code", "execution_count": 51, "id": "63f41256", "metadata": {}, "outputs": [], "source": [ "df2['latency'] = time_list\n", "df2['response'] = response_list" ] }, { "cell_type": "code", "execution_count": 52, "id": "0d8a6065", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.01s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.35s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:47<00:00, 4.77s/it]\n", "100%|███████████████████████████████████████████| 10/10 [00:55<00:00, 5.60s/it]\n" ] } ], "source": [ "for metric in metrics:\n", " evaluator = load_evaluator(\"labeled_criteria\", criteria=metric, llm=eval_llm)\n", " \n", " reasoning = []\n", " value = []\n", " score = []\n", " \n", " for i in tqdm(range(len(df2))):\n", " eval_result = evaluator.evaluate_strings(\n", " prediction=df2.response.values[i],\n", " input=df2.Questions.values[i],\n", " reference=df2.Answers.values[i]\n", " )\n", " reasoning.append(eval_result['reasoning'])\n", " value.append(eval_result['value'])\n", " score.append(eval_result['score'])\n", " \n", " df2[metric+'_reasoning'] = reasoning\n", " df2[metric+'_value'] = value\n", " df2[metric+'_score'] = score " ] }, { "cell_type": "code", "execution_count": 77, "id": "c648632c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
QuestionsAnswerslatencyresponsecorrectness_reasoningcorrectness_valuecorrectness_scorerelevance_reasoningrelevance_valuerelevance_scorecoherence_reasoningcoherence_valuecoherence_scoreconciseness_reasoningconciseness_valueconciseness_score
0What is Mental HealthMental Health is a \" state of well-being in wh...11.046327Based on the context provided, mental health r...Step 1: Evaluate if the submission is factuall...N0Step 1: Analyze the relevance criterion\\nThe s...N0The submission discusses mental health in rela...Y1Step 1: Analyze conciseness criterion\\nThe sub...Y1
1What are the most common mental disorders ment...The most common mental disorders include depre...4.509713The handbook mentions several mental illnesses...The submission mentions depression and schizop...N0Step 1: Analyze relevance criterion - Check if...Y1Step 1: Assess coherence\\nThe submission menti...N0Step 1: Analyze conciseness criterion\\nThe sub...N0
2What are the early warning signs and symptoms ...Early warning signs and symptoms of depression...8.501180According to the provided context, specificall...The submission matches the reference data in t...Y1The submission refers directly to information ...Y1Step 1: Evaluate coherence - The submission is...Y1The submission is concise and includes most of...Y1
3How can someone help a person who suffers from...To help someone with anxiety, one can support ...10.611402According to the Mental Health Handbook, when ...The submission seems consistent with the refer...Y1Step 1: Review relevance criterion\\nThe submis...Y1The submission is coherent, well-structured, a...Y1The submission is relatively concise and cover...Y1
4What are the causes of mental illness listed i...Causes of mental illness include abnormal func...6.299272According to the context, the causes of mental...The submission lists causes such as neglect, s...N0The submission mentions factors that are part ...N0The submission is coherent and well-structured...Y1Step 1: Read and understand both the input dat...N0
\n", "
" ], "text/plain": [ " Questions \\\n", "0 What is Mental Health \n", "1 What are the most common mental disorders ment... \n", "2 What are the early warning signs and symptoms ... \n", "3 How can someone help a person who suffers from... \n", "4 What are the causes of mental illness listed i... \n", "\n", " Answers latency \\\n", "0 Mental Health is a \" state of well-being in wh... 11.046327 \n", "1 The most common mental disorders include depre... 4.509713 \n", "2 Early warning signs and symptoms of depression... 8.501180 \n", "3 To help someone with anxiety, one can support ... 10.611402 \n", "4 Causes of mental illness include abnormal func... 6.299272 \n", "\n", " response \\\n", "0 Based on the context provided, mental health r... \n", "1 The handbook mentions several mental illnesses... \n", "2 According to the provided context, specificall... \n", "3 According to the Mental Health Handbook, when ... \n", "4 According to the context, the causes of mental... \n", "\n", " correctness_reasoning correctness_value \\\n", "0 Step 1: Evaluate if the submission is factuall... N \n", "1 The submission mentions depression and schizop... N \n", "2 The submission matches the reference data in t... Y \n", "3 The submission seems consistent with the refer... Y \n", "4 The submission lists causes such as neglect, s... N \n", "\n", " correctness_score relevance_reasoning \\\n", "0 0 Step 1: Analyze the relevance criterion\\nThe s... \n", "1 0 Step 1: Analyze relevance criterion - Check if... \n", "2 1 The submission refers directly to information ... \n", "3 1 Step 1: Review relevance criterion\\nThe submis... \n", "4 0 The submission mentions factors that are part ... \n", "\n", " relevance_value relevance_score \\\n", "0 N 0 \n", "1 Y 1 \n", "2 Y 1 \n", "3 Y 1 \n", "4 N 0 \n", "\n", " coherence_reasoning coherence_value \\\n", "0 The submission discusses mental health in rela... Y \n", "1 Step 1: Assess coherence\\nThe submission menti... N \n", "2 Step 1: Evaluate coherence - The submission is... Y \n", "3 The submission is coherent, well-structured, a... Y \n", "4 The submission is coherent and well-structured... Y \n", "\n", " coherence_score conciseness_reasoning \\\n", "0 1 Step 1: Analyze conciseness criterion\\nThe sub... \n", "1 0 Step 1: Analyze conciseness criterion\\nThe sub... \n", "2 1 The submission is concise and includes most of... \n", "3 1 The submission is relatively concise and cover... \n", "4 1 Step 1: Read and understand both the input dat... \n", "\n", " conciseness_value conciseness_score \n", "0 Y 1 \n", "1 N 0 \n", "2 Y 1 \n", "3 Y 1 \n", "4 N 0 " ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2.head()" ] }, { "cell_type": "code", "execution_count": 47, "id": "2d1002b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "correctness_score 0.500000\n", "relevance_score 0.888889\n", "coherence_score 0.888889\n", "conciseness_score 0.900000\n", "latency 8.190205\n", "dtype: float64" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2[['correctness_score','relevance_score','coherence_score','conciseness_score','latency']].mean()" ] }, { "cell_type": "markdown", "id": "e808bdcf", "metadata": {}, "source": [ "# Query relevance" ] }, { "cell_type": "code", "execution_count": 66, "id": "6b541f3d", "metadata": {}, "outputs": [], "source": [ "def new_search_faiss(query, k=3, threshold=1.5):\n", " query_vector = model.encode([query])[0].astype('float32')\n", " query_vector = np.expand_dims(query_vector, axis=0)\n", " distances, indices = index.search(query_vector, k)\n", " \n", " results = []\n", " for dist, idx in zip(distances[0], indices[0]):\n", " if dist < threshold: # Only include results within the threshold distance\n", " results.append({\n", " 'distance': dist,\n", " 'content': sections_data[idx]['content'],\n", " 'metadata': sections_data[idx]['metadata']\n", " })\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": 70, "id": "4f579654", "metadata": {}, "outputs": [], "source": [ "new_prompt_template = \"\"\"\n", "You are an AI assistant specialized in Mental Health guidelines.\n", "Use the provided context to answer the question short and accurately. \n", "If you don't know the answer, simply say, \"I don't know.\"\n", "\n", "Context:\n", "{context}\n", "\n", "Question: {question}\n", "\n", "Answer:\"\"\"\n", "\n", "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n", "\n", "llm = Ollama(\n", " model=\"llama3\"\n", ")\n", "\n", "# Create the chain\n", "chain = LLMChain(llm=llm, prompt=prompt)\n", "\n", "def new_answer_question(query):\n", " # Search for relevant context\n", " search_results = new_search_faiss(query)\n", " \n", " if search_results==[]:\n", " response=\"I don't know, sorry\"\n", " else:\n", " context = \"\\n\\n\".join([result['content'] for result in search_results])\n", " response = chain.run(context=context, question=query)\n", " \n", " return response" ] }, { "cell_type": "code", "execution_count": 71, "id": "1f83ef1b", "metadata": {}, "outputs": [], "source": [ "irr_q2=irr_q.copy()" ] }, { "cell_type": "code", "execution_count": 72, "id": "f06474e3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 61.93it/s]\n" ] } ], "source": [ "time_list=[]\n", "response_list=[]\n", "for i in tqdm(range(len(irr_q2))):\n", " query = irr_q['Questions'].values[i]\n", " start = time.time()\n", " response = new_answer_question(query)\n", " end = time.time() \n", " time_list.append(end-start)\n", " response_list.append(response)" ] }, { "cell_type": "code", "execution_count": 73, "id": "52db6b82", "metadata": {}, "outputs": [], "source": [ "irr_q2['response']=response_list\n", "irr_q2['latency']=time_list" ] }, { "cell_type": "code", "execution_count": 80, "id": "80a178ee", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Questionsresponselatencyirrelevant_score
0What is the capital of Mars?I don't know, sorry0.061378True
1How many unicorns live in New York City?I don't know, sorry0.012511True
2What is the color of happiness?I don't know, sorry0.011900True
3Can cats fly on Tuesdays?I don't know, sorry0.011438True
4How much does a thought weigh?I don't know, sorry0.010644True
\n", "
" ], "text/plain": [ " Questions response latency \\\n", "0 What is the capital of Mars? I don't know, sorry 0.061378 \n", "1 How many unicorns live in New York City? I don't know, sorry 0.012511 \n", "2 What is the color of happiness? I don't know, sorry 0.011900 \n", "3 Can cats fly on Tuesdays? I don't know, sorry 0.011438 \n", "4 How much does a thought weigh? I don't know, sorry 0.010644 \n", "\n", " irrelevant_score \n", "0 True \n", "1 True \n", "2 True \n", "3 True \n", "4 True " ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "irr_q2.head()" ] }, { "cell_type": "code", "execution_count": 74, "id": "4508de9e", "metadata": {}, "outputs": [], "source": [ "irr_q2['irrelevant_score'] = irr_q2['response'].str.contains(\"I don't know\")" ] }, { "cell_type": "code", "execution_count": 75, "id": "3d34ba06", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "irrelevant_score 1.000000\n", "latency 0.016068\n", "dtype: float64" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "irr_q2[['irrelevant_score','latency']].mean()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }