File size: 553 Bytes
c542da9
 
 
 
 
 
a2e05fb
c542da9
 
 
a2e05fb
 
c542da9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from sentence_transformers import SentenceTransformer, util

class Mapper:
  def __init__(self, repo: str, model: str):
    self.__model = SentenceTransformer(f"{repo}/{model}")

  def __call__(self, query: str, data: list[str]):
    query_emb = self.__model.encode(query)
    data_emb = self.__model.encode(data)

    scores: list[float] = util.dot_score(query_emb, data_emb)[0].cpu().tolist()
    data_score_pairs = list(zip(range(0, len(data)), scores))

    return sorted(
      data_score_pairs,
      key=lambda x: x[1], 
      reverse=True,
    )