Spaces:
Runtime error
Runtime error
import os | |
import time | |
from typing import Any, Dict, Tuple | |
import asyncio | |
import numpy as np | |
import torch | |
from dotenv import load_dotenv | |
from vespa.application import Vespa | |
from vespa.io import VespaQueryResponse | |
from .colpali import SimMapGenerator | |
import backend.stopwords | |
import logging | |
class VespaQueryClient: | |
MAX_QUERY_TERMS = 64 | |
VESPA_SCHEMA_NAME = "pdf_page" | |
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text" | |
def __init__(self, logger: logging.Logger): | |
""" | |
Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application. | |
""" | |
load_dotenv() | |
self.logger = logger | |
if os.environ.get("USE_MTLS") == "true": | |
self.logger.info("Connected using mTLS") | |
mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY") | |
mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT") | |
self.vespa_app_url = os.environ.get("VESPA_APP_MTLS_URL") | |
if not self.vespa_app_url: | |
raise ValueError( | |
"Please set the VESPA_APP_MTLS_URL environment variable" | |
) | |
if not mtls_cert or not mtls_key: | |
raise ValueError( | |
"USE_MTLS was true, but VESPA_CLOUD_MTLS_KEY and VESPA_CLOUD_MTLS_CERT were not set" | |
) | |
# write the key and cert to a file | |
mtls_key_path = "/tmp/vespa-data-plane-private-key.pem" | |
with open(mtls_key_path, "w") as f: | |
f.write(mtls_key) | |
mtls_cert_path = "/tmp/vespa-data-plane-public-cert.pem" | |
with open(mtls_cert_path, "w") as f: | |
f.write(mtls_cert) | |
# Instantiate Vespa connection | |
self.app = Vespa( | |
url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path | |
) | |
else: | |
self.logger.info("Connected using token") | |
self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL") | |
if not self.vespa_app_url: | |
raise ValueError( | |
"Please set the VESPA_APP_TOKEN_URL environment variable" | |
) | |
self.vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN") | |
if not self.vespa_cloud_secret_token: | |
raise ValueError( | |
"Please set the VESPA_CLOUD_SECRET_TOKEN environment variable" | |
) | |
# Instantiate Vespa connection | |
self.app = Vespa( | |
url=self.vespa_app_url, | |
vespa_cloud_secret_token=self.vespa_cloud_secret_token, | |
) | |
self.app.wait_for_application_up() | |
self.logger.info(f"Connected to Vespa at {self.vespa_app_url}") | |
def get_fields(self, sim_map: bool = False): | |
if not sim_map: | |
return self.SELECT_FIELDS | |
else: | |
return "summaryfeatures" | |
def format_query_results( | |
self, query: str, response: VespaQueryResponse, hits: int = 5 | |
) -> dict: | |
""" | |
Format the Vespa query results. | |
Args: | |
query (str): The query text. | |
response (VespaQueryResponse): The response from Vespa. | |
hits (int, optional): Number of hits to display. Defaults to 5. | |
Returns: | |
dict: The JSON content of the response. | |
""" | |
query_time = response.json.get("timing", {}).get("searchtime", -1) | |
query_time = round(query_time, 2) | |
count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) | |
result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n" | |
self.logger.debug(result_text) | |
return response.json | |
async def query_vespa_bm25( | |
self, | |
query: str, | |
q_emb: torch.Tensor, | |
hits: int = 3, | |
timeout: str = "10s", | |
sim_map: bool = False, | |
**kwargs, | |
) -> dict: | |
""" | |
Query Vespa using the BM25 ranking profile. | |
This corresponds to the "BM25" radio button in the UI. | |
Args: | |
query (str): The query text. | |
q_emb (torch.Tensor): Query embeddings. | |
hits (int, optional): Number of hits to retrieve. Defaults to 3. | |
timeout (str, optional): Query timeout. Defaults to "10s". | |
Returns: | |
dict: The formatted query results. | |
""" | |
async with self.app.asyncio(connections=1) as session: | |
query_embedding = self.format_q_embs(q_emb) | |
start = time.perf_counter() | |
response: VespaQueryResponse = await session.query( | |
body={ | |
"yql": ( | |
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();" | |
), | |
"ranking": self.get_rank_profile("bm25", sim_map), | |
"query": query, | |
"timeout": timeout, | |
"hits": hits, | |
"input.query(qt)": query_embedding, | |
"presentation.timing": True, | |
**kwargs, | |
}, | |
) | |
assert response.is_successful(), response.json | |
stop = time.perf_counter() | |
self.logger.debug( | |
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was " | |
f"{response.json.get('timing', {}).get('searchtime', -1)} s" | |
) | |
return self.format_query_results(query, response) | |
def float_to_binary_embedding(self, float_query_embedding: dict) -> dict: | |
""" | |
Convert float query embeddings to binary embeddings. | |
Args: | |
float_query_embedding (dict): Dictionary of float embeddings. | |
Returns: | |
dict: Dictionary of binary embeddings. | |
""" | |
binary_query_embeddings = {} | |
for key, vector in float_query_embedding.items(): | |
binary_vector = ( | |
np.packbits(np.where(np.array(vector) > 0, 1, 0)) | |
.astype(np.int8) | |
.tolist() | |
) | |
binary_query_embeddings[key] = binary_vector | |
if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS: | |
self.logger.warning( | |
f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating." | |
) | |
break | |
return binary_query_embeddings | |
def create_nn_query_strings( | |
self, binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20 | |
) -> Tuple[str, dict]: | |
""" | |
Create nearest neighbor query strings for Vespa. | |
Args: | |
binary_query_embeddings (dict): Binary query embeddings. | |
target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20. | |
Returns: | |
Tuple[str, dict]: Nearest neighbor query string and query tensor dictionary. | |
""" | |
nn_query_dict = {} | |
for i in range(len(binary_query_embeddings)): | |
nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i] | |
nn = " OR ".join( | |
[ | |
f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))" | |
for i in range(len(binary_query_embeddings)) | |
] | |
) | |
return nn, nn_query_dict | |
def format_q_embs(self, q_embs: torch.Tensor) -> dict: | |
""" | |
Convert query embeddings to a dictionary of lists. | |
Args: | |
q_embs (torch.Tensor): Query embeddings tensor. | |
Returns: | |
dict: Dictionary where each key is an index and value is the embedding list. | |
""" | |
return {idx: emb.tolist() for idx, emb in enumerate(q_embs)} | |
async def get_result_from_query( | |
self, | |
query: str, | |
q_embs: torch.Tensor, | |
ranking: str, | |
idx_to_token: dict, | |
) -> Dict[str, Any]: | |
""" | |
Get query results from Vespa based on the ranking method. | |
Args: | |
query (str): The query text. | |
q_embs (torch.Tensor): Query embeddings. | |
ranking (str): The ranking method to use. | |
idx_to_token (dict): Index to token mapping. | |
Returns: | |
Dict[str, Any]: The query results. | |
""" | |
# Remove stopwords from the query to avoid visual emphasis on irrelevant words (e.g., "the", "and", "of") | |
query = backend.stopwords.filter(query) | |
rank_method = ranking.split("_")[0] | |
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim" | |
if rank_method == "colpali": # ColPali | |
result = await self.query_vespa_colpali( | |
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map | |
) | |
elif rank_method == "hybrid": # Hybrid ColPali+BM25 | |
result = await self.query_vespa_colpali( | |
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map | |
) | |
elif rank_method == "bm25": | |
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map) | |
else: | |
raise ValueError(f"Unsupported ranking: {rank_method}") | |
if "root" not in result or "children" not in result["root"]: | |
result["root"] = {"children": []} | |
return result | |
for single_result in result["root"]["children"]: | |
self.logger.debug(single_result["fields"].keys()) | |
return result | |
def get_sim_maps_from_query( | |
self, query: str, q_embs: torch.Tensor, ranking: str, idx_to_token: dict | |
): | |
""" | |
Get similarity maps from Vespa based on the ranking method. | |
Args: | |
query (str): The query text. | |
q_embs (torch.Tensor): Query embeddings. | |
ranking (str): The ranking method to use. | |
idx_to_token (dict): Index to token mapping. | |
Returns: | |
Dict[str, Any]: The query results. | |
""" | |
# Get the result by calling asyncio.run | |
result = asyncio.run( | |
self.get_result_from_query(query, q_embs, ranking, idx_to_token) | |
) | |
vespa_sim_maps = [] | |
for single_result in result["root"]["children"]: | |
vespa_sim_map = single_result["fields"].get("summaryfeatures", None) | |
if vespa_sim_map is not None: | |
vespa_sim_maps.append(vespa_sim_map) | |
else: | |
raise ValueError("No sim_map found in Vespa response") | |
return vespa_sim_maps | |
async def get_full_image_from_vespa(self, doc_id: str) -> str: | |
""" | |
Retrieve the full image from Vespa for a given document ID. | |
Args: | |
doc_id (str): The document ID. | |
Returns: | |
str: The full image data. | |
""" | |
async with self.app.asyncio(connections=1) as session: | |
start = time.perf_counter() | |
response: VespaQueryResponse = await session.query( | |
body={ | |
"yql": f'select full_image from {self.VESPA_SCHEMA_NAME} where id contains "{doc_id}"', | |
"ranking": "unranked", | |
"presentation.timing": True, | |
"ranking.matching.numThreadsPerSearch": 1, | |
}, | |
) | |
assert response.is_successful(), response.json | |
stop = time.perf_counter() | |
self.logger.debug( | |
f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was " | |
f"{response.json.get('timing', {}).get('searchtime', -1)} s" | |
) | |
return response.json["root"]["children"][0]["fields"]["full_image"] | |
def get_results_children(self, result: VespaQueryResponse) -> list: | |
return result["root"]["children"] | |
def results_to_search_results( | |
self, result: VespaQueryResponse, idx_to_token: dict | |
) -> list: | |
# Initialize sim_map_ fields in the result | |
fields_to_add = [ | |
f"sim_map_{token}_{idx}" | |
for idx, token in idx_to_token.items() | |
if not SimMapGenerator.should_filter_token(token) | |
] | |
for child in result["root"]["children"]: | |
for sim_map_key in fields_to_add: | |
child["fields"][sim_map_key] = None | |
return self.get_results_children(result) | |
async def get_suggestions(self, query: str) -> list: | |
async with self.app.asyncio(connections=1) as session: | |
start = time.perf_counter() | |
yql = f'select questions from {self.VESPA_SCHEMA_NAME} where questions matches (".*{query}.*")' | |
response: VespaQueryResponse = await session.query( | |
body={ | |
"yql": yql, | |
"query": query, | |
"ranking": "unranked", | |
"presentation.timing": True, | |
"presentation.summary": "suggestions", | |
"ranking.matching.numThreadsPerSearch": 1, | |
}, | |
) | |
assert response.is_successful(), response.json | |
stop = time.perf_counter() | |
self.logger.debug( | |
f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was " | |
f"{response.json.get('timing', {}).get('searchtime', -1)} s" | |
) | |
search_results = ( | |
response.json["root"]["children"] | |
if "root" in response.json and "children" in response.json["root"] | |
else [] | |
) | |
questions = [ | |
result["fields"]["questions"] | |
for result in search_results | |
if "questions" in result["fields"] | |
] | |
unique_questions = set([item for sublist in questions for item in sublist]) | |
# remove an artifact from our data generation | |
if "string" in unique_questions: | |
unique_questions.remove("string") | |
return list(unique_questions) | |
def get_rank_profile(self, ranking: str, sim_map: bool) -> str: | |
if sim_map: | |
return f"{ranking}_sim" | |
else: | |
return ranking | |
async def query_vespa_colpali( | |
self, | |
query: str, | |
ranking: str, | |
q_emb: torch.Tensor, | |
target_hits_per_query_tensor: int = 100, | |
hnsw_explore_additional_hits: int = 300, | |
hits: int = 3, | |
timeout: str = "10s", | |
sim_map: bool = False, | |
**kwargs, | |
) -> dict: | |
""" | |
Query Vespa using nearest neighbor search with mixed tensors for MaxSim calculations. | |
This corresponds to the "ColPali" radio button in the UI. | |
Args: | |
query (str): The query text. | |
q_emb (torch.Tensor): Query embeddings. | |
target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20. | |
hits (int, optional): Number of hits to retrieve. Defaults to 3. | |
timeout (str, optional): Query timeout. Defaults to "10s". | |
Returns: | |
dict: The formatted query results. | |
""" | |
async with self.app.asyncio(connections=1) as session: | |
float_query_embedding = self.format_q_embs(q_emb) | |
binary_query_embeddings = self.float_to_binary_embedding( | |
float_query_embedding | |
) | |
# Mixed tensors for MaxSim calculations | |
query_tensors = { | |
"input.query(qtb)": binary_query_embeddings, | |
"input.query(qt)": float_query_embedding, | |
} | |
nn_string, nn_query_dict = self.create_nn_query_strings( | |
binary_query_embeddings, target_hits_per_query_tensor | |
) | |
query_tensors.update(nn_query_dict) | |
response: VespaQueryResponse = await session.query( | |
body={ | |
**query_tensors, | |
"presentation.timing": True, | |
"yql": ( | |
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()" | |
), | |
"ranking.profile": self.get_rank_profile( | |
ranking=ranking, sim_map=sim_map | |
), | |
"timeout": timeout, | |
"hits": hits, | |
"query": query, | |
"hnsw.exploreAdditionalHits": hnsw_explore_additional_hits, | |
"ranking.rerankCount": 100, | |
**kwargs, | |
}, | |
) | |
assert response.is_successful(), response.json | |
return self.format_query_results(query, response) | |
async def keepalive(self) -> bool: | |
""" | |
Query Vespa to keep the connection alive. | |
Returns: | |
bool: True if the connection is alive. | |
""" | |
async with self.app.asyncio(connections=1) as session: | |
response: VespaQueryResponse = await session.query( | |
body={ | |
"yql": f"select title from {self.VESPA_SCHEMA_NAME} where true limit 1;", | |
"ranking": "unranked", | |
"query": "keepalive", | |
"timeout": "3s", | |
"hits": 1, | |
}, | |
) | |
assert response.is_successful(), response.json | |
return True | |