Rsr2425 commited on
Commit
1ef298a
·
1 Parent(s): eae1098

Did some refactoring in BE API

Browse files
backend/app/main.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- import random
5
 
6
  app = FastAPI()
7
 
@@ -27,12 +27,5 @@ async def crawl_documentation(input_data: UrlInput):
27
 
28
  @app.post("/problems/")
29
  async def generate_problems(query: UserQuery):
30
- # For MVP, returning random sample questions
31
- sample_questions = [
32
- "What is the main purpose of this framework?",
33
- "How do you install this tool?",
34
- "What are the key components?",
35
- "Explain the basic workflow",
36
- "What are the best practices?"
37
- ]
38
- return {"Problems": sample_questions}
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from backend.app.problem_generator import ProblemGenerator
5
 
6
  app = FastAPI()
7
 
 
27
 
28
  @app.post("/problems/")
29
  async def generate_problems(query: UserQuery):
30
+ problems = ProblemGenerator().generate_problems(query.user_query)
31
+ return {"Problems": problems}
 
 
 
 
 
 
 
backend/app/problem_generator.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ class ProblemGenerator:
4
+ def generate_problems(self, query: str) -> List[str]:
5
+ """
6
+ Generate problems based on the user's query.
7
+ """
8
+ # For MVP, returning random sample questions
9
+ sample_questions = [
10
+ "What is the main purpose of this framework?",
11
+ "How do you install this tool?",
12
+ "What are the key components?",
13
+ "Explain the basic workflow",
14
+ "What are the best practices?"
15
+ ]
16
+ return sample_questions
backend/app/vectorstore.py CHANGED
@@ -1,44 +1,62 @@
1
  """
2
  Super early version of a vector store. Just want to make something available for the rest of the app to use.
 
 
3
  """
4
  import os
5
  import requests
6
  import nltk
7
-
8
  from langchain_community.vectorstores import Qdrant
9
  from langchain_openai.embeddings import OpenAIEmbeddings
 
 
10
 
11
  nltk.download('punkt_tab')
12
  nltk.download('averaged_perceptron_tagger_eng')
13
 
14
- embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
15
-
16
- # Create static/data directory if it doesn't exist
17
- os.makedirs("static/data", exist_ok=True)
18
-
19
- # Download and save the webpage
20
- url = "https://python.langchain.com/docs/tutorials/rag/"
21
- response = requests.get(url)
22
- with open("static/data/langchain_rag_tutorial.html", "w", encoding="utf-8") as f:
23
- f.write(response.text)
24
-
25
- from langchain_community.document_loaders import DirectoryLoader
26
- from langchain.text_splitter import RecursiveCharacterTextSplitter
27
-
28
- # Load HTML files from static/data directory
29
- loader = DirectoryLoader("static/data", glob="*.html")
30
- documents = loader.load()
31
-
32
- # Split documents into chunks
33
- text_splitter = RecursiveCharacterTextSplitter(
34
- chunk_size=1000,
35
- chunk_overlap=200
36
- )
37
- split_chunks = text_splitter.split_documents(documents)
38
-
39
- vector_db = Qdrant.from_documents(
40
- split_chunks,
41
- embedding_model,
42
- location=":memory:",
43
- collection_name="extending_context_window_llama_3",
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Super early version of a vector store. Just want to make something available for the rest of the app to use.
3
+
4
+ Vector store implementation with singleton pattern to ensure only one instance exists.
5
  """
6
  import os
7
  import requests
8
  import nltk
9
+ from typing import Optional
10
  from langchain_community.vectorstores import Qdrant
11
  from langchain_openai.embeddings import OpenAIEmbeddings
12
+ from langchain_community.document_loaders import DirectoryLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
 
15
  nltk.download('punkt_tab')
16
  nltk.download('averaged_perceptron_tagger_eng')
17
 
18
+ # Global variable to store the singleton instance
19
+ _vector_db_instance: Optional[Qdrant] = None
20
+
21
+ def get_vector_db() -> Qdrant:
22
+ """
23
+ Factory function that returns a singleton instance of the vector database.
24
+ Creates the instance if it doesn't exist.
25
+ """
26
+ global _vector_db_instance
27
+
28
+ if _vector_db_instance is None:
29
+ # Create static/data directory if it doesn't exist
30
+ os.makedirs("static/data", exist_ok=True)
31
+
32
+ # Download and save the webpage if it doesn't exist
33
+ html_path = "static/data/langchain_rag_tutorial.html"
34
+ if not os.path.exists(html_path):
35
+ url = "https://python.langchain.com/docs/tutorials/rag/"
36
+ response = requests.get(url)
37
+ with open(html_path, "w", encoding="utf-8") as f:
38
+ f.write(response.text)
39
+
40
+ # Initialize embedding model
41
+ embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
42
+
43
+ # Load HTML files from static/data directory
44
+ loader = DirectoryLoader("static/data", glob="*.html")
45
+ documents = loader.load()
46
+
47
+ # Split documents into chunks
48
+ text_splitter = RecursiveCharacterTextSplitter(
49
+ chunk_size=1000,
50
+ chunk_overlap=200
51
+ )
52
+ split_chunks = text_splitter.split_documents(documents)
53
+
54
+ # Create vector store instance
55
+ _vector_db_instance = Qdrant.from_documents(
56
+ split_chunks,
57
+ embedding_model,
58
+ location=":memory:",
59
+ collection_name="extending_context_window_llama_3",
60
+ )
61
+
62
+ return _vector_db_instance
backend/tests/{test_quiz.py → test_api.py} RENAMED
@@ -18,4 +18,5 @@ def test_problems_endpoint():
18
  )
19
  assert response.status_code == 200
20
  assert "Problems" in response.json()
21
- assert len(response.json()["Problems"]) == 5
 
 
18
  )
19
  assert response.status_code == 200
20
  assert "Problems" in response.json()
21
+ assert len(response.json()["Problems"]) == 5
22
+
backend/tests/test_vectorstore.py CHANGED
@@ -1,15 +1,14 @@
1
- import pytest
2
  import os
3
  from langchain.schema import Document
4
- from backend.app import vectorstore
5
 
6
  def test_directory_creation():
7
- """Test that the static/data directory is created"""
8
  assert os.path.exists("static/data")
9
  assert os.path.exists("static/data/langchain_rag_tutorial.html")
10
 
 
11
  def test_html_content():
12
- """Test that the HTML content was downloaded and contains expected content"""
13
  with open("static/data/langchain_rag_tutorial.html", "r", encoding="utf-8") as f:
14
  content = f.read()
15
 
@@ -22,8 +21,9 @@ def test_vector_store_similarity_search():
22
  # Test query
23
  query = "What is RAG?"
24
 
25
- # Perform similarity search
26
- results = vectorstore.vector_db.similarity_search(query, k=2)
 
27
 
28
  # Verify we get results
29
  assert len(results) == 2
@@ -32,4 +32,13 @@ def test_vector_store_similarity_search():
32
  # Verify the results contain relevant content
33
  combined_content = " ".join([doc.page_content for doc in results]).lower()
34
  assert "rag" in combined_content
35
- assert "retrieval" in combined_content
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from langchain.schema import Document
3
+ from backend.app.vectorstore import get_vector_db
4
 
5
  def test_directory_creation():
6
+ get_vector_db()
7
  assert os.path.exists("static/data")
8
  assert os.path.exists("static/data/langchain_rag_tutorial.html")
9
 
10
+ # TODO remove this test when data ingrestion layer is implemented
11
  def test_html_content():
 
12
  with open("static/data/langchain_rag_tutorial.html", "r", encoding="utf-8") as f:
13
  content = f.read()
14
 
 
21
  # Test query
22
  query = "What is RAG?"
23
 
24
+ # Get vector db instance and perform similarity search
25
+ vector_db = get_vector_db()
26
+ results = vector_db.similarity_search(query, k=2)
27
 
28
  # Verify we get results
29
  assert len(results) == 2
 
32
  # Verify the results contain relevant content
33
  combined_content = " ".join([doc.page_content for doc in results]).lower()
34
  assert "rag" in combined_content
35
+ assert "retrieval" in combined_content
36
+
37
+ def test_vector_db_singleton():
38
+ """Test that get_vector_db returns the same instance each time"""
39
+ # Get two instances
40
+ instance1 = get_vector_db()
41
+ instance2 = get_vector_db()
42
+
43
+ # Verify they are the same object
44
+ assert instance1 is instance2