colpali-vespa / backend /vespa_app.py
J94's picture
Upload folder using huggingface_hub
7c88df9 verified
import os
import time
from typing import Any, Dict, Tuple
import asyncio
import numpy as np
import torch
from dotenv import load_dotenv
from vespa.application import Vespa
from vespa.io import VespaQueryResponse
from .colpali import SimMapGenerator
import backend.stopwords
import logging
class VespaQueryClient:
MAX_QUERY_TERMS = 64
VESPA_SCHEMA_NAME = "pdf_page"
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"
def __init__(self, logger: logging.Logger):
"""
Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application.
"""
load_dotenv()
self.logger = logger
if os.environ.get("USE_MTLS") == "true":
self.logger.info("Connected using mTLS")
mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY")
mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT")
self.vespa_app_url = os.environ.get("VESPA_APP_MTLS_URL")
if not self.vespa_app_url:
raise ValueError(
"Please set the VESPA_APP_MTLS_URL environment variable"
)
if not mtls_cert or not mtls_key:
raise ValueError(
"USE_MTLS was true, but VESPA_CLOUD_MTLS_KEY and VESPA_CLOUD_MTLS_CERT were not set"
)
# write the key and cert to a file
mtls_key_path = "/tmp/vespa-data-plane-private-key.pem"
with open(mtls_key_path, "w") as f:
f.write(mtls_key)
mtls_cert_path = "/tmp/vespa-data-plane-public-cert.pem"
with open(mtls_cert_path, "w") as f:
f.write(mtls_cert)
# Instantiate Vespa connection
self.app = Vespa(
url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path
)
else:
self.logger.info("Connected using token")
self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL")
if not self.vespa_app_url:
raise ValueError(
"Please set the VESPA_APP_TOKEN_URL environment variable"
)
self.vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN")
if not self.vespa_cloud_secret_token:
raise ValueError(
"Please set the VESPA_CLOUD_SECRET_TOKEN environment variable"
)
# Instantiate Vespa connection
self.app = Vespa(
url=self.vespa_app_url,
vespa_cloud_secret_token=self.vespa_cloud_secret_token,
)
self.app.wait_for_application_up()
self.logger.info(f"Connected to Vespa at {self.vespa_app_url}")
def get_fields(self, sim_map: bool = False):
if not sim_map:
return self.SELECT_FIELDS
else:
return "summaryfeatures"
def format_query_results(
self, query: str, response: VespaQueryResponse, hits: int = 5
) -> dict:
"""
Format the Vespa query results.
Args:
query (str): The query text.
response (VespaQueryResponse): The response from Vespa.
hits (int, optional): Number of hits to display. Defaults to 5.
Returns:
dict: The JSON content of the response.
"""
query_time = response.json.get("timing", {}).get("searchtime", -1)
query_time = round(query_time, 2)
count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
self.logger.debug(result_text)
return response.json
async def query_vespa_bm25(
self,
query: str,
q_emb: torch.Tensor,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Query Vespa using the BM25 ranking profile.
This corresponds to the "BM25" radio button in the UI.
Args:
query (str): The query text.
q_emb (torch.Tensor): Query embeddings.
hits (int, optional): Number of hits to retrieve. Defaults to 3.
timeout (str, optional): Query timeout. Defaults to "10s".
Returns:
dict: The formatted query results.
"""
async with self.app.asyncio(connections=1) as session:
query_embedding = self.format_q_embs(q_emb)
start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": (
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
),
"ranking": self.get_rank_profile("bm25", sim_map),
"query": query,
"timeout": timeout,
"hits": hits,
"input.query(qt)": query_embedding,
"presentation.timing": True,
**kwargs,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
self.logger.debug(
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
return self.format_query_results(query, response)
def float_to_binary_embedding(self, float_query_embedding: dict) -> dict:
"""
Convert float query embeddings to binary embeddings.
Args:
float_query_embedding (dict): Dictionary of float embeddings.
Returns:
dict: Dictionary of binary embeddings.
"""
binary_query_embeddings = {}
for key, vector in float_query_embedding.items():
binary_vector = (
np.packbits(np.where(np.array(vector) > 0, 1, 0))
.astype(np.int8)
.tolist()
)
binary_query_embeddings[key] = binary_vector
if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS:
self.logger.warning(
f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating."
)
break
return binary_query_embeddings
def create_nn_query_strings(
self, binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20
) -> Tuple[str, dict]:
"""
Create nearest neighbor query strings for Vespa.
Args:
binary_query_embeddings (dict): Binary query embeddings.
target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20.
Returns:
Tuple[str, dict]: Nearest neighbor query string and query tensor dictionary.
"""
nn_query_dict = {}
for i in range(len(binary_query_embeddings)):
nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i]
nn = " OR ".join(
[
f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))"
for i in range(len(binary_query_embeddings))
]
)
return nn, nn_query_dict
def format_q_embs(self, q_embs: torch.Tensor) -> dict:
"""
Convert query embeddings to a dictionary of lists.
Args:
q_embs (torch.Tensor): Query embeddings tensor.
Returns:
dict: Dictionary where each key is an index and value is the embedding list.
"""
return {idx: emb.tolist() for idx, emb in enumerate(q_embs)}
async def get_result_from_query(
self,
query: str,
q_embs: torch.Tensor,
ranking: str,
idx_to_token: dict,
) -> Dict[str, Any]:
"""
Get query results from Vespa based on the ranking method.
Args:
query (str): The query text.
q_embs (torch.Tensor): Query embeddings.
ranking (str): The ranking method to use.
idx_to_token (dict): Index to token mapping.
Returns:
Dict[str, Any]: The query results.
"""
# Remove stopwords from the query to avoid visual emphasis on irrelevant words (e.g., "the", "and", "of")
query = backend.stopwords.filter(query)
rank_method = ranking.split("_")[0]
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
if rank_method == "colpali": # ColPali
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "hybrid": # Hybrid ColPali+BM25
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "bm25":
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
else:
raise ValueError(f"Unsupported ranking: {rank_method}")
if "root" not in result or "children" not in result["root"]:
result["root"] = {"children": []}
return result
for single_result in result["root"]["children"]:
self.logger.debug(single_result["fields"].keys())
return result
def get_sim_maps_from_query(
self, query: str, q_embs: torch.Tensor, ranking: str, idx_to_token: dict
):
"""
Get similarity maps from Vespa based on the ranking method.
Args:
query (str): The query text.
q_embs (torch.Tensor): Query embeddings.
ranking (str): The ranking method to use.
idx_to_token (dict): Index to token mapping.
Returns:
Dict[str, Any]: The query results.
"""
# Get the result by calling asyncio.run
result = asyncio.run(
self.get_result_from_query(query, q_embs, ranking, idx_to_token)
)
vespa_sim_maps = []
for single_result in result["root"]["children"]:
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map is not None:
vespa_sim_maps.append(vespa_sim_map)
else:
raise ValueError("No sim_map found in Vespa response")
return vespa_sim_maps
async def get_full_image_from_vespa(self, doc_id: str) -> str:
"""
Retrieve the full image from Vespa for a given document ID.
Args:
doc_id (str): The document ID.
Returns:
str: The full image data.
"""
async with self.app.asyncio(connections=1) as session:
start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": f'select full_image from {self.VESPA_SCHEMA_NAME} where id contains "{doc_id}"',
"ranking": "unranked",
"presentation.timing": True,
"ranking.matching.numThreadsPerSearch": 1,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
self.logger.debug(
f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
return response.json["root"]["children"][0]["fields"]["full_image"]
def get_results_children(self, result: VespaQueryResponse) -> list:
return result["root"]["children"]
def results_to_search_results(
self, result: VespaQueryResponse, idx_to_token: dict
) -> list:
# Initialize sim_map_ fields in the result
fields_to_add = [
f"sim_map_{token}_{idx}"
for idx, token in idx_to_token.items()
if not SimMapGenerator.should_filter_token(token)
]
for child in result["root"]["children"]:
for sim_map_key in fields_to_add:
child["fields"][sim_map_key] = None
return self.get_results_children(result)
async def get_suggestions(self, query: str) -> list:
async with self.app.asyncio(connections=1) as session:
start = time.perf_counter()
yql = f'select questions from {self.VESPA_SCHEMA_NAME} where questions matches (".*{query}.*")'
response: VespaQueryResponse = await session.query(
body={
"yql": yql,
"query": query,
"ranking": "unranked",
"presentation.timing": True,
"presentation.summary": "suggestions",
"ranking.matching.numThreadsPerSearch": 1,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
self.logger.debug(
f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
search_results = (
response.json["root"]["children"]
if "root" in response.json and "children" in response.json["root"]
else []
)
questions = [
result["fields"]["questions"]
for result in search_results
if "questions" in result["fields"]
]
unique_questions = set([item for sublist in questions for item in sublist])
# remove an artifact from our data generation
if "string" in unique_questions:
unique_questions.remove("string")
return list(unique_questions)
def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
if sim_map:
return f"{ranking}_sim"
else:
return ranking
async def query_vespa_colpali(
self,
query: str,
ranking: str,
q_emb: torch.Tensor,
target_hits_per_query_tensor: int = 100,
hnsw_explore_additional_hits: int = 300,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Query Vespa using nearest neighbor search with mixed tensors for MaxSim calculations.
This corresponds to the "ColPali" radio button in the UI.
Args:
query (str): The query text.
q_emb (torch.Tensor): Query embeddings.
target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20.
hits (int, optional): Number of hits to retrieve. Defaults to 3.
timeout (str, optional): Query timeout. Defaults to "10s".
Returns:
dict: The formatted query results.
"""
async with self.app.asyncio(connections=1) as session:
float_query_embedding = self.format_q_embs(q_emb)
binary_query_embeddings = self.float_to_binary_embedding(
float_query_embedding
)
# Mixed tensors for MaxSim calculations
query_tensors = {
"input.query(qtb)": binary_query_embeddings,
"input.query(qt)": float_query_embedding,
}
nn_string, nn_query_dict = self.create_nn_query_strings(
binary_query_embeddings, target_hits_per_query_tensor
)
query_tensors.update(nn_query_dict)
response: VespaQueryResponse = await session.query(
body={
**query_tensors,
"presentation.timing": True,
"yql": (
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
),
"ranking.profile": self.get_rank_profile(
ranking=ranking, sim_map=sim_map
),
"timeout": timeout,
"hits": hits,
"query": query,
"hnsw.exploreAdditionalHits": hnsw_explore_additional_hits,
"ranking.rerankCount": 100,
**kwargs,
},
)
assert response.is_successful(), response.json
return self.format_query_results(query, response)
async def keepalive(self) -> bool:
"""
Query Vespa to keep the connection alive.
Returns:
bool: True if the connection is alive.
"""
async with self.app.asyncio(connections=1) as session:
response: VespaQueryResponse = await session.query(
body={
"yql": f"select title from {self.VESPA_SCHEMA_NAME} where true limit 1;",
"ranking": "unranked",
"query": "keepalive",
"timeout": "3s",
"hits": 1,
},
)
assert response.is_successful(), response.json
return True