Rsr2425 commited on
Commit
999f24c
·
1 Parent(s): 6c5c116

Linted code with Black

Browse files
backend/app/main.py CHANGED
@@ -7,6 +7,9 @@ from backend.app.problem_generator import ProblemGenerationPipeline
7
  from backend.app.problem_grader import ProblemGradingPipeline
8
  from typing import Dict, List
9
  import asyncio
 
 
 
10
 
11
  app = FastAPI()
12
 
@@ -18,38 +21,48 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
 
21
  class UrlInput(BaseModel):
22
  url: str
23
 
 
24
  class UserQuery(BaseModel):
25
  user_query: str
26
 
 
27
  # TODO: Make this a list of {problem: str, answer: str}. Would be cleaner for data validation
28
  class FeedbackRequest(BaseModel):
29
  user_query: str
30
  problems: list[str]
31
  user_answers: list[str]
32
 
 
33
  class FeedbackResponse(BaseModel):
34
  feedback: List[str]
35
 
 
36
  @app.post("/api/crawl/")
37
  async def crawl_documentation(input_data: UrlInput):
38
  print(f"Received url {input_data.url}")
39
  return {"status": "received"}
40
 
 
41
  @app.post("/api/problems/")
42
  async def generate_problems(query: UserQuery):
43
  problems = ProblemGenerationPipeline().generate_problems(query.user_query)
44
  return {"Problems": problems}
45
 
 
46
  @app.post("/api/feedback", response_model=FeedbackResponse)
47
  async def get_feedback(request: FeedbackRequest):
48
  if len(request.problems) != len(request.user_answers):
49
- raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
 
 
 
50
  try:
51
  grader = ProblemGradingPipeline()
52
-
53
  grading_tasks = [
54
  grader.grade(
55
  query=request.user_query,
@@ -58,32 +71,59 @@ async def get_feedback(request: FeedbackRequest):
58
  )
59
  for problem, user_answer in zip(request.problems, request.user_answers)
60
  ]
61
-
62
  feedback_list = await asyncio.gather(*grading_tasks)
63
-
64
  return FeedbackResponse(feedback=feedback_list)
65
-
66
  except Exception as e:
67
  # log exception and stack trace
68
  import traceback
 
69
  print(f"Exception: {e}")
70
  print(f"Stack trace: {traceback.format_exc()}")
71
  raise HTTPException(status_code=500, detail=str(e))
72
 
 
73
  # Serve static files
74
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
75
 
 
76
  # Root path handler
77
  @app.get("/")
78
  async def serve_root():
79
  return FileResponse("/app/static/index.html")
80
 
 
81
  # Catch-all route for serving index.html
82
  @app.get("/{full_path:path}")
83
  async def serve_react(full_path: str):
84
  # Skip API routes
85
  if full_path.startswith("api/"):
86
  raise HTTPException(status_code=404, detail="Not found")
87
-
88
  # For all other routes, serve the React index.html
89
- return FileResponse("/app/static/index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from backend.app.problem_grader import ProblemGradingPipeline
8
  from typing import Dict, List
9
  import asyncio
10
+ import logging
11
+ import os
12
+ from crawler import DomainCrawler
13
 
14
  app = FastAPI()
15
 
 
21
  allow_headers=["*"],
22
  )
23
 
24
+
25
  class UrlInput(BaseModel):
26
  url: str
27
 
28
+
29
  class UserQuery(BaseModel):
30
  user_query: str
31
 
32
+
33
  # TODO: Make this a list of {problem: str, answer: str}. Would be cleaner for data validation
34
  class FeedbackRequest(BaseModel):
35
  user_query: str
36
  problems: list[str]
37
  user_answers: list[str]
38
 
39
+
40
  class FeedbackResponse(BaseModel):
41
  feedback: List[str]
42
 
43
+
44
  @app.post("/api/crawl/")
45
  async def crawl_documentation(input_data: UrlInput):
46
  print(f"Received url {input_data.url}")
47
  return {"status": "received"}
48
 
49
+
50
  @app.post("/api/problems/")
51
  async def generate_problems(query: UserQuery):
52
  problems = ProblemGenerationPipeline().generate_problems(query.user_query)
53
  return {"Problems": problems}
54
 
55
+
56
  @app.post("/api/feedback", response_model=FeedbackResponse)
57
  async def get_feedback(request: FeedbackRequest):
58
  if len(request.problems) != len(request.user_answers):
59
+ raise HTTPException(
60
+ status_code=400,
61
+ detail="Problems and user answers must have the same length",
62
+ )
63
  try:
64
  grader = ProblemGradingPipeline()
65
+
66
  grading_tasks = [
67
  grader.grade(
68
  query=request.user_query,
 
71
  )
72
  for problem, user_answer in zip(request.problems, request.user_answers)
73
  ]
74
+
75
  feedback_list = await asyncio.gather(*grading_tasks)
76
+
77
  return FeedbackResponse(feedback=feedback_list)
78
+
79
  except Exception as e:
80
  # log exception and stack trace
81
  import traceback
82
+
83
  print(f"Exception: {e}")
84
  print(f"Stack trace: {traceback.format_exc()}")
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
87
+
88
  # Serve static files
89
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
90
 
91
+
92
  # Root path handler
93
  @app.get("/")
94
  async def serve_root():
95
  return FileResponse("/app/static/index.html")
96
 
97
+
98
  # Catch-all route for serving index.html
99
  @app.get("/{full_path:path}")
100
  async def serve_react(full_path: str):
101
  # Skip API routes
102
  if full_path.startswith("api/"):
103
  raise HTTPException(status_code=404, detail="Not found")
104
+
105
  # For all other routes, serve the React index.html
106
+ return FileResponse("/app/static/index.html")
107
+
108
+
109
+ def setup_logging():
110
+ """Configure logging for the entire application"""
111
+ # Create logs directory if it doesn't exist
112
+ logs_dir = "logs"
113
+ if not os.path.exists(logs_dir):
114
+ os.makedirs(logs_dir)
115
+
116
+ # Configure logging
117
+ logging.basicConfig(
118
+ level=logging.INFO,
119
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
120
+ handlers=[
121
+ # Console handler
122
+ logging.StreamHandler(),
123
+ # File handler
124
+ logging.FileHandler(os.path.join(logs_dir, "crawler.log")),
125
+ ],
126
+ )
127
+
128
+
129
+ setup_logging()
backend/app/problem_generator.py CHANGED
@@ -32,14 +32,15 @@ USER_ROLE_PROMPT = """
32
 
33
  class ProblemGenerationPipeline:
34
  def __init__(self, return_context: bool = False, embedding_model_id: str = None):
35
- self.chat_prompt = ChatPromptTemplate.from_messages([
36
- ("system", SYSTEM_ROLE_PROMPT),
37
- ("user", USER_ROLE_PROMPT)
38
- ])
39
-
40
  self.llm = ChatOpenAI(model=MODEL, temperature=0.7)
41
- self.retriever = get_vector_db(embedding_model_id).as_retriever(search_kwargs={"k": 2})
42
-
 
 
43
  # TODO: This is a hack to get the context for the questions. Very messy interface.
44
  self.return_context = return_context
45
  if not return_context:
@@ -52,18 +53,24 @@ class ProblemGenerationPipeline:
52
  else:
53
  # response looks like: {response: str, context: List[Document]}
54
  self.rag_chain = (
55
- {"context": itemgetter("query") | self.retriever, "query": itemgetter("query")}
 
 
 
56
  | RunnablePassthrough.assign(context=itemgetter("context"))
57
- | {"response": self.chat_prompt | self.llm | StrOutputParser(), "context": itemgetter("context")}
 
 
 
58
  )
59
 
60
  def generate_problems(self, query: str, debug: bool = False) -> List[str]:
61
  """
62
  Generate problems based on the user's query using RAG.
63
-
64
  Args:
65
  query (str): The topic to generate questions about
66
-
67
  Returns:
68
  List[str]: A list of generated questions
69
  """
@@ -75,4 +82,4 @@ class ProblemGenerationPipeline:
75
  return raw_result
76
  # raw_result is a string when return_context is False
77
  else:
78
- return json.loads(raw_result)["questions"]
 
32
 
33
  class ProblemGenerationPipeline:
34
  def __init__(self, return_context: bool = False, embedding_model_id: str = None):
35
+ self.chat_prompt = ChatPromptTemplate.from_messages(
36
+ [("system", SYSTEM_ROLE_PROMPT), ("user", USER_ROLE_PROMPT)]
37
+ )
38
+
 
39
  self.llm = ChatOpenAI(model=MODEL, temperature=0.7)
40
+ self.retriever = get_vector_db(embedding_model_id).as_retriever(
41
+ search_kwargs={"k": 2}
42
+ )
43
+
44
  # TODO: This is a hack to get the context for the questions. Very messy interface.
45
  self.return_context = return_context
46
  if not return_context:
 
53
  else:
54
  # response looks like: {response: str, context: List[Document]}
55
  self.rag_chain = (
56
+ {
57
+ "context": itemgetter("query") | self.retriever,
58
+ "query": itemgetter("query"),
59
+ }
60
  | RunnablePassthrough.assign(context=itemgetter("context"))
61
+ | {
62
+ "response": self.chat_prompt | self.llm | StrOutputParser(),
63
+ "context": itemgetter("context"),
64
+ }
65
  )
66
 
67
  def generate_problems(self, query: str, debug: bool = False) -> List[str]:
68
  """
69
  Generate problems based on the user's query using RAG.
70
+
71
  Args:
72
  query (str): The topic to generate questions about
73
+
74
  Returns:
75
  List[str]: A list of generated questions
76
  """
 
82
  return raw_result
83
  # raw_result is a string when return_context is False
84
  else:
85
+ return json.loads(raw_result)["questions"]
backend/app/problem_grader.py CHANGED
@@ -7,6 +7,7 @@ from langchain_core.runnables import RunnablePassthrough
7
  from langchain_core.output_parsers import StrOutputParser
8
  from backend.app.vectorstore import get_vector_db
9
  from operator import itemgetter
 
10
  MODEL = "gpt-3.5-turbo"
11
 
12
  SYSTEM_ROLE_PROMPT = """
@@ -35,23 +36,24 @@ USER_ROLE_PROMPT = """
35
 
36
  class ProblemGradingPipeline:
37
  def __init__(self):
38
- self.chat_prompt = ChatPromptTemplate.from_messages([
39
- ("system", SYSTEM_ROLE_PROMPT),
40
- ("user", USER_ROLE_PROMPT)
41
- ])
42
-
43
  self.llm = ChatOpenAI(model=MODEL, temperature=0.3)
44
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
45
-
46
- self.rag_chain = (
47
  {
48
  # Use the query to retrieve documents from the vectorstore
49
- "context": itemgetter("query") | self.retriever | (lambda docs: "\n\n".join([doc.page_content for doc in docs])),
 
 
50
  # Pass through all other inputs directly
51
  "query": itemgetter("query"),
52
  "problem": itemgetter("problem"),
53
- "answer": itemgetter("answer")
54
- }
55
  | self.chat_prompt
56
  | self.llm
57
  | StrOutputParser()
@@ -60,18 +62,16 @@ class ProblemGradingPipeline:
60
  async def grade(self, query: str, problem: str, answer: str) -> str:
61
  """
62
  Asynchronously grade a student's answer to a problem using RAG for context-aware evaluation.
63
-
64
  Args:
65
  query (str): The topic/context to use for grading
66
  problem (str): The question being answered
67
  answer (str): The student's answer to evaluate
68
-
69
  Returns:
70
  str: Grading response indicating if the answer is correct and providing feedback
71
  """
72
  print(f"Grading problem: {problem} with answer: {answer} for query: {query}")
73
- return await self.rag_chain.ainvoke({
74
- "query": query,
75
- "problem": problem,
76
- "answer": answer
77
- })
 
7
  from langchain_core.output_parsers import StrOutputParser
8
  from backend.app.vectorstore import get_vector_db
9
  from operator import itemgetter
10
+
11
  MODEL = "gpt-3.5-turbo"
12
 
13
  SYSTEM_ROLE_PROMPT = """
 
36
 
37
  class ProblemGradingPipeline:
38
  def __init__(self):
39
+ self.chat_prompt = ChatPromptTemplate.from_messages(
40
+ [("system", SYSTEM_ROLE_PROMPT), ("user", USER_ROLE_PROMPT)]
41
+ )
42
+
 
43
  self.llm = ChatOpenAI(model=MODEL, temperature=0.3)
44
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
45
+
46
+ self.rag_chain = (
47
  {
48
  # Use the query to retrieve documents from the vectorstore
49
+ "context": itemgetter("query")
50
+ | self.retriever
51
+ | (lambda docs: "\n\n".join([doc.page_content for doc in docs])),
52
  # Pass through all other inputs directly
53
  "query": itemgetter("query"),
54
  "problem": itemgetter("problem"),
55
+ "answer": itemgetter("answer"),
56
+ }
57
  | self.chat_prompt
58
  | self.llm
59
  | StrOutputParser()
 
62
  async def grade(self, query: str, problem: str, answer: str) -> str:
63
  """
64
  Asynchronously grade a student's answer to a problem using RAG for context-aware evaluation.
65
+
66
  Args:
67
  query (str): The topic/context to use for grading
68
  problem (str): The question being answered
69
  answer (str): The student's answer to evaluate
70
+
71
  Returns:
72
  str: Grading response indicating if the answer is correct and providing feedback
73
  """
74
  print(f"Grading problem: {problem} with answer: {answer} for query: {query}")
75
+ return await self.rag_chain.ainvoke(
76
+ {"query": query, "problem": problem, "answer": answer}
77
+ )
 
 
backend/app/vectorstore.py CHANGED
@@ -3,20 +3,26 @@ Super early version of a vector store. Just want to make something available for
3
 
4
  Vector store implementation with singleton pattern to ensure only one instance exists.
5
  """
 
6
  import os
7
  import requests
8
  import nltk
 
9
  from typing import Optional
10
  from langchain_community.vectorstores import Qdrant
11
  from langchain_openai.embeddings import OpenAIEmbeddings
12
  from langchain_community.document_loaders import DirectoryLoader
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_huggingface import HuggingFaceEmbeddings
 
15
 
16
- nltk.download('punkt_tab')
17
- nltk.download('averaged_perceptron_tagger_eng')
18
 
19
  DEFAULT_EMBEDDING_MODEL_ID = "text-embedding-3-small"
 
 
 
20
 
21
  # Global variable to store the singleton instance
22
  _vector_db_instance: Optional[Qdrant] = None
@@ -24,13 +30,29 @@ _vector_db_instance: Optional[Qdrant] = None
24
  # to match the new embedding model.
25
  _embedding_model_id: str = None
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def get_vector_db(embedding_model_id: str = None) -> Qdrant:
28
  """
29
  Factory function that returns a singleton instance of the vector database.
30
  Creates the instance if it doesn't exist.
31
  """
32
  global _vector_db_instance
33
-
34
  if _vector_db_instance is None:
35
  # Create static/data directory if it doesn't exist
36
  os.makedirs("static/data", exist_ok=True)
@@ -45,7 +67,7 @@ def get_vector_db(embedding_model_id: str = None) -> Qdrant:
45
 
46
  embedding_model = None
47
  if embedding_model_id is None:
48
- embedding_model = OpenAIEmbeddings(model=DEFAULT_EMBEDDING_MODEL_ID)
49
  else:
50
  embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
51
 
@@ -55,16 +77,16 @@ def get_vector_db(embedding_model_id: str = None) -> Qdrant:
55
 
56
  # Split documents into chunks
57
  text_splitter = RecursiveCharacterTextSplitter(
58
- chunk_size=1000,
59
- chunk_overlap=200
60
  )
61
  split_chunks = text_splitter.split_documents(documents)
62
 
63
  # Create vector store instance
 
64
  _vector_db_instance = Qdrant.from_documents(
65
  split_chunks,
66
  embedding_model,
67
- location=":memory:",
68
  collection_name="extending_context_window_llama_3",
69
  )
70
 
 
3
 
4
  Vector store implementation with singleton pattern to ensure only one instance exists.
5
  """
6
+
7
  import os
8
  import requests
9
  import nltk
10
+ import logging
11
  from typing import Optional
12
  from langchain_community.vectorstores import Qdrant
13
  from langchain_openai.embeddings import OpenAIEmbeddings
14
  from langchain_community.document_loaders import DirectoryLoader
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain_huggingface import HuggingFaceEmbeddings
17
+ from qdrant_client import QdrantClient
18
 
19
+ nltk.download("punkt_tab")
20
+ nltk.download("averaged_perceptron_tagger_eng")
21
 
22
  DEFAULT_EMBEDDING_MODEL_ID = "text-embedding-3-small"
23
+ LOCAL_QDRANT_PATH = "/data/qdrant_db"
24
+
25
+ logger = logging.getLogger(__name__)
26
 
27
  # Global variable to store the singleton instance
28
  _vector_db_instance: Optional[Qdrant] = None
 
30
  # to match the new embedding model.
31
  _embedding_model_id: str = None
32
 
33
+
34
+ def get_qdrant_client():
35
+ if os.environ.get("QDRANT_URL") is None or os.environ.get("QDRANT_API_KEY") is None:
36
+ logger.error(
37
+ "QDRANT_URL or QDRANT_API_KEY is not set. Defaulting to local memory vector store."
38
+ )
39
+
40
+ os.makedirs(LOCAL_QDRANT_PATH, exist_ok=True)
41
+ return QdrantClient(path=LOCAL_QDRANT_PATH)
42
+
43
+ QDRANT_URL = os.environ.get("QDRANT_URL")
44
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
45
+
46
+ return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
47
+
48
+
49
  def get_vector_db(embedding_model_id: str = None) -> Qdrant:
50
  """
51
  Factory function that returns a singleton instance of the vector database.
52
  Creates the instance if it doesn't exist.
53
  """
54
  global _vector_db_instance
55
+
56
  if _vector_db_instance is None:
57
  # Create static/data directory if it doesn't exist
58
  os.makedirs("static/data", exist_ok=True)
 
67
 
68
  embedding_model = None
69
  if embedding_model_id is None:
70
+ embedding_model = OpenAIEmbeddings(modzŻel=DEFAULT_EMBEDDING_MODEL_ID)
71
  else:
72
  embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
73
 
 
77
 
78
  # Split documents into chunks
79
  text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=1000, chunk_overlap=200
 
81
  )
82
  split_chunks = text_splitter.split_documents(documents)
83
 
84
  # Create vector store instance
85
+ client = get_qdrant_client()
86
  _vector_db_instance = Qdrant.from_documents(
87
  split_chunks,
88
  embedding_model,
89
+ client=client,
90
  collection_name="extending_context_window_llama_3",
91
  )
92
 
backend/tests/test_api.py CHANGED
@@ -4,22 +4,19 @@ import pytest
4
 
5
  client = TestClient(app)
6
 
 
7
  def test_crawl_endpoint():
8
- response = client.post(
9
- "/api/crawl/",
10
- json={"url": "https://example.com"}
11
- )
12
  assert response.status_code == 200
13
  assert response.json() == {"status": "received"}
14
 
 
15
  def test_problems_endpoint():
16
- response = client.post(
17
- "/api/problems/",
18
- json={"user_query": "RAG"}
19
- )
20
  assert response.status_code == 200
21
  assert "Problems" in response.json()
22
- assert len(response.json()["Problems"]) == 5
 
23
 
24
  def test_feedback_validation_error():
25
  """Test that mismatched problems and answers lengths return 400"""
@@ -28,13 +25,16 @@ def test_feedback_validation_error():
28
  json={
29
  "user_query": "Python lists",
30
  "problems": ["What is a list?", "How do you append?"],
31
- "user_answers": ["A sequence",] # Only one answer
32
- }
 
 
33
  )
34
-
35
  assert response.status_code == 400
36
  assert "same length" in response.json()["detail"]
37
 
 
38
  @pytest.mark.asyncio
39
  async def test_successful_feedback():
40
  """Test successful grading of multiple problems"""
@@ -44,24 +44,22 @@ async def test_successful_feedback():
44
  "user_query": "RAG",
45
  "problems": [
46
  "What are the two main components of a typical RAG application?",
47
- "What is the purpose of the indexing component in a RAG application?"
48
  ],
49
  "user_answers": [
50
  "A list is a mutable sequence type that can store multiple items in Python",
51
- "You use the append() method to add an element to the end of a list"
52
- ]
53
- }
54
  )
55
-
56
  assert response.status_code == 200
57
  result = response.json()
58
  assert "feedback" in result
59
  assert len(result["feedback"]) == 2
60
-
61
  # Check that responses start with either "Correct" or "Incorrect"
62
  for feedback in result["feedback"]:
63
  assert feedback.startswith(("Correct", "Incorrect"))
64
  # Check that there's an explanation after the classification
65
  assert len(feedback.split(". ")) >= 2
66
-
67
-
 
4
 
5
  client = TestClient(app)
6
 
7
+
8
  def test_crawl_endpoint():
9
+ response = client.post("/api/crawl/", json={"url": "https://example.com"})
 
 
 
10
  assert response.status_code == 200
11
  assert response.json() == {"status": "received"}
12
 
13
+
14
  def test_problems_endpoint():
15
+ response = client.post("/api/problems/", json={"user_query": "RAG"})
 
 
 
16
  assert response.status_code == 200
17
  assert "Problems" in response.json()
18
+ assert len(response.json()["Problems"]) == 5
19
+
20
 
21
  def test_feedback_validation_error():
22
  """Test that mismatched problems and answers lengths return 400"""
 
25
  json={
26
  "user_query": "Python lists",
27
  "problems": ["What is a list?", "How do you append?"],
28
+ "user_answers": [
29
+ "A sequence",
30
+ ], # Only one answer
31
+ },
32
  )
33
+
34
  assert response.status_code == 400
35
  assert "same length" in response.json()["detail"]
36
 
37
+
38
  @pytest.mark.asyncio
39
  async def test_successful_feedback():
40
  """Test successful grading of multiple problems"""
 
44
  "user_query": "RAG",
45
  "problems": [
46
  "What are the two main components of a typical RAG application?",
47
+ "What is the purpose of the indexing component in a RAG application?",
48
  ],
49
  "user_answers": [
50
  "A list is a mutable sequence type that can store multiple items in Python",
51
+ "You use the append() method to add an element to the end of a list",
52
+ ],
53
+ },
54
  )
55
+
56
  assert response.status_code == 200
57
  result = response.json()
58
  assert "feedback" in result
59
  assert len(result["feedback"]) == 2
60
+
61
  # Check that responses start with either "Correct" or "Incorrect"
62
  for feedback in result["feedback"]:
63
  assert feedback.startswith(("Correct", "Incorrect"))
64
  # Check that there's an explanation after the classification
65
  assert len(feedback.split(". ")) >= 2
 
 
backend/tests/test_vectorstore.py CHANGED
@@ -2,43 +2,47 @@ import os
2
  from langchain.schema import Document
3
  from backend.app.vectorstore import get_vector_db
4
 
 
5
  def test_directory_creation():
6
  get_vector_db()
7
  assert os.path.exists("static/data")
8
  assert os.path.exists("static/data/langchain_rag_tutorial.html")
9
 
 
10
  # TODO remove this test when data ingrestion layer is implemented
11
  def test_html_content():
12
  with open("static/data/langchain_rag_tutorial.html", "r", encoding="utf-8") as f:
13
  content = f.read()
14
-
15
  # Check for some expected content from the LangChain RAG tutorial
16
  assert "RAG" in content
17
  assert "LangChain" in content
18
 
 
19
  def test_vector_store_similarity_search():
20
  """Test that the vector store can perform similarity search"""
21
  # Test query
22
  query = "What is RAG?"
23
-
24
  # Get vector db instance and perform similarity search
25
  vector_db = get_vector_db()
26
  results = vector_db.similarity_search(query, k=2)
27
-
28
  # Verify we get results
29
  assert len(results) == 2
30
  assert isinstance(results[0], Document)
31
-
32
  # Verify the results contain relevant content
33
  combined_content = " ".join([doc.page_content for doc in results]).lower()
34
  assert "rag" in combined_content
35
  assert "retrieval" in combined_content
36
 
 
37
  def test_vector_db_singleton():
38
  """Test that get_vector_db returns the same instance each time"""
39
  # Get two instances
40
  instance1 = get_vector_db()
41
  instance2 = get_vector_db()
42
-
43
  # Verify they are the same object
44
- assert instance1 is instance2
 
2
  from langchain.schema import Document
3
  from backend.app.vectorstore import get_vector_db
4
 
5
+
6
  def test_directory_creation():
7
  get_vector_db()
8
  assert os.path.exists("static/data")
9
  assert os.path.exists("static/data/langchain_rag_tutorial.html")
10
 
11
+
12
  # TODO remove this test when data ingrestion layer is implemented
13
  def test_html_content():
14
  with open("static/data/langchain_rag_tutorial.html", "r", encoding="utf-8") as f:
15
  content = f.read()
16
+
17
  # Check for some expected content from the LangChain RAG tutorial
18
  assert "RAG" in content
19
  assert "LangChain" in content
20
 
21
+
22
  def test_vector_store_similarity_search():
23
  """Test that the vector store can perform similarity search"""
24
  # Test query
25
  query = "What is RAG?"
26
+
27
  # Get vector db instance and perform similarity search
28
  vector_db = get_vector_db()
29
  results = vector_db.similarity_search(query, k=2)
30
+
31
  # Verify we get results
32
  assert len(results) == 2
33
  assert isinstance(results[0], Document)
34
+
35
  # Verify the results contain relevant content
36
  combined_content = " ".join([doc.page_content for doc in results]).lower()
37
  assert "rag" in combined_content
38
  assert "retrieval" in combined_content
39
 
40
+
41
  def test_vector_db_singleton():
42
  """Test that get_vector_db returns the same instance each time"""
43
  # Get two instances
44
  instance1 = get_vector_db()
45
  instance2 = get_vector_db()
46
+
47
  # Verify they are the same object
48
+ assert instance1 is instance2
pyproject.toml CHANGED
@@ -30,6 +30,7 @@ dependencies = [
30
  "wandb>=0.19.6",
31
  "datasets>=3.2.0",
32
  "ragas==0.2.10",
 
33
  ]
34
 
35
  [tool.setuptools]
 
30
  "wandb>=0.19.6",
31
  "datasets>=3.2.0",
32
  "ragas==0.2.10",
33
+ "black>=25.1.0",
34
  ]
35
 
36
  [tool.setuptools]