File size: 1,092 Bytes
8a9d0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import datasets
import numpy as np
import spaces
from sentence_transformers import CrossEncoder, SentenceTransformer

from table import BASE_REPO_ID

ds = datasets.load_dataset(BASE_REPO_ID, split="train")
ds.add_faiss_index(column="embedding")

bi_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
ce_model = CrossEncoder("BAAI/bge-reranker-base")


@spaces.GPU(duration=10)
def search(query: str, candidate_pool_size: int = 100, retrieval_k: int = 50) -> list[dict]:
    prefix = "Represent this sentence for searching relevant passages: "
    q_vec = bi_model.encode(prefix + query, normalize_embeddings=True)

    _, retrieved_ds = ds.get_nearest_examples("embedding", q_vec, k=candidate_pool_size)

    ce_inputs = [
        (query, f"{retrieved_ds['title'][i]} {retrieved_ds['abstract'][i]}") for i in range(len(retrieved_ds["title"]))
    ]
    ce_scores = ce_model.predict(ce_inputs, batch_size=16)

    sorted_idx = np.argsort(ce_scores)[::-1]
    return [
        {"paper_id": retrieved_ds["paper_id"][i], "ce_score": float(ce_scores[i])} for i in sorted_idx[:retrieval_k]
    ]