|
import os |
|
import torch |
|
import numpy as np |
|
from colbert.infra import ColBERTConfig |
|
from colbert.modeling.checkpoint import Checkpoint |
|
|
|
|
|
class ColBERT: |
|
def __init__(self, name, **kwargs) -> None: |
|
print("ColBERT: Loading model", name) |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
DOCKER = kwargs.get("env") == "docker" |
|
if DOCKER: |
|
|
|
|
|
|
|
|
|
|
|
lock_file = ( |
|
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" |
|
) |
|
if os.path.exists(lock_file): |
|
os.remove(lock_file) |
|
|
|
self.ckpt = Checkpoint( |
|
name, |
|
colbert_config=ColBERTConfig(model_name=name), |
|
).to(self.device) |
|
pass |
|
|
|
def calculate_similarity_scores(self, query_embeddings, document_embeddings): |
|
|
|
query_embeddings = query_embeddings.to(self.device) |
|
document_embeddings = document_embeddings.to(self.device) |
|
|
|
|
|
if query_embeddings.dim() != 3: |
|
raise ValueError( |
|
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." |
|
) |
|
if document_embeddings.dim() != 3: |
|
raise ValueError( |
|
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." |
|
) |
|
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: |
|
raise ValueError( |
|
"There should be either one query or queries equal to the number of documents." |
|
) |
|
|
|
|
|
transposed_query_embeddings = query_embeddings.permute(0, 2, 1) |
|
|
|
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings) |
|
|
|
maximum_scores = torch.max(computed_scores, dim=1).values |
|
|
|
|
|
final_scores = maximum_scores.sum(dim=1) |
|
|
|
normalized_scores = torch.softmax(final_scores, dim=0) |
|
|
|
return normalized_scores.detach().cpu().numpy().astype(np.float32) |
|
|
|
def predict(self, sentences): |
|
|
|
query = sentences[0][0] |
|
docs = [i[1] for i in sentences] |
|
|
|
|
|
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] |
|
|
|
embedded_queries = self.ckpt.queryFromText([query], bsize=32) |
|
embedded_query = embedded_queries[0] |
|
|
|
|
|
scores = self.calculate_similarity_scores( |
|
embedded_query.unsqueeze(0), embedded_docs |
|
) |
|
|
|
return scores |
|
|