Spaces:
Runtime error
Runtime error
File size: 2,424 Bytes
d23aa66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import datasets
import evaluate
from typing import List, Union
import torch
import torch.nn.functional as F
import torch.nn as nn
_DESCRIPTION = """
Cosine similarity between two pairs of embeddings where each embedding represents the semantics of object .
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of a list of `int`): a group of embeddings
references (`list` of `int`): the other group of embeddings paired with the predictions
Returns:
cos_similarity ("float") : average cosine similarity between two pairs of embeddings
Examples:
Example 1-A simple example
>>> cos_similarity_metrics = evaluate.load("ahnyeonchan/cosine_sim_btw_embeddings_of_same_semantics")
>>> results = accuracy_metric.compute(references=[[1.0, 1.0], [0.0, 1.0]], predictions=[[1.0, 1.0], [0.0, 1.0]])
>>> print(results)
{'cos_similarity': 1.0}
"""
_CITATION = """"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class CosSim(evaluate.Metric):
def __init__(self, *args, **kwargs):
super(CosSim, self).__init__(*args, **kwargs)
self.cossim = nn.CosineSimilarity()
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("float32")),
"references": datasets.Sequence(datasets.Value("float32")),
}
),
reference_urls=[],
)
def _compute(self, predictions: List[List], references: List[List]):
if isinstance(predictions, torch.Tensor):
predictions = torch.Tensor(predictions)
elif isinstance(predictions, list):
predictions = torch.Tensor(predictions)
else:
raise NotImplementedError()
if isinstance(references, torch.Tensor):
references = torch.Tensor(references)
elif isinstance(references, list):
references = torch.Tensor(references)
else:
raise NotImplementedError()
cosims = self.cossim(predictions, references)
val = torch.mean(cossim).item()
return {
"cos_similarity": float(val)
}
|