Rsr2425 commited on
Commit
45884d3
·
1 Parent(s): a6dd268

Updated Qdrant code to use Qdrant Cloud (untested)

Browse files
backend/app/vectorstore.py CHANGED
@@ -8,44 +8,156 @@ import os
8
  import requests
9
  import nltk
10
  import logging
11
- from typing import Optional
 
 
 
12
  from langchain_community.vectorstores import Qdrant
13
  from langchain_openai.embeddings import OpenAIEmbeddings
14
  from langchain_community.document_loaders import DirectoryLoader
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain_huggingface import HuggingFaceEmbeddings
17
  from qdrant_client import QdrantClient
 
 
 
 
 
 
 
18
 
19
  nltk.download("punkt_tab")
20
  nltk.download("averaged_perceptron_tagger_eng")
21
 
22
  DEFAULT_EMBEDDING_MODEL_ID = "text-embedding-3-small"
 
 
 
23
  LOCAL_QDRANT_PATH = "/data/qdrant_db"
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
  # Global variable to store the singleton instance
 
28
  _vector_db_instance: Optional[Qdrant] = None
29
  # TODO fix bug. There's a logical error where if you change the embedding model, the vector db instance might not updated
30
  # to match the new embedding model.
31
  _embedding_model_id: str = None
32
 
33
 
34
- def get_qdrant_client():
35
- if os.environ.get("QDRANT_URL") is None or os.environ.get("QDRANT_API_KEY") is None:
36
- logger.error(
37
- "QDRANT_URL or QDRANT_API_KEY is not set. Defaulting to local memory vector store."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
- os.makedirs(LOCAL_QDRANT_PATH, exist_ok=True)
41
- return QdrantClient(path=LOCAL_QDRANT_PATH)
42
-
43
- QDRANT_URL = os.environ.get("QDRANT_URL")
44
- QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
45
-
46
- return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
47
 
 
 
 
 
48
 
 
49
  def get_vector_db(embedding_model_id: str = None) -> Qdrant:
50
  """
51
  Factory function that returns a singleton instance of the vector database.
@@ -54,40 +166,21 @@ def get_vector_db(embedding_model_id: str = None) -> Qdrant:
54
  global _vector_db_instance
55
 
56
  if _vector_db_instance is None:
57
- # Create static/data directory if it doesn't exist
58
- os.makedirs("static/data", exist_ok=True)
59
-
60
- # Download and save the webpage if it doesn't exist
61
- html_path = "static/data/langchain_rag_tutorial.html"
62
- if not os.path.exists(html_path):
63
- url = "https://python.langchain.com/docs/tutorials/rag/"
64
- response = requests.get(url)
65
- with open(html_path, "w", encoding="utf-8") as f:
66
- f.write(response.text)
67
-
68
  embedding_model = None
69
  if embedding_model_id is None:
70
- embedding_model = OpenAIEmbeddings(modzŻel=DEFAULT_EMBEDDING_MODEL_ID)
71
  else:
72
  embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
73
 
74
- # Load HTML files from static/data directory
75
- loader = DirectoryLoader("static/data", glob="*.html")
76
- documents = loader.load()
77
-
78
- # Split documents into chunks
79
- text_splitter = RecursiveCharacterTextSplitter(
80
- chunk_size=1000, chunk_overlap=200
81
- )
82
- split_chunks = text_splitter.split_documents(documents)
83
 
84
- # Create vector store instance
85
- client = get_qdrant_client()
86
- _vector_db_instance = Qdrant.from_documents(
87
- split_chunks,
88
- embedding_model,
89
  client=client,
90
- collection_name="extending_context_window_llama_3",
91
  )
92
 
93
  return _vector_db_instance
 
8
  import requests
9
  import nltk
10
  import logging
11
+ import uuid
12
+ import hashlib
13
+
14
+ from typing import Optional, List
15
  from langchain_community.vectorstores import Qdrant
16
  from langchain_openai.embeddings import OpenAIEmbeddings
17
  from langchain_community.document_loaders import DirectoryLoader
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
19
  from langchain_huggingface import HuggingFaceEmbeddings
20
  from qdrant_client import QdrantClient
21
+ from qdrant_client.models import VectorParams, Distance
22
+ from langchain.schema import Document
23
+ from .vectorstore_helpers import (
24
+ get_document_hash_as_uuid,
25
+ enrich_document_metadata,
26
+ check_collection_exists,
27
+ )
28
 
29
  nltk.download("punkt_tab")
30
  nltk.download("averaged_perceptron_tagger_eng")
31
 
32
  DEFAULT_EMBEDDING_MODEL_ID = "text-embedding-3-small"
33
+ DEFAULT_VECTOR_DIMENSIONS = 1536
34
+ DEFAULT_VECTOR_DISTANCE = Distance.COSINE
35
+ PROBLEMS_REFERENCE_COLLECTION_NAME = "problems_reference_collection"
36
  LOCAL_QDRANT_PATH = "/data/qdrant_db"
37
 
38
  logger = logging.getLogger(__name__)
39
 
40
  # Global variable to store the singleton instance
41
+ _qdrant_client_instance: Optional[QdrantClient] = None
42
  _vector_db_instance: Optional[Qdrant] = None
43
  # TODO fix bug. There's a logical error where if you change the embedding model, the vector db instance might not updated
44
  # to match the new embedding model.
45
  _embedding_model_id: str = None
46
 
47
 
48
+ def _get_qdrant_client():
49
+ global _qdrant_client_instance
50
+
51
+ if _qdrant_client_instance is None:
52
+ if (
53
+ os.environ.get("QDRANT_URL") is None
54
+ or os.environ.get("QDRANT_API_KEY") is None
55
+ ):
56
+ logger.warning(
57
+ "QDRANT_URL or QDRANT_API_KEY is not set. Defaulting to local memory vector store."
58
+ )
59
+
60
+ os.makedirs(LOCAL_QDRANT_PATH, exist_ok=True)
61
+ _qdrant_client_instance = QdrantClient(path=LOCAL_QDRANT_PATH)
62
+
63
+ QDRANT_URL = os.environ.get("QDRANT_URL")
64
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
65
+
66
+ _qdrant_client_instance = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
67
+ return _qdrant_client_instance
68
+
69
+
70
+ def _initialize_vector_db(embedding_model):
71
+ os.makedirs("static/data", exist_ok=True)
72
+
73
+ html_path = "static/data/langchain_rag_tutorial.html"
74
+ if not os.path.exists(html_path):
75
+ url = "https://python.langchain.com/docs/tutorials/rag/"
76
+ response = requests.get(url)
77
+ with open(html_path, "w", encoding="utf-8") as f:
78
+ f.write(response.text)
79
+
80
+ loader = DirectoryLoader("static/data", glob="*.html")
81
+ documents = loader.load()
82
+
83
+ enriched_docs = [
84
+ enrich_document_metadata(
85
+ doc,
86
+ title="LangChain RAG Tutorial",
87
+ type="tutorial",
88
+ source_url="https://python.langchain.com/docs/tutorials/rag/",
89
+ description="Official LangChain tutorial on building RAG applications",
90
+ date_added="2024-03-21",
91
+ category="documentation",
92
+ version="1.0",
93
+ language="en",
94
+ original_source=doc.metadata.get("source"),
95
+ )
96
+ for doc in documents
97
+ ]
98
+
99
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
100
+ split_chunks = text_splitter.split_documents(enriched_docs)
101
+
102
+ client = _get_qdrant_client()
103
+ store_documents(
104
+ split_chunks,
105
+ PROBLEMS_REFERENCE_COLLECTION_NAME,
106
+ client,
107
+ )
108
+
109
+
110
+ def get_all_unique_source_docs_in_collection(
111
+ collection_name: str, client: QdrantClient, limit: int = 1000, offset: int = 0
112
+ ) -> List[Document]:
113
+ response = client.scroll(
114
+ collection_name=collection_name,
115
+ limit=limit,
116
+ offset=offset,
117
+ with_payload=["source"],
118
+ with_vectors=False,
119
+ )
120
+ result = set()
121
+ while len(response[0]) > 0:
122
+ for point in response[0]:
123
+ if "source" in point.payload:
124
+ result.add(point.payload["source"])
125
+ offset = response[1]
126
+ response = client.scroll(
127
+ collection_name=collection_name,
128
+ limit=limit,
129
+ offset=offset + limit,
130
+ )
131
+ return list(result)
132
+
133
+
134
+ def store_documents(
135
+ documents: List[Document],
136
+ collection_name: str,
137
+ client: QdrantClient,
138
+ embedding_model=None,
139
+ ):
140
+ if embedding_model is None:
141
+ embedding_model = OpenAIEmbeddings(model=DEFAULT_EMBEDDING_MODEL_ID)
142
+
143
+ if not check_collection_exists(client, collection_name):
144
+ client.create_collection(
145
+ collection_name,
146
+ vectors_config=VectorParams(
147
+ size=DEFAULT_VECTOR_DIMENSIONS, distance=DEFAULT_VECTOR_DISTANCE
148
+ ),
149
  )
150
 
151
+ vectorstore = Qdrant(
152
+ client=client, collection_name=collection_name, embeddings=embedding_model
153
+ )
 
 
 
 
154
 
155
+ vectorstore.add_documents(
156
+ documents=documents,
157
+ ids=[get_document_hash_as_uuid(doc) for doc in documents],
158
+ )
159
 
160
+ # TODO already probably exposing too much by returning a Qdrant object here
161
  def get_vector_db(embedding_model_id: str = None) -> Qdrant:
162
  """
163
  Factory function that returns a singleton instance of the vector database.
 
166
  global _vector_db_instance
167
 
168
  if _vector_db_instance is None:
 
 
 
 
 
 
 
 
 
 
 
169
  embedding_model = None
170
  if embedding_model_id is None:
171
+ embedding_model = OpenAIEmbeddings(model=DEFAULT_EMBEDDING_MODEL_ID)
172
  else:
173
  embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
174
 
175
+ client = _get_qdrant_client()
176
+ collection_info = client.get_collection(PROBLEMS_REFERENCE_COLLECTION_NAME)
177
+ if collection_info.vectors_count is None or collection_info.vectors_count == 0:
178
+ _initialize_vector_db(embedding_model)
 
 
 
 
 
179
 
180
+ _vector_db_instance = Qdrant.from_existing_collection(
181
+ collection_name=PROBLEMS_REFERENCE_COLLECTION_NAME,
182
+ embedding_model=embedding_model,
 
 
183
  client=client,
 
184
  )
185
 
186
  return _vector_db_instance
backend/app/vectorstore_helpers.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import uuid
3
+
4
+ from langchain.schema import Document
5
+ from qdrant_client import QdrantClient
6
+ from typing import List
7
+
8
+
9
+ def check_collection_exists(client: QdrantClient, collection_name: str) -> bool:
10
+ """Check if a collection exists in Qdrant."""
11
+ return client.get_collection(collection_name) is not None
12
+
13
+
14
+ def get_document_hash_as_uuid(doc):
15
+ content_hash = hashlib.sha256(doc.page_content.encode()).hexdigest()
16
+ uuid_from_hash = uuid.UUID(content_hash[:32])
17
+ return str(uuid_from_hash)
18
+
19
+
20
+ def enrich_document_metadata(doc: Document, **additional_metadata) -> Document:
21
+ doc.metadata.update(additional_metadata)
22
+ return doc