playmak3r commited on
Commit
254f7e7
·
1 Parent(s): 0e5dcd2

fix similarity output format

Browse files
Files changed (1) hide show
  1. similarity.py +1 -1
similarity.py CHANGED
@@ -20,7 +20,7 @@ def get_similarity_batched(texts1: List[str], texts2: List[str]):
20
  embeddings1 = st_model.encode(texts1, convert_to_tensor=True, show_progress_bar=False)
21
  embeddings2 = st_model.encode(texts2, convert_to_tensor=True, show_progress_bar=False)
22
  cosine_scores = util.cos_sim(embeddings1, embeddings2)
23
- return cosine_scores.diag()
24
 
25
  def clean_text_batch(texts1: List[str], texts2: List[str]):
26
  if len(texts1) == len(texts2):
 
20
  embeddings1 = st_model.encode(texts1, convert_to_tensor=True, show_progress_bar=False)
21
  embeddings2 = st_model.encode(texts2, convert_to_tensor=True, show_progress_bar=False)
22
  cosine_scores = util.cos_sim(embeddings1, embeddings2)
23
+ return cosine_scores.diag().tolist()
24
 
25
  def clean_text_batch(texts1: List[str], texts2: List[str]):
26
  if len(texts1) == len(texts2):