File size: 2,604 Bytes
a73d4bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import numpy as np
from typing import List
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from chunker import chunk_documents
class Retriever:
def __init__(self, docs: List[str], score: int) -> None:
self.docs = chunk_documents(docs=docs)
self.score = score
tokenized_docs = [doc.lower().split(" ") for doc in self.docs]
self.bm25 = BM25Okapi(tokenized_docs)
self.sbert = SentenceTransformer(
'sentence-transformers/all-distilroberta-v1'
)
self.doc_embeddings = self.sbert.encode(
self.docs, show_progress_bar=True
)
self.cross_encoder = CrossEncoder("cross-encoder/stsb-roberta-base")
def get_docs(self, query: str, n: int = 5, score: int = 2) -> List[str]:
match score:
case 0:
bm25_scores = self._get_bm25_scores(query=query)
sorted_indices = torch.Tensor.tolist(
np.argsort(bm25_scores)
)[::-1]
case 1:
semantic_scores = self._get_semantic_scores(query=query)
sorted_indices = torch.Tensor.tolist(
np.argsort(semantic_scores)
)[::-1]
case 2:
bm25_scores = self._get_bm25_scores(query=query)
semantic_scores = self._get_semantic_scores(query=query)
scores = torch.tensor(0.3 * bm25_scores) + 0.7 * semantic_scores
sorted_indices = torch.Tensor.tolist(np.argsort(scores))[::-1]
preselected_docs = [self.docs[i] for i in sorted_indices][:n]
result = self.rerank(query=query, docs=preselected_docs)
return result
def _get_bm25_scores(self, query: str) -> np.ndarray[float]:
tokenized_query = query.lower().split(" ")
bm25_scores = self.bm25.get_scores(tokenized_query)
return bm25_scores
def _get_semantic_scores(self, query: str) -> torch.Tensor:
query_embeddings = self.sbert.encode(query)
semantic_scores = self.sbert.similarity(
query_embeddings, self.doc_embeddings
)
return semantic_scores[0]
def rerank(self, query: str, docs: List[str]) -> List[str]:
pairs = [(query, doc) for doc in docs]
rerank_scores = self.cross_encoder.predict(pairs)
reranked_docs = [doc for _, doc in sorted(zip(rerank_scores, docs), reverse=True)]
return reranked_docs |