Spaces:
Paused
Paused
zephyr on t4
Browse files- benchmark/__main__.py +77 -0
- benchmark/questions.json +38 -0
- config/prompt_templates/llama2.txt +5 -4
- config/prompt_templates/mistral-7b-instruct.txt +0 -7
- config/prompt_templates/zephyr_7b.txt +10 -0
- data/benchmark/.gitkeep +0 -0
- data/indexing_benchmark.ipynb +387 -0
- qa_engine/mocks.py +5 -19
- qa_engine/qa_engine.py +37 -3
benchmark/__main__.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
|
4 |
+
import wandb
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from qa_engine import logger, Config, QAEngine
|
8 |
+
|
9 |
+
|
10 |
+
QUESTIONS_FILENAME = 'benchmark/questions.json'
|
11 |
+
|
12 |
+
config = Config()
|
13 |
+
qa_engine = QAEngine(
|
14 |
+
llm_model_id=config.question_answering_model_id,
|
15 |
+
embedding_model_id=config.embedding_model_id,
|
16 |
+
index_repo_id=config.index_repo_id,
|
17 |
+
prompt_template=config.prompt_template,
|
18 |
+
use_docs_for_context=config.use_docs_for_context,
|
19 |
+
add_sources_to_response=config.add_sources_to_response,
|
20 |
+
use_messages_for_context=config.use_messages_in_context,
|
21 |
+
debug=config.debug
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
filtered_config = config.asdict()
|
27 |
+
disallowed_config_keys = [
|
28 |
+
"DISCORD_TOKEN", "NUM_LAST_MESSAGES", "USE_NAMES_IN_CONTEXT",
|
29 |
+
"ENABLE_COMMANDS", "APP_MODE", "DEBUG"
|
30 |
+
]
|
31 |
+
for key in disallowed_config_keys:
|
32 |
+
filtered_config.pop(key, None)
|
33 |
+
|
34 |
+
wandb.init(
|
35 |
+
project='HF-Docs-QA',
|
36 |
+
entity='hf-qa-bot',
|
37 |
+
name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
|
38 |
+
mode='run', # run/disabled
|
39 |
+
config=filtered_config
|
40 |
+
)
|
41 |
+
|
42 |
+
with open(QUESTIONS_FILENAME, 'r') as f:
|
43 |
+
questions = json.load(f)
|
44 |
+
|
45 |
+
table = wandb.Table(
|
46 |
+
columns=[
|
47 |
+
"id", "question", "messages_context", "answer", "sources", "time"
|
48 |
+
]
|
49 |
+
)
|
50 |
+
for i, q in enumerate(questions):
|
51 |
+
logger.info(f"Question {i+1}/{len(questions)}")
|
52 |
+
|
53 |
+
question = q['question']
|
54 |
+
messages_context = q['messages_context']
|
55 |
+
|
56 |
+
time_start = time.perf_counter()
|
57 |
+
response = qa_engine.get_response(
|
58 |
+
question=question,
|
59 |
+
messages_context=messages_context
|
60 |
+
)
|
61 |
+
time_end = time.perf_counter()
|
62 |
+
|
63 |
+
table.add_data(
|
64 |
+
i,
|
65 |
+
question,
|
66 |
+
messages_context,
|
67 |
+
response.get_answer(),
|
68 |
+
response.get_sources_as_text(),
|
69 |
+
time_end - time_start
|
70 |
+
)
|
71 |
+
|
72 |
+
wandb.log({"answers": table})
|
73 |
+
wandb.finish()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
main()
|
benchmark/questions.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"question": "How to create audio dataset with Hugging Face?",
|
4 |
+
"messages_context": " "
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"question": "I want to check if 2 sentences are similar semantically. How can I do it?",
|
8 |
+
"messages_context": " "
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"question": "What are the benefits of Gradio?",
|
12 |
+
"messages_context": " "
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"question": "How to deploy a text-to-image model?",
|
16 |
+
"messages_context": " "
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"question": "Does Hugging Face offer any distributed training assistance? followup: Can you give me an example setup of it?",
|
20 |
+
"messages_context": " "
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"question": "I want to detect cars on video recording. How should I do it and what models do you recommend?",
|
24 |
+
"messages_context": " "
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"question": "Is there any tool for evaluating models in Hugging Face? followup: Can you give me an example setup of it?",
|
28 |
+
"messages_context": " "
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"question": "What are some advantages of the Hugging Face Hub?",
|
32 |
+
"messages_context": " "
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"question": "How would I use a model in 8 bit in transformers?",
|
36 |
+
"messages_context": " "
|
37 |
+
}
|
38 |
+
]
|
config/prompt_templates/llama2.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
<<SYS>>
|
2 |
-
|
3 |
-
|
|
|
4 |
<</SYS>>
|
5 |
|
6 |
-
[INST]
|
|
|
1 |
+
<<SYS>>Using the information contained in the context,
|
2 |
+
give a comprehensive answer to the question.
|
3 |
+
Respond only to the question asked, response should be concise and relevant to the question.
|
4 |
+
If the answer cannot be deduced from the context, do not give an answer.
|
5 |
<</SYS>>
|
6 |
|
7 |
+
[INST] Context: {context} [/INST] User: {question}
|
config/prompt_templates/mistral-7b-instruct.txt
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
[INST]
|
2 |
-
You are a helpful assistant for question-answering over Huggging Face documentation. Use the following chunks of relevant information to answer the question. If you don't know the answer, just say that you don't know. You can output code snippets that appear in the documentation to enrich the answer.
|
3 |
-
<QUESTION>:
|
4 |
-
{question}
|
5 |
-
<CONTEXT>:
|
6 |
-
{context}
|
7 |
-
[/INST]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/prompt_templates/zephyr_7b.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<|system|>Using the information contained in the context,
|
2 |
+
give a comprehensive answer to the question.
|
3 |
+
Respond only to the question asked, response should be concise and relevant to the question.
|
4 |
+
If the answer cannot be deduced from the context, do not give an answer.</s>
|
5 |
+
<|user|>
|
6 |
+
Context:
|
7 |
+
{context}
|
8 |
+
Question: {question}
|
9 |
+
</s>
|
10 |
+
<|assistant|>
|
data/benchmark/.gitkeep
DELETED
File without changes
|
data/indexing_benchmark.ipynb
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 37,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import math\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"from pathlib import Path\n",
|
12 |
+
"from typing import List, Union, Any\n",
|
13 |
+
"from tqdm import tqdm\n",
|
14 |
+
"from sentence_transformers import CrossEncoder\n",
|
15 |
+
"from langchain.chains import RetrievalQA\n",
|
16 |
+
"from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
|
17 |
+
"from langchain.document_loaders import TextLoader\n",
|
18 |
+
"from langchain.indexes import VectorstoreIndexCreator\n",
|
19 |
+
"from langchain.text_splitter import CharacterTextSplitter\n",
|
20 |
+
"from langchain.vectorstores import FAISS\n",
|
21 |
+
"from sentence_transformers import CrossEncoder"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 31,
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
|
31 |
+
" max_length: int = None\n",
|
32 |
+
" def __init__(self, max_length: int = 512, **kwargs: Any):\n",
|
33 |
+
" super().__init__(**kwargs)\n",
|
34 |
+
" self.max_length = max_length\n",
|
35 |
+
" if self.max_length < 0:\n",
|
36 |
+
" print('max_length is not specified, using model default max_seq_length')\n",
|
37 |
+
"\n",
|
38 |
+
" def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
|
39 |
+
" all_embeddings = []\n",
|
40 |
+
" for text in tqdm(texts, desc=\"Embedding documents\"):\n",
|
41 |
+
" if len(text) > self.max_length and self.max_length > -1:\n",
|
42 |
+
" n_chunks = math.ceil(len(text)/self.max_length)\n",
|
43 |
+
" chunks = [\n",
|
44 |
+
" text[i*self.max_length:(i+1)*self.max_length]\n",
|
45 |
+
" for i in range(n_chunks)\n",
|
46 |
+
" ]\n",
|
47 |
+
" instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
|
48 |
+
" chunk_embeddings = self.client.encode(instruction_pairs)\n",
|
49 |
+
" avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
|
50 |
+
" all_embeddings.append(avg_embedding.tolist())\n",
|
51 |
+
" else:\n",
|
52 |
+
" instruction_pairs = [[self.embed_instruction, text]]\n",
|
53 |
+
" embeddings = self.client.encode(instruction_pairs)\n",
|
54 |
+
" all_embeddings.append(embeddings[0].tolist())\n",
|
55 |
+
"\n",
|
56 |
+
" return all_embeddings\n",
|
57 |
+
"\n",
|
58 |
+
"\n",
|
59 |
+
"class BenchDataST:\n",
|
60 |
+
" def __init__(self, path: str, percentage: float = 0.005, chunk_size: int = 512, chunk_overlap: int = 100):\n",
|
61 |
+
" self.path = path\n",
|
62 |
+
" self.percentage = percentage\n",
|
63 |
+
" self.docs = []\n",
|
64 |
+
" self.metadata = []\n",
|
65 |
+
" self.load()\n",
|
66 |
+
" self.text_splitter = CharacterTextSplitter(separator=\"\", chunk_size=chunk_size, chunk_overlap=chunk_overlap)\n",
|
67 |
+
" self.docs_processed = self.text_splitter.create_documents(self.docs, self.metadata)\n",
|
68 |
+
"\n",
|
69 |
+
" def load(self):\n",
|
70 |
+
" for p in Path(self.path).iterdir():\n",
|
71 |
+
" if not p.is_dir():\n",
|
72 |
+
" with open(p) as f:\n",
|
73 |
+
" source = f.readline().strip().replace('source: ', '')\n",
|
74 |
+
" self.docs.append(f.read())\n",
|
75 |
+
" self.metadata.append({\"source\": source})\n",
|
76 |
+
" self.docs = self.docs[:int(len(self.docs) * self.percentage)]\n",
|
77 |
+
" self.metadata = self.metadata[:int(len(self.metadata) * self.percentage)]\n",
|
78 |
+
"\n",
|
79 |
+
" def __len__(self):\n",
|
80 |
+
" return len(self.docs)\n",
|
81 |
+
"\n",
|
82 |
+
" def __getitem__(self, idx):\n",
|
83 |
+
" return self.docs[idx], self.metadata[idx]\n",
|
84 |
+
"\n",
|
85 |
+
" def __iter__(self):\n",
|
86 |
+
" for doc, metadata in zip(self.docs, self.metadata):\n",
|
87 |
+
" yield doc, metadata\n",
|
88 |
+
"\n",
|
89 |
+
" def __repr__(self):\n",
|
90 |
+
" return f'BenchDataST({len(self)} docs) at {self.path} with {self.percentage} percentage \\nSources: {self.metadata} \\nChunks: {self.text_splitter}'\n",
|
91 |
+
" \n",
|
92 |
+
"\n",
|
93 |
+
"class BenchmarkST:\n",
|
94 |
+
" def __init__(self, data: BenchDataST, baseline_model: Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, AverageInstructEmbeddings], embedding_models: List[Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, AverageInstructEmbeddings]]):\n",
|
95 |
+
" self.data = data\n",
|
96 |
+
" self.baseline_model = baseline_model\n",
|
97 |
+
" self.embedding_models = embedding_models\n",
|
98 |
+
" self.baseline_index, self.indexes = self.build_indexes()\n",
|
99 |
+
"\n",
|
100 |
+
" def build_indexes(self):\n",
|
101 |
+
" indexes = []\n",
|
102 |
+
" for model in [self.baseline_model] + self.embedding_models:\n",
|
103 |
+
" print(f\"Building index for {model}\")\n",
|
104 |
+
" index = FAISS.from_documents(self.data.docs_processed, model)\n",
|
105 |
+
" indexes.append(index)\n",
|
106 |
+
" return indexes[0], indexes[1:]\n",
|
107 |
+
" \n",
|
108 |
+
" def add_index(self, index: FAISS):\n",
|
109 |
+
" self.indexes.append(index)\n",
|
110 |
+
" \n",
|
111 |
+
" def evaluate(self, query: str, k: int = 3):\n",
|
112 |
+
" baseline_results = self.baseline_index.similarity_search_with_score(query, k=k)\n",
|
113 |
+
" results = []\n",
|
114 |
+
" for index in self.indexes:\n",
|
115 |
+
" results.append(index.similarity_search_with_score(query, k=k))\n",
|
116 |
+
" return baseline_results, results"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": 48,
|
122 |
+
"metadata": {},
|
123 |
+
"outputs": [
|
124 |
+
{
|
125 |
+
"name": "stdout",
|
126 |
+
"output_type": "stream",
|
127 |
+
"text": [
|
128 |
+
"load INSTRUCTOR_Transformer\n",
|
129 |
+
"max_seq_length 512\n"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"name": "stderr",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"No sentence-transformers model found with name /Users/michalwilinski/.cache/torch/sentence_transformers/cross-encoder_ms-marco-MiniLM-L-12-v2. Creating a new one with MEAN pooling.\n",
|
137 |
+
"Some weights of the model checkpoint at /Users/michalwilinski/.cache/torch/sentence_transformers/cross-encoder_ms-marco-MiniLM-L-12-v2 were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']\n",
|
138 |
+
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
139 |
+
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stdout",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"Building index for client=INSTRUCTOR(\n",
|
147 |
+
" (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: T5EncoderModel \n",
|
148 |
+
" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})\n",
|
149 |
+
" (2): Dense({'in_features': 768, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})\n",
|
150 |
+
" (3): Normalize()\n",
|
151 |
+
") model_name='hkunlp/instructor-base' cache_folder=None model_kwargs={} encode_kwargs={} embed_instruction='Represent this piece of text for searching relevant information:' query_instruction='Query the most relevant piece of information from the Hugging Face documentation' max_length=512\n"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"name": "stderr",
|
156 |
+
"output_type": "stream",
|
157 |
+
"text": [
|
158 |
+
"Embedding documents: 100%|██████████| 278/278 [00:19<00:00, 14.11it/s]\n"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"name": "stdout",
|
163 |
+
"output_type": "stream",
|
164 |
+
"text": [
|
165 |
+
"Building index for client=SentenceTransformer(\n",
|
166 |
+
" (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
|
167 |
+
" (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
|
168 |
+
") model_name='cross-encoder/ms-marco-MiniLM-L-12-v2' cache_folder=None model_kwargs={} encode_kwargs={} multi_process=False\n"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"source": [
|
173 |
+
"data = BenchDataST(\n",
|
174 |
+
" path=\"./datasets/huggingface_docs/\",\n",
|
175 |
+
" percentage=0.005,\n",
|
176 |
+
" chunk_size=512,\n",
|
177 |
+
" chunk_overlap=100\n",
|
178 |
+
")\n",
|
179 |
+
"\n",
|
180 |
+
"baseline_embedding_model = AverageInstructEmbeddings(\n",
|
181 |
+
" model_name=\"hkunlp/instructor-base\",\n",
|
182 |
+
" embed_instruction=\"Represent this piece of text for searching relevant information:\",\n",
|
183 |
+
" query_instruction=\"Query the most relevant piece of information from the Hugging Face documentation\",\n",
|
184 |
+
" max_length=512,\n",
|
185 |
+
")\n",
|
186 |
+
"\n",
|
187 |
+
"embedding_model = HuggingFaceEmbeddings(\n",
|
188 |
+
" model_name=\"intfloat/e5-large-v2\",\n",
|
189 |
+
")\n",
|
190 |
+
"\n",
|
191 |
+
"cross_encoder = HuggingFaceEmbeddings(model_name=\"cross-encoder/ms-marco-MiniLM-L-12-v2\")\n",
|
192 |
+
"\n",
|
193 |
+
"benchmark = BenchmarkST(\n",
|
194 |
+
" data=data,\n",
|
195 |
+
" baseline_model=baseline_embedding_model,\n",
|
196 |
+
" embedding_models=[cross_encoder]\n",
|
197 |
+
")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 54,
|
203 |
+
"metadata": {},
|
204 |
+
"outputs": [
|
205 |
+
{
|
206 |
+
"name": "stdout",
|
207 |
+
"output_type": "stream",
|
208 |
+
"text": [
|
209 |
+
"Baseline results:\n",
|
210 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.23610792\n",
|
211 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24087097\n",
|
212 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24181677\n",
|
213 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24541612\n",
|
214 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24639006\n",
|
215 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24780047\n",
|
216 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.2535807\n",
|
217 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.25887597\n",
|
218 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27293646\n",
|
219 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27374876\n",
|
220 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27710187\n",
|
221 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.28146794\n",
|
222 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.29536068\n",
|
223 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.29784447\n",
|
224 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.30452335\n",
|
225 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.3061711\n",
|
226 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.31600478\n",
|
227 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.3166225\n",
|
228 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.33345556\n",
|
229 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.3469957\n",
|
230 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.35222226\n",
|
231 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.36451602\n",
|
232 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.36925688\n",
|
233 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.37025565\n",
|
234 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.37112093\n",
|
235 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37146708\n",
|
236 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.3766507\n",
|
237 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37794292\n",
|
238 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37923962\n",
|
239 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.38359642\n",
|
240 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.3878625\n",
|
241 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.39796114\n",
|
242 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.40057343\n",
|
243 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.40114868\n",
|
244 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.40156174\n",
|
245 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.40341228\n",
|
246 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.40720195\n",
|
247 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41241395\n",
|
248 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4134417\n",
|
249 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4134435\n",
|
250 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41754264\n",
|
251 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41917825\n",
|
252 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41928726\n",
|
253 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41988587\n",
|
254 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.42029166\n",
|
255 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.42128915\n",
|
256 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4226097\n",
|
257 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.42302307\n",
|
258 |
+
"{'source': 'https://github.com/gradio-app/gradio/blob/main/demo/stt_or_tts/run.ipynb'} 0.4252566\n",
|
259 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.42704937\n",
|
260 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4297651\n",
|
261 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43067485\n",
|
262 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43116528\n",
|
263 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.43272027\n",
|
264 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.43434155\n",
|
265 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43486434\n",
|
266 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43524152\n",
|
267 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43530554\n",
|
268 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4371896\n",
|
269 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43753576\n",
|
270 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43824\n",
|
271 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4384127\n",
|
272 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43900505\n",
|
273 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43903238\n",
|
274 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.44034868\n",
|
275 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.44217598\n",
|
276 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/schedulers/euler_ancestral.md'} 0.4426194\n",
|
277 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44303834\n",
|
278 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.4452571\n",
|
279 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44619536\n",
|
280 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.44652176\n",
|
281 |
+
"{'source': 'https://github.com/gradio-app/gradio/blob/main/demo/stt_or_tts/run.ipynb'} 0.44683564\n",
|
282 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.44743723\n",
|
283 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44768596\n",
|
284 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4477852\n",
|
285 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44906363\n",
|
286 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45155957\n",
|
287 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45215163\n",
|
288 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45415214\n",
|
289 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4541726\n",
|
290 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.4542602\n",
|
291 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.4544394\n",
|
292 |
+
"{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.45448524\n",
|
293 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.454512\n",
|
294 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45478693\n",
|
295 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/schedulers/euler_ancestral.md'} 0.45494407\n",
|
296 |
+
"{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.45494407\n",
|
297 |
+
"{'source': 'https://github.com/gradio-app/gradio/blob/main/js/accordion/CHANGELOG.md'} 0.45520714\n",
|
298 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4559689\n",
|
299 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.4568352\n",
|
300 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4577096\n",
|
301 |
+
"{'source': 'https://github.com/huggingface/simulate/blob/main/docs/source/api/lights.mdx'} 0.4577096\n",
|
302 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.45773098\n",
|
303 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.45818624\n",
|
304 |
+
"{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.45871085\n",
|
305 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.4591412\n",
|
306 |
+
"{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.46033093\n",
|
307 |
+
"{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.4605264\n",
|
308 |
+
"{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.46091354\n",
|
309 |
+
"{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.46182537\n",
|
310 |
+
"Cross encoder results:\n",
|
311 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 6.840022\n",
|
312 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} -0.98426485\n",
|
313 |
+
"{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} -1.9345549\n",
|
314 |
+
"bye\n"
|
315 |
+
]
|
316 |
+
}
|
317 |
+
],
|
318 |
+
"source": [
|
319 |
+
"query = \"textual inversion\"\n",
|
320 |
+
"k = 100\n",
|
321 |
+
"baseline_results, results = benchmark.evaluate(query=query, k=k)\n",
|
322 |
+
"print(\"Baseline results:\")\n",
|
323 |
+
"[print(doc.metadata,score) for (doc,score) in baseline_results]\n",
|
324 |
+
"cross_encoder = CrossEncoder(\"cross-encoder/ms-marco-MiniLM-L-12-v2\")\n",
|
325 |
+
"cross_encoder_results = cross_encoder.predict([(query, doc.page_content) for doc in data.docs_processed])\n",
|
326 |
+
"# rerank results\n",
|
327 |
+
"cross_encoder_results = sorted(zip(data.docs_processed, cross_encoder_results), key=lambda x: x[1], reverse=True)\n",
|
328 |
+
"print(\"Cross encoder results:\")\n",
|
329 |
+
"final_results = cross_encoder_results[:3]\n",
|
330 |
+
"[print(doc.metadata, score) for (doc,score) in final_results]\n",
|
331 |
+
"print(\"bye\")"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": 55,
|
337 |
+
"metadata": {},
|
338 |
+
"outputs": [
|
339 |
+
{
|
340 |
+
"name": "stdout",
|
341 |
+
"output_type": "stream",
|
342 |
+
"text": [
|
343 |
+
"es where the space character is not used (like Chinese or Japanese).\n",
|
344 |
+
"\n",
|
345 |
+
"The other main feature of SentencePiece is *reversible tokenization*: since there is no special treatment of spaces, decoding the tokens is done simply by concatenating them and replacing the `_`s with spaces -- this results in the normalized text. As we saw earlier, the BERT tokenizer removes repeating spaces, so its tokenization is not reversible.\n",
|
346 |
+
"\n",
|
347 |
+
"## Algorithm overview[[algorithm-overview]]\n",
|
348 |
+
"\n",
|
349 |
+
"In the following sections, we'll dive into t\n"
|
350 |
+
]
|
351 |
+
}
|
352 |
+
],
|
353 |
+
"source": [
|
354 |
+
"print(final_results[0][0].page_content)"
|
355 |
+
]
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"cell_type": "code",
|
359 |
+
"execution_count": null,
|
360 |
+
"metadata": {},
|
361 |
+
"outputs": [],
|
362 |
+
"source": []
|
363 |
+
}
|
364 |
+
],
|
365 |
+
"metadata": {
|
366 |
+
"kernelspec": {
|
367 |
+
"display_name": "hf_qa_bot",
|
368 |
+
"language": "python",
|
369 |
+
"name": "python3"
|
370 |
+
},
|
371 |
+
"language_info": {
|
372 |
+
"codemirror_mode": {
|
373 |
+
"name": "ipython",
|
374 |
+
"version": 3
|
375 |
+
},
|
376 |
+
"file_extension": ".py",
|
377 |
+
"mimetype": "text/x-python",
|
378 |
+
"name": "python",
|
379 |
+
"nbconvert_exporter": "python",
|
380 |
+
"pygments_lexer": "ipython3",
|
381 |
+
"version": "3.11.3"
|
382 |
+
},
|
383 |
+
"orig_nbformat": 4
|
384 |
+
},
|
385 |
+
"nbformat": 4,
|
386 |
+
"nbformat_minor": 2
|
387 |
+
}
|
qa_engine/mocks.py
CHANGED
@@ -6,36 +6,22 @@ from langchain.llms.base import LLM
|
|
6 |
|
7 |
class MockLocalBinaryModel(LLM):
|
8 |
"""
|
9 |
-
Mock Local Binary Model class
|
10 |
-
|
11 |
-
Args:
|
12 |
-
model_id (str): The ID of the model to be mocked.
|
13 |
-
|
14 |
-
Attributes:
|
15 |
-
model_path (str): The path to the model to be mocked.
|
16 |
-
llm (str): The string "a".
|
17 |
-
|
18 |
-
Raises:
|
19 |
-
ValueError: If the model_path does not exist.
|
20 |
"""
|
21 |
|
22 |
model_path: str = None
|
23 |
-
llm: str = '
|
24 |
|
25 |
-
def __init__(self
|
26 |
super().__init__()
|
27 |
-
self.model_path = f'bot/question_answering/{model_id}'
|
28 |
-
if not os.path.exists(self.model_path):
|
29 |
-
raise ValueError(f'{self.model_path} does not exist')
|
30 |
-
|
31 |
|
32 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
33 |
return self.llm
|
34 |
|
35 |
@property
|
36 |
def _identifying_params(self) -> Mapping[str, Any]:
|
37 |
-
return {'name_of_model':
|
38 |
|
39 |
@property
|
40 |
def _llm_type(self) -> str:
|
41 |
-
return
|
|
|
6 |
|
7 |
class MockLocalBinaryModel(LLM):
|
8 |
"""
|
9 |
+
Mock Local Binary Model class.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
|
12 |
model_path: str = None
|
13 |
+
llm: str = 'Mocked Response'
|
14 |
|
15 |
+
def __init__(self):
|
16 |
super().__init__()
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
19 |
return self.llm
|
20 |
|
21 |
@property
|
22 |
def _identifying_params(self) -> Mapping[str, Any]:
|
23 |
+
return {'name_of_model': 'mock'}
|
24 |
|
25 |
@property
|
26 |
def _llm_type(self) -> str:
|
27 |
+
return 'mock'
|
qa_engine/qa_engine.py
CHANGED
@@ -18,6 +18,7 @@ from sentence_transformers import CrossEncoder
|
|
18 |
|
19 |
from qa_engine import logger
|
20 |
from qa_engine.response import Response
|
|
|
21 |
|
22 |
|
23 |
class LocalBinaryModel(LLM):
|
@@ -191,6 +192,9 @@ class QAEngine():
|
|
191 |
model_url=llm_model_id.replace('api_models/', ''),
|
192 |
debug=self.debug
|
193 |
)
|
|
|
|
|
|
|
194 |
else:
|
195 |
logger.info('using transformers pipeline model')
|
196 |
self.llm_model = TransformersPipelineModel(
|
@@ -224,6 +228,33 @@ class QAEngine():
|
|
224 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
225 |
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
228 |
"""
|
229 |
Generate an answer to the specified question.
|
@@ -267,8 +298,10 @@ class QAEngine():
|
|
267 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
268 |
|
269 |
logger.info('Running LLM chain')
|
270 |
-
|
271 |
-
|
|
|
|
|
272 |
logger.info('Received answer')
|
273 |
|
274 |
if self.debug:
|
@@ -277,7 +310,8 @@ class QAEngine():
|
|
277 |
logger.info(f'question len: {len(question)} {sep}')
|
278 |
logger.info(f'question: {question} {sep}')
|
279 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
280 |
-
logger.info(f'answer: {
|
|
|
281 |
logger.info(f'{response.get_sources_as_text()} {sep}')
|
282 |
logger.info(f'messages_contex: {messages_context} {sep}')
|
283 |
logger.info(f'relevant_docs: {relevant_docs} {sep}')
|
|
|
18 |
|
19 |
from qa_engine import logger
|
20 |
from qa_engine.response import Response
|
21 |
+
from qa_engine.mocks import MockLocalBinaryModel
|
22 |
|
23 |
|
24 |
class LocalBinaryModel(LLM):
|
|
|
192 |
model_url=llm_model_id.replace('api_models/', ''),
|
193 |
debug=self.debug
|
194 |
)
|
195 |
+
elif llm_model_id == 'mock':
|
196 |
+
logger.info('using mock model')
|
197 |
+
self.llm_model = MockLocalBinaryModel()
|
198 |
else:
|
199 |
logger.info('using transformers pipeline model')
|
200 |
self.llm_model = TransformersPipelineModel(
|
|
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
230 |
|
231 |
+
@staticmethod
|
232 |
+
def _preprocess_question(question: str) -> str:
|
233 |
+
if question[-1] != '?':
|
234 |
+
question += '?'
|
235 |
+
return question
|
236 |
+
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def _postprocess_answer(answer: str) -> str:
|
240 |
+
'''
|
241 |
+
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
242 |
+
'''
|
243 |
+
SEQUENCES_TO_REMOVE = [
|
244 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
245 |
+
]
|
246 |
+
SEQUENCES_TO_STOP = [
|
247 |
+
'User:', 'You:', 'Question:'
|
248 |
+
]
|
249 |
+
for seq in SEQUENCES_TO_REMOVE:
|
250 |
+
answer = answer.replace(seq, '')
|
251 |
+
for seq in SEQUENCES_TO_STOP:
|
252 |
+
if seq in answer:
|
253 |
+
answer = answer[:answer.index(seq)]
|
254 |
+
answer = answer.strip()
|
255 |
+
return answer
|
256 |
+
|
257 |
+
|
258 |
def get_response(self, question: str, messages_context: str = '') -> Response:
|
259 |
"""
|
260 |
Generate an answer to the specified question.
|
|
|
298 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
299 |
|
300 |
logger.info('Running LLM chain')
|
301 |
+
question_processed = QAEngine._preprocess_question(question)
|
302 |
+
answer = self.llm_chain.run(question=question_processed, context=context)
|
303 |
+
answer_postprocessed = QAEngine._postprocess_answer(answer)
|
304 |
+
response.set_answer(answer_postprocessed)
|
305 |
logger.info('Received answer')
|
306 |
|
307 |
if self.debug:
|
|
|
310 |
logger.info(f'question len: {len(question)} {sep}')
|
311 |
logger.info(f'question: {question} {sep}')
|
312 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
313 |
+
logger.info(f'answer original: {answer} {sep}')
|
314 |
+
logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
|
315 |
logger.info(f'{response.get_sources_as_text()} {sep}')
|
316 |
logger.info(f'messages_contex: {messages_context} {sep}')
|
317 |
logger.info(f'relevant_docs: {relevant_docs} {sep}')
|