janar commited on
Commit
12a040e
·
1 Parent(s): f2932e2

can add different vector stores

Browse files
Files changed (1) hide show
  1. api/db/vector_store.py +9 -5
api/db/vector_store.py CHANGED
@@ -1,11 +1,10 @@
1
  from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
4
- from langchain.embeddings import OpenAIEmbeddings
5
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
6
  from qdrant_client.models import VectorParams, Distance
7
 
8
- embeddings = OpenAIEmbeddings()
9
 
10
  class ToyVectorStore:
11
 
@@ -18,7 +17,10 @@ class ToyVectorStore:
18
  return QdrantVectorStore()
19
  else:
20
  raise ValueError(f"Invalid vector store {vector_store}")
21
-
 
 
 
22
  @abstractmethod
23
  def get_collection(self, collection: str="test") -> VectorStore:
24
  """
@@ -36,9 +38,10 @@ class ToyVectorStore:
36
  pass
37
 
38
  class ElasticVectorStore(ToyVectorStore):
 
39
  def get_collection(self, collection:str) -> ElasticVectorSearch:
40
  return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
41
- index_name= collection, embedding=embeddings)
42
 
43
  def create_collection(self, collection: str) -> None:
44
  store = self.get_collection(collection)
@@ -52,7 +55,8 @@ class QdrantVectorStore(ToyVectorStore):
52
  api_key=os.getenv("QDRANT_API_KEY"))
53
 
54
  def get_collection(self, collection: str) -> Qdrant:
55
- return Qdrant(client=self.client,collection_name=collection,embeddings=embeddings)
 
56
 
57
  def create_collection(self, collection: str) -> None:
58
  self.client.create_collection(collection_name=collection,
 
1
  from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
4
+ from langchain.embeddings import OpenAIEmbeddings, ElasticsearchEmbeddings
5
  from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
6
  from qdrant_client.models import VectorParams, Distance
7
 
 
8
 
9
  class ToyVectorStore:
10
 
 
17
  return QdrantVectorStore()
18
  else:
19
  raise ValueError(f"Invalid vector store {vector_store}")
20
+
21
+ def __init__(self):
22
+ self.embeddings = OpenAIEmbeddings()
23
+
24
  @abstractmethod
25
  def get_collection(self, collection: str="test") -> VectorStore:
26
  """
 
38
  pass
39
 
40
  class ElasticVectorStore(ToyVectorStore):
41
+
42
  def get_collection(self, collection:str) -> ElasticVectorSearch:
43
  return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
44
+ index_name= collection, embedding=self.embeddings)
45
 
46
  def create_collection(self, collection: str) -> None:
47
  store = self.get_collection(collection)
 
55
  api_key=os.getenv("QDRANT_API_KEY"))
56
 
57
  def get_collection(self, collection: str) -> Qdrant:
58
+ return Qdrant(client=self.client,collection_name=collection,
59
+ embeddings=self.embeddings)
60
 
61
  def create_collection(self, collection: str) -> None:
62
  self.client.create_collection(collection_name=collection,