Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload retrievers_setup.py
Browse files
rag_chain/retrievers_setup.py
CHANGED
@@ -3,10 +3,13 @@ from functools import cache
|
|
3 |
|
4 |
import qdrant_client
|
5 |
import torch
|
|
|
6 |
from langchain.retrievers import ContextualCompressionRetriever
|
7 |
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
|
|
8 |
from langchain_community.retrievers import QdrantSparseVectorRetriever
|
9 |
from langchain_community.vectorstores import Qdrant
|
|
|
10 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
11 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
12 |
|
@@ -28,6 +31,7 @@ class DenseRetrieverClient:
|
|
28 |
self.client = qdrant_client.QdrantClient(
|
29 |
url=os.getenv("QDRANT_URL"),
|
30 |
api_key=os.getenv("QDRANT_API_KEY"),
|
|
|
31 |
)
|
32 |
self.qdrant_collection = self.load_qdrant_collection()
|
33 |
|
@@ -91,6 +95,7 @@ class SparseRetrieverClient:
|
|
91 |
self.client = qdrant_client.QdrantClient(url=os.getenv(
|
92 |
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
93 |
self.model_id = splade_model_id
|
|
|
94 |
self.collection_name = collection_name
|
95 |
self.vector_name = vector_name
|
96 |
self.k = k
|
@@ -120,17 +125,23 @@ class SparseRetrieverClient:
|
|
120 |
Returns:
|
121 |
tuple[list[int], list[float]]: Indices and values of the sparse vector
|
122 |
"""
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
127 |
logits, attention_mask = output.logits, tokens.attention_mask
|
128 |
-
|
|
|
129 |
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
|
|
130 |
max_val, _ = torch.max(weighted_log, dim=1)
|
131 |
vec = max_val.squeeze()
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
return indices.tolist(), values.tolist()
|
135 |
|
136 |
def get_sparse_retriever(self):
|
@@ -172,3 +183,34 @@ def compression_retriever_setup(base_retriever, embeddings_model: str = "text-em
|
|
172 |
)
|
173 |
|
174 |
return compression_retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import qdrant_client
|
5 |
import torch
|
6 |
+
from langchain.prompts import PromptTemplate
|
7 |
from langchain.retrievers import ContextualCompressionRetriever
|
8 |
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
9 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
10 |
from langchain_community.retrievers import QdrantSparseVectorRetriever
|
11 |
from langchain_community.vectorstores import Qdrant
|
12 |
+
from langchain_openai import ChatOpenAI
|
13 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
14 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
15 |
|
|
|
31 |
self.client = qdrant_client.QdrantClient(
|
32 |
url=os.getenv("QDRANT_URL"),
|
33 |
api_key=os.getenv("QDRANT_API_KEY"),
|
34 |
+
prefer_grpc=True,
|
35 |
)
|
36 |
self.qdrant_collection = self.load_qdrant_collection()
|
37 |
|
|
|
95 |
self.client = qdrant_client.QdrantClient(url=os.getenv(
|
96 |
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
97 |
self.model_id = splade_model_id
|
98 |
+
self.tokenizer, self.model = self.set_tokenizer_config()
|
99 |
self.collection_name = collection_name
|
100 |
self.vector_name = vector_name
|
101 |
self.k = k
|
|
|
125 |
Returns:
|
126 |
tuple[list[int], list[float]]: Indices and values of the sparse vector
|
127 |
"""
|
128 |
+
tokens = self.tokenizer(text, return_tensors="pt",
|
129 |
+
max_length=512, padding="max_length", truncation=True)
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
output = self.model(**tokens)
|
133 |
+
|
134 |
logits, attention_mask = output.logits, tokens.attention_mask
|
135 |
+
|
136 |
+
relu_log = torch.log1p(torch.relu(logits))
|
137 |
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
138 |
+
|
139 |
max_val, _ = torch.max(weighted_log, dim=1)
|
140 |
vec = max_val.squeeze()
|
141 |
+
|
142 |
+
indices = torch.nonzero(vec, as_tuple=False).squeeze().cpu().numpy()
|
143 |
+
values = vec[indices].cpu().numpy()
|
144 |
+
|
145 |
return indices.tolist(), values.tolist()
|
146 |
|
147 |
def get_sparse_retriever(self):
|
|
|
183 |
)
|
184 |
|
185 |
return compression_retriever
|
186 |
+
|
187 |
+
|
188 |
+
def multi_query_retriever_setup(retriever) -> MultiQueryRetriever:
|
189 |
+
""" Configure a multi-query retriever using a base retriever and the LLM.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
retriever:
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
retriever: MultiQueryRetriever
|
196 |
+
"""
|
197 |
+
|
198 |
+
QUERY_PROMPT = PromptTemplate(
|
199 |
+
input_variables=["question"],
|
200 |
+
template="""
|
201 |
+
|
202 |
+
Your task is to generate 3 different versions of the provided question, incorporating the user's location preference in each version. Each version must be separated by newlines. Ensure that no part of your response is enclosed in quotation marks. Do not modify any acronyms or unfamiliar terms. Keep your responses clear, concise, and limited to these alternatives.
|
203 |
+
Note: The text provided are queries to Tall Tree Health Centre's AI virtual assistant.
|
204 |
+
|
205 |
+
Question:
|
206 |
+
{question}
|
207 |
+
|
208 |
+
""",
|
209 |
+
)
|
210 |
+
|
211 |
+
llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0)
|
212 |
+
multi_query_retriever = MultiQueryRetriever.from_llm(
|
213 |
+
retriever=retriever, llm=llm, prompt=QUERY_PROMPT, include_original=True,
|
214 |
+
)
|
215 |
+
|
216 |
+
return multi_query_retriever
|