File size: 1,765 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from sentence_transformers import SentenceTransformer  # pylint: disable=C0413

from abc import ABCMeta, abstractmethod


class BaseEmbedding(metaclass=ABCMeta):
    """
    Base Embedding interface.
    """

    @abstractmethod
    def to_embeddings(self, data, **kwargs):
        pass

    @property
    @abstractmethod
    def dimension(self) -> int:
        return 0



class SBERT(BaseEmbedding):
    """Generate sentence embedding for given text using pretrained models of Sentence Transformers.

    :param model: model name, defaults to 'all-MiniLM-L6-v2'.
    :type model: str

    Example:
        .. code-block:: python

            from gptcache.embedding import SBERT

            test_sentence = 'Hello, world.'
            encoder = SBERT('all-MiniLM-L6-v2')
            embed = encoder.to_embeddings(test_sentence)
    """

    def __init__(self, model: str = "all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model)
        self.model.eval()
        self.__dimension = None

    def to_embeddings(self, data, **_):
        """Generate embedding given text input

        :param data: text in string.
        :type data: str

        :return: a text embedding in shape of (dim,).
        """
        if not isinstance(data, list):
            data = [data]
        emb = self.model.encode(data)
        _, dim = emb.shape
        if not self.__dimension:
            self.__dimension = dim
        return np.array(emb).astype("float32")

    @property
    def dimension(self):
        """Embedding dimension.

        :return: embedding dimension
        """
        if not self.__dimension:
           embd = self.model.encode(["foo"])
           _, self.__dimension = embd.shape
        return self.__dimension