Spaces:
Sleeping
Sleeping
File size: 4,500 Bytes
a54024a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import abc
from typing import List, Union
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer
from .type_aliases import ENCODER_DEVICE_TYPE
class Encoder(abc.ABC):
@abc.abstractmethod
def encode(self, prediction: List[str]) -> NDArray:
"""
Abstract method to encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
class SBertEncoder(Encoder):
def __init__(self, model: SentenceTransformer, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
"""
Initialize SBertEncoder instance.
Args:
model (SentenceTransformer): The Sentence Transformer model instance to use for encoding.
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
batch_size (int): Batch size for encoding.
verbose (bool): Whether to print verbose information during encoding.
"""
self.model = model
self.device = device
self.batch_size = batch_size
self.verbose = verbose
def encode(self, prediction: List[str]) -> NDArray:
"""
Encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
"""
# SBert output is always Batch x Dim
if isinstance(self.device, list):
# Use multiprocess encoding for list of devices
pool = self.model.start_multi_process_pool(target_devices=self.device)
embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
self.model.stop_multi_process_pool(pool)
else:
# Single device encoding
embeddings = self.model.encode(
prediction,
device=self.device,
batch_size=self.batch_size,
show_progress_bar=self.verbose,
)
return embeddings
def get_encoder(
sbert_model: SentenceTransformer,
device: ENCODER_DEVICE_TYPE,
batch_size: int,
verbose: bool,
) -> Encoder:
"""
Get an instance of SBertEncoder using the provided parameters.
Args:
sbert_model (SentenceTransformer): An instance of SentenceTransformer model to use for encoding.
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
(e.g., "cuda", 0 for GPU, "cpu").
batch_size (int): Batch size to use for encoding.
verbose (bool): Whether to print verbose information during encoding.
Returns:
SBertEncoder: Instance of the selected encoder based on the model_name.
Example:
>>> model_name = "paraphrase-distilroberta-base-v1"
>>> sbert_model = get_sbert_encoder(model_name)
>>> device = get_gpu("cuda")
>>> batch_size = 32
>>> verbose = True
>>> encoder = get_encoder(sbert_model, device, batch_size, verbose)
"""
encoder = SBertEncoder(sbert_model, device, batch_size, verbose)
return encoder
def get_sbert_encoder(model_name: str) -> SentenceTransformer:
"""
Get an instance of SentenceTransformer encoder based on the specified model name.
Args:
model_name (str): Name of the model to instantiate. You can use any model on Huggingface/SentenceTransformer
that is supported by SentenceTransformer.
Returns:
SentenceTransformer: Instance of the selected encoder based on the model_name.
Raises:
EnvironmentError: If an unsupported model_name is provided.
RuntimeError: If there's an issue during instantiation of the encoder.
"""
try:
encoder = SentenceTransformer(model_name, trust_remote_code=True)
except EnvironmentError as err:
raise EnvironmentError(str(err)) from None
except Exception as err:
raise RuntimeError(str(err)) from None
return encoder
|