Spaces:
Running
Running
File size: 556 Bytes
bdafe83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import numpy as np
import torch
def pairwise_cos_sim(A: torch.Tensor, B: torch.Tensor, device="cuda:0"):
if isinstance(A, np.ndarray):
A = torch.from_numpy(A).to(device)
if isinstance(B, np.ndarray):
B = torch.from_numpy(B).to(device)
from torch.nn.functional import normalize
A_norm = normalize(A, dim=1) # Normalize the rows of A
B_norm = normalize(B, dim=1) # Normalize the rows of B
cos_sim = torch.matmul(A_norm,
B_norm.t()) # Calculate the cosine similarity
return cos_sim |