File size: 2,964 Bytes
bd3532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List, Dict, Union
from ..embedding_provider import EmbeddingProvider
import numpy as np

class SentenceTransformerEmbedding(EmbeddingProvider):
    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        device: str = None,
        batch_size: int = 32,
        normalize_embeddings: bool = True
    ) -> None:
        """Initialize sentence transformer embedding provider

        Args:
            model_name (str, optional): Name of the sentence tranformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2".
        """
        from sentence_transformers import SentenceTransformer
        
        if device is None:
            import torch
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.model = SentenceTransformer(model_name, device=device)
        self.model_name = model_name
        self.batch_size = batch_size
        self.normalize_embeddings = normalize_embeddings
        
    def embed_documents(self, documents: List[str]) -> np.ndarray:
        """Embed a list of documents

        Args:
            documents (List[str]): List of documents to embed
        """
        return self.model.encode(
            documents,
            batch_size=self.batch_size,
            normalize_embeddings=self.normalize_embeddings
        )
        
    def embed_query(self, query: str) -> np.ndarray:
        """Embed a single query

        Args:
            query (str): Query to embed

        Returns:
            np.ndarray: Embedding vector
        """
        return self.model.encode(
            query, 
            normalize_embeddings=self.normalize_embeddings
        )
        
    def get_model_info(self) -> Dict[str, Union[str, int]]:
        """
        Retrieve information about the current embedding model
        
        Returns:
            Dict: Model information
        """
        return {
            "model_name": self.model_name,
            "device": self.device,
            "batch_size": self.batch_size,
            "normalize_embeddings": self.normalize_embeddings,
            "embedding_dim": self.model.get_sentence_embedding_dimension()
        }

    def list_available_models(self) -> List[str]:
        """
        List some popular Sentence Transformer models
        
        Returns:
            List[str]: Available model names
        """
        popular_models = [
            "sentence-transformers/all-MiniLM-L6-v2",  # Small and fast
            "sentence-transformers/all-mpnet-base-v2",  # High performance
            "sentence-transformers/all-distilroberta-v1",  # Lightweight
            "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",  # Question Answering
            "sentence-transformers/multi-qa-mpnet-base-cos-v1",  # Multilingual QA
            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"  # Multilingual
        ]
        return popular_models