Spaces:
Running
Running
File size: 2,673 Bytes
f2932e2 51a7f02 9c7a6f3 f2932e2 1bec7d8 51a7f02 f2932e2 c7e10e4 9c7a6f3 1bec7d8 f2932e2 9c7a6f3 f2932e2 9c7a6f3 f2932e2 9c7a6f3 f2932e2 12a040e 9c7a6f3 1bec7d8 12a040e f2932e2 1bec7d8 f2932e2 1bec7d8 12a040e f2932e2 1bec7d8 f2932e2 1bec7d8 f2932e2 12a040e 1bec7d8 f2932e2 1bec7d8 f2932e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from abc import abstractmethod
import os
from qdrant_client import QdrantClient
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
from qdrant_client.models import VectorParams, Distance
from db.embedding import Embedding, EMBEDDINGS
class ToyVectorStore:
@staticmethod
def get_embedding():
embedding = os.getenv("EMBEDDING")
if not embedding:
return EMBEDDINGS["OPEN_AI"]
return EMBEDDINGS[embedding]
@staticmethod
def get_instance():
vector_store = os.getenv("STORE")
if vector_store == "ELASTIC":
return ElasticVectorStore(ToyVectorStore.get_embedding())
elif vector_store == "QDRANT":
return QdrantVectorStore(ToyVectorStore.get_embedding())
else:
raise ValueError(f"Invalid vector store {vector_store}")
def __init__(self, embedding: Embedding):
self.embedding = embedding
@abstractmethod
def get_collection(self, collection: str="test") -> VectorStore:
"""
get an instance of vector store
of collection
"""
pass
@abstractmethod
def create_collection(self, collection: str) -> None:
"""
create an instance of vector store
with collection name
"""
pass
class ElasticVectorStore(ToyVectorStore):
def __init__(self, embeddings):
super().__init__(embeddings)
def get_collection(self, collection:str) -> ElasticVectorSearch:
return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
index_name= collection, embedding=self.embedding.embedding)
def create_collection(self, collection: str) -> None:
store = self.get_collection(collection)
store.create_index(store.client,collection, dict())
class QdrantVectorStore(ToyVectorStore):
def __init__(self, embeddings):
super().__init__(embeddings)
self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"))
def get_collection(self, collection: str) -> Qdrant:
return Qdrant(client=self.client,collection_name=collection,
embeddings=self.embedding.embedding)
def create_collection(self, collection: str) -> None:
self.client.create_collection(collection_name=collection,
vectors_config=VectorParams(size=self.embedding.dimension,
distance=Distance.COSINE))
|