Spaces:
Sleeping
Sleeping
from typing import List, Dict, Union | |
from ..embedding_provider import EmbeddingProvider | |
import numpy as np | |
class SentenceTransformerEmbedding(EmbeddingProvider): | |
def __init__( | |
self, | |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | |
device: str = None, | |
batch_size: int = 32, | |
normalize_embeddings: bool = True | |
) -> None: | |
"""Initialize sentence transformer embedding provider | |
Args: | |
model_name (str, optional): Name of the sentence tranformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". | |
""" | |
from sentence_transformers import SentenceTransformer | |
if device is None: | |
import torch | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = SentenceTransformer(model_name, device=device) | |
self.model_name = model_name | |
self.batch_size = batch_size | |
self.normalize_embeddings = normalize_embeddings | |
def embed_documents(self, documents: List[str]) -> np.ndarray: | |
"""Embed a list of documents | |
Args: | |
documents (List[str]): List of documents to embed | |
""" | |
return self.model.encode( | |
documents, | |
batch_size=self.batch_size, | |
normalize_embeddings=self.normalize_embeddings | |
) | |
def embed_query(self, query: str) -> np.ndarray: | |
"""Embed a single query | |
Args: | |
query (str): Query to embed | |
Returns: | |
np.ndarray: Embedding vector | |
""" | |
return self.model.encode( | |
query, | |
normalize_embeddings=self.normalize_embeddings | |
) | |
def get_model_info(self) -> Dict[str, Union[str, int]]: | |
""" | |
Retrieve information about the current embedding model | |
Returns: | |
Dict: Model information | |
""" | |
return { | |
"model_name": self.model_name, | |
"device": self.device, | |
"batch_size": self.batch_size, | |
"normalize_embeddings": self.normalize_embeddings, | |
"embedding_dim": self.model.get_sentence_embedding_dimension() | |
} | |
def list_available_models(self) -> List[str]: | |
""" | |
List some popular Sentence Transformer models | |
Returns: | |
List[str]: Available model names | |
""" | |
popular_models = [ | |
"sentence-transformers/all-MiniLM-L6-v2", # Small and fast | |
"sentence-transformers/all-mpnet-base-v2", # High performance | |
"sentence-transformers/all-distilroberta-v1", # Lightweight | |
"sentence-transformers/multi-qa-MiniLM-L6-cos-v1", # Question Answering | |
"sentence-transformers/multi-qa-mpnet-base-cos-v1", # Multilingual QA | |
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Multilingual | |
] | |
return popular_models | |