File size: 2,673 Bytes
f2932e2
51a7f02
 
9c7a6f3
f2932e2
 
1bec7d8
51a7f02
 
f2932e2
c7e10e4
9c7a6f3
 
 
1bec7d8
 
 
 
f2932e2
 
 
9c7a6f3
f2932e2
9c7a6f3
f2932e2
9c7a6f3
f2932e2
 
12a040e
9c7a6f3
1bec7d8
 
12a040e
f2932e2
 
 
 
 
 
 
1bec7d8
f2932e2
 
 
 
 
 
 
 
 
1bec7d8
 
12a040e
f2932e2
 
1bec7d8
 
f2932e2
 
 
 
 
 
 
1bec7d8
 
f2932e2
 
 
 
12a040e
1bec7d8
f2932e2
 
 
1bec7d8
 
f2932e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from abc import abstractmethod
import os
from qdrant_client import QdrantClient
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
from qdrant_client.models import VectorParams, Distance
from db.embedding import Embedding, EMBEDDINGS


class ToyVectorStore:

    @staticmethod
    def get_embedding():
        embedding = os.getenv("EMBEDDING")
        if not embedding:
            return EMBEDDINGS["OPEN_AI"]
        return EMBEDDINGS[embedding]
    
    @staticmethod
    def get_instance():
        vector_store = os.getenv("STORE")

        if vector_store == "ELASTIC":
            return ElasticVectorStore(ToyVectorStore.get_embedding())
        elif vector_store == "QDRANT":
            return QdrantVectorStore(ToyVectorStore.get_embedding())
        else:
            raise ValueError(f"Invalid vector store {vector_store}")
    

    def __init__(self, embedding: Embedding):
        self.embedding = embedding

    @abstractmethod
    def get_collection(self, collection: str="test") -> VectorStore:
        """
        get an instance of vector store
        of collection
        """
        pass

    @abstractmethod
    def create_collection(self, collection: str) -> None:
        """
        create an instance of vector store
        with collection name
        """
        pass

class ElasticVectorStore(ToyVectorStore):
    def __init__(self, embeddings):
        super().__init__(embeddings)

    def get_collection(self, collection:str) -> ElasticVectorSearch:
        return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
                               index_name= collection, embedding=self.embedding.embedding)
    
    def create_collection(self, collection: str) -> None:
        store = self.get_collection(collection)
        store.create_index(store.client,collection, dict())


class QdrantVectorStore(ToyVectorStore):

    def __init__(self, embeddings):
        super().__init__(embeddings)
        self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
                                        api_key=os.getenv("QDRANT_API_KEY"))

    def get_collection(self, collection: str) -> Qdrant:  
        return Qdrant(client=self.client,collection_name=collection,
                      embeddings=self.embedding.embedding)

    def create_collection(self, collection: str) -> None:
        self.client.create_collection(collection_name=collection, 
                        vectors_config=VectorParams(size=self.embedding.dimension, 
                                                    distance=Distance.COSINE))