sparksearch-demo / SmartSearch /providers /SentenceTransformerEmbedding.py
teddyllm's picture
Upload 20 files
bd3532f verified
from typing import List, Dict, Union
from ..embedding_provider import EmbeddingProvider
import numpy as np
class SentenceTransformerEmbedding(EmbeddingProvider):
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: str = None,
batch_size: int = 32,
normalize_embeddings: bool = True
) -> None:
"""Initialize sentence transformer embedding provider
Args:
model_name (str, optional): Name of the sentence tranformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2".
"""
from sentence_transformers import SentenceTransformer
if device is None:
import torch
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, device=device)
self.model_name = model_name
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings
def embed_documents(self, documents: List[str]) -> np.ndarray:
"""Embed a list of documents
Args:
documents (List[str]): List of documents to embed
"""
return self.model.encode(
documents,
batch_size=self.batch_size,
normalize_embeddings=self.normalize_embeddings
)
def embed_query(self, query: str) -> np.ndarray:
"""Embed a single query
Args:
query (str): Query to embed
Returns:
np.ndarray: Embedding vector
"""
return self.model.encode(
query,
normalize_embeddings=self.normalize_embeddings
)
def get_model_info(self) -> Dict[str, Union[str, int]]:
"""
Retrieve information about the current embedding model
Returns:
Dict: Model information
"""
return {
"model_name": self.model_name,
"device": self.device,
"batch_size": self.batch_size,
"normalize_embeddings": self.normalize_embeddings,
"embedding_dim": self.model.get_sentence_embedding_dimension()
}
def list_available_models(self) -> List[str]:
"""
List some popular Sentence Transformer models
Returns:
List[str]: Available model names
"""
popular_models = [
"sentence-transformers/all-MiniLM-L6-v2", # Small and fast
"sentence-transformers/all-mpnet-base-v2", # High performance
"sentence-transformers/all-distilroberta-v1", # Lightweight
"sentence-transformers/multi-qa-MiniLM-L6-cos-v1", # Question Answering
"sentence-transformers/multi-qa-mpnet-base-cos-v1", # Multilingual QA
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Multilingual
]
return popular_models