Rsr2425 commited on
Commit
b22f9a0
·
1 Parent(s): fafed13

Added topics endpoint

Browse files
backend/app/main.py CHANGED
@@ -10,6 +10,7 @@ import asyncio
10
  import logging
11
  import os
12
  from backend.app.crawler import DomainCrawler
 
13
 
14
  app = FastAPI()
15
 
@@ -41,8 +42,12 @@ class FeedbackResponse(BaseModel):
41
  feedback: List[str]
42
 
43
 
44
- @app.post("/api/crawl/")
45
- async def crawl_documentation(input_data: UrlInput):
 
 
 
 
46
  print(f"Received url {input_data.url}")
47
  return {"status": "received"}
48
 
@@ -85,6 +90,12 @@ async def get_feedback(request: FeedbackRequest):
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
87
 
 
 
 
 
 
 
88
  # Serve static files
89
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
90
 
 
10
  import logging
11
  import os
12
  from backend.app.crawler import DomainCrawler
13
+ from backend.app.vectorstore import get_all_unique_source_of_docs_in_collection_DUMB
14
 
15
  app = FastAPI()
16
 
 
42
  feedback: List[str]
43
 
44
 
45
+ class TopicsResponse(BaseModel):
46
+ sources: List[str]
47
+
48
+
49
+ @app.post("/api/ingest/")
50
+ async def ingest_documentation(input_data: UrlInput):
51
  print(f"Received url {input_data.url}")
52
  return {"status": "received"}
53
 
 
90
  raise HTTPException(status_code=500, detail=str(e))
91
 
92
 
93
+ @app.post("/api/topics", response_model=TopicsResponse)
94
+ async def get_topics():
95
+ sources = get_all_unique_source_of_docs_in_collection_DUMB()
96
+ return {"sources": sources}
97
+
98
+
99
  # Serve static files
100
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
101
 
backend/app/vectorstore.py CHANGED
@@ -43,38 +43,6 @@ _embedding_model: Optional[Union[OpenAIEmbeddings, HuggingFaceEmbeddings]] = Non
43
  _embedding_model_id: str = None
44
 
45
 
46
- def _get_qdrant_client():
47
- global _qdrant_client_instance
48
-
49
- if _qdrant_client_instance is None:
50
- if (
51
- os.environ.get("QDRANT_URL") is None
52
- or os.environ.get("QDRANT_API_KEY") is None
53
- ):
54
- logger.warning(
55
- "QDRANT_URL or QDRANT_API_KEY is not set. Defaulting to local memory vector store."
56
- )
57
-
58
- os.makedirs(LOCAL_QDRANT_PATH, exist_ok=True)
59
- _qdrant_client_instance = QdrantClient(path=LOCAL_QDRANT_PATH)
60
- # _qdrant_client_instance = QdrantClient(":memory:")
61
- return _qdrant_client_instance
62
-
63
- logger.info(
64
- f"Attempting to connect to Qdrant at {os.environ.get("QDRANT_URL")}"
65
- )
66
- try:
67
- _qdrant_client_instance = QdrantClient(
68
- url=os.environ.get("QDRANT_URL"),
69
- api_key=os.environ.get("QDRANT_API_KEY"),
70
- )
71
- logger.info("Successfully connected to Qdrant Cloud")
72
- except Exception as e:
73
- logger.error(f"Failed to connect to Qdrant Cloud: {str(e)}")
74
- raise e
75
- return _qdrant_client_instance
76
-
77
-
78
  def _initialize_vector_db():
79
  os.makedirs("static/data", exist_ok=True)
80
 
@@ -112,10 +80,44 @@ def _initialize_vector_db():
112
  )
113
 
114
 
115
- def get_all_unique_source_docs_in_collection(
116
- collection_name: str, client: QdrantClient, limit: int = 1000, offset: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  ) -> List[Document]:
118
- response = client.scroll(
119
  collection_name=collection_name,
120
  limit=limit,
121
  offset=offset,
@@ -128,7 +130,7 @@ def get_all_unique_source_docs_in_collection(
128
  if "source" in point.payload:
129
  result.add(point.payload["source"])
130
  offset = response[1]
131
- response = client.scroll(
132
  collection_name=collection_name,
133
  limit=limit,
134
  offset=offset + limit,
@@ -136,6 +138,23 @@ def get_all_unique_source_docs_in_collection(
136
  return list(result)
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def store_documents(
140
  documents: List[Document],
141
  collection_name: str,
@@ -145,7 +164,7 @@ def store_documents(
145
  assert _vector_db_instance is not None, "Vector database instance not initialized"
146
 
147
  embedding_model = get_embedding_model(embedding_model_id)
148
- client = _get_qdrant_client()
149
 
150
  _vector_db_instance.add_documents(
151
  documents=documents,
@@ -181,7 +200,7 @@ def get_vector_db(embedding_model_id: str = None) -> QdrantVectorStore:
181
  need_to_initialize_db = False
182
  embedding_model = get_embedding_model(embedding_model_id)
183
 
184
- client = _get_qdrant_client()
185
 
186
  if not check_collection_exists(client, PROBLEMS_REFERENCE_COLLECTION_NAME):
187
  client.create_collection(
 
43
  _embedding_model_id: str = None
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def _initialize_vector_db():
47
  os.makedirs("static/data", exist_ok=True)
48
 
 
80
  )
81
 
82
 
83
+ def get_qdrant_client():
84
+ global _qdrant_client_instance
85
+
86
+ if _qdrant_client_instance is None:
87
+ if (
88
+ os.environ.get("QDRANT_URL") is None
89
+ or os.environ.get("QDRANT_API_KEY") is None
90
+ ):
91
+ logger.warning(
92
+ "QDRANT_URL or QDRANT_API_KEY is not set. Defaulting to local memory vector store."
93
+ )
94
+
95
+ os.makedirs(LOCAL_QDRANT_PATH, exist_ok=True)
96
+ _qdrant_client_instance = QdrantClient(path=LOCAL_QDRANT_PATH)
97
+ # _qdrant_client_instance = QdrantClient(":memory:")
98
+ return _qdrant_client_instance
99
+
100
+ logger.info(
101
+ f"Attempting to connect to Qdrant at {os.environ.get("QDRANT_URL")}"
102
+ )
103
+ try:
104
+ _qdrant_client_instance = QdrantClient(
105
+ url=os.environ.get("QDRANT_URL"),
106
+ api_key=os.environ.get("QDRANT_API_KEY"),
107
+ )
108
+ logger.info("Successfully connected to Qdrant Cloud")
109
+ except Exception as e:
110
+ logger.error(f"Failed to connect to Qdrant Cloud: {str(e)}")
111
+ raise e
112
+ return _qdrant_client_instance
113
+
114
+
115
+ def get_all_unique_source_of_docs_in_collection(
116
+ collection_name: str = PROBLEMS_REFERENCE_COLLECTION_NAME,
117
+ limit: int = 1000,
118
+ offset: int = 0,
119
  ) -> List[Document]:
120
+ response = get_qdrant_client().scroll(
121
  collection_name=collection_name,
122
  limit=limit,
123
  offset=offset,
 
130
  if "source" in point.payload:
131
  result.add(point.payload["source"])
132
  offset = response[1]
133
+ response = get_qdrant_client().scroll(
134
  collection_name=collection_name,
135
  limit=limit,
136
  offset=offset + limit,
 
138
  return list(result)
139
 
140
 
141
+ # TODO This is a dumb hack to get around Qdrant client restrictions when using local file storage.
142
+ # Instead of using the client directly, we use QdrantVectorStore's similarity search
143
+ # with a dummy query to get all documents, then extract unique sources.
144
+ def get_all_unique_source_of_docs_in_collection_DUMB(
145
+ collection_name: str = PROBLEMS_REFERENCE_COLLECTION_NAME,
146
+ ) -> List[str]:
147
+ vector_store = get_vector_db()
148
+ # Use a very generic query that should match everything
149
+ docs = vector_store.similarity_search("",k=1000)
150
+
151
+ sources = set()
152
+ for doc in docs:
153
+ if doc.metadata and "title" in doc.metadata:
154
+ sources.add(doc.metadata["title"])
155
+ return list(sources)
156
+
157
+
158
  def store_documents(
159
  documents: List[Document],
160
  collection_name: str,
 
164
  assert _vector_db_instance is not None, "Vector database instance not initialized"
165
 
166
  embedding_model = get_embedding_model(embedding_model_id)
167
+ client = get_qdrant_client()
168
 
169
  _vector_db_instance.add_documents(
170
  documents=documents,
 
200
  need_to_initialize_db = False
201
  embedding_model = get_embedding_model(embedding_model_id)
202
 
203
+ client = get_qdrant_client()
204
 
205
  if not check_collection_exists(client, PROBLEMS_REFERENCE_COLLECTION_NAME):
206
  client.create_collection(
backend/tests/test_api.py CHANGED
@@ -6,7 +6,7 @@ client = TestClient(app)
6
 
7
 
8
  def test_crawl_endpoint():
9
- response = client.post("/api/crawl/", json={"url": "https://example.com"})
10
  assert response.status_code == 200
11
  assert response.json() == {"status": "received"}
12
 
@@ -61,3 +61,14 @@ def test_successful_feedback():
61
  for feedback in result["feedback"]:
62
  assert feedback.strip().startswith(("Correct", "Incorrect"))
63
  assert len(feedback.split(". ")) >= 2
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def test_crawl_endpoint():
9
+ response = client.post("/api/ingest/", json={"url": "https://example.com"})
10
  assert response.status_code == 200
11
  assert response.json() == {"status": "received"}
12
 
 
61
  for feedback in result["feedback"]:
62
  assert feedback.strip().startswith(("Correct", "Incorrect"))
63
  assert len(feedback.split(". ")) >= 2
64
+
65
+
66
+ def test_topics_endpoint():
67
+ """Test that topics endpoint returns expected sources"""
68
+ response = client.post("/api/topics")
69
+ assert response.status_code == 200
70
+ result = response.json()
71
+
72
+ assert "sources" in result
73
+ assert len(result["sources"]) == 1
74
+ assert result["sources"][0] == "LangChain RAG Tutorial"
backend/tests/test_vectorstore.py CHANGED
@@ -4,7 +4,7 @@ import pytest
4
  import requests
5
 
6
  from langchain.schema import Document
7
- from backend.app.vectorstore import get_vector_db, _get_qdrant_client
8
 
9
 
10
  def test_directory_creation():
@@ -72,7 +72,7 @@ def test_qdrant_cloud_connection():
72
  print(f"Port: {parsed_url.port}")
73
  print(f"Path: {parsed_url.path}")
74
 
75
- client = _get_qdrant_client()
76
  client.get_collections()
77
  assert True, "Connection successful"
78
  except Exception as e:
 
4
  import requests
5
 
6
  from langchain.schema import Document
7
+ from backend.app.vectorstore import get_vector_db, get_qdrant_client
8
 
9
 
10
  def test_directory_creation():
 
72
  print(f"Port: {parsed_url.port}")
73
  print(f"Path: {parsed_url.path}")
74
 
75
+ client = get_qdrant_client()
76
  client.get_collections()
77
  assert True, "Connection successful"
78
  except Exception as e: