Rsr2425 commited on
Commit
14044f3
·
1 Parent(s): 58d3e00

Got feedback endpoint working

Browse files
backend/app/main.py CHANGED
@@ -4,7 +4,9 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import FileResponse
5
  from pydantic import BaseModel
6
  from backend.app.problem_generator import ProblemGenerationPipeline
 
7
  from typing import Dict, List
 
8
 
9
  app = FastAPI()
10
 
@@ -22,11 +24,15 @@ class UrlInput(BaseModel):
22
  class UserQuery(BaseModel):
23
  user_query: str
24
 
25
- class FeedbackInput(BaseModel):
 
26
  user_query: str
27
  problems: list[str]
28
  user_answers: list[str]
29
 
 
 
 
30
  @app.post("/api/crawl/")
31
  async def crawl_documentation(input_data: UrlInput):
32
  print(f"Received url {input_data.url}")
@@ -37,17 +43,28 @@ async def generate_problems(query: UserQuery):
37
  problems = ProblemGenerationPipeline().generate_problems(query.user_query)
38
  return {"Problems": problems}
39
 
40
- @app.post("/api/feedback/")
41
- async def submit_feedback(feedback: FeedbackInput):
42
- # check if problems len is equal to user_answers len
43
- if len(feedback.problems) != len(feedback.user_answers):
44
  raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
45
-
46
- for problem, user_answer in zip(feedback.problems, feedback.user_answers):
47
- print(f"Problem: {problem}")
48
- print(f"User answer: {user_answer}")
49
-
50
- return {"status": "success"}
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Serve static files
53
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
 
4
  from fastapi.responses import FileResponse
5
  from pydantic import BaseModel
6
  from backend.app.problem_generator import ProblemGenerationPipeline
7
+ from backend.app.problem_grader import ProblemGradingPipeline
8
  from typing import Dict, List
9
+ import asyncio
10
 
11
  app = FastAPI()
12
 
 
24
  class UserQuery(BaseModel):
25
  user_query: str
26
 
27
+ # TODO: Make this a list of {problem: str, answer: str}. Would be cleaner for data validation
28
+ class FeedbackRequest(BaseModel):
29
  user_query: str
30
  problems: list[str]
31
  user_answers: list[str]
32
 
33
+ class FeedbackResponse(BaseModel):
34
+ feedback: List[str]
35
+
36
  @app.post("/api/crawl/")
37
  async def crawl_documentation(input_data: UrlInput):
38
  print(f"Received url {input_data.url}")
 
43
  problems = ProblemGenerationPipeline().generate_problems(query.user_query)
44
  return {"Problems": problems}
45
 
46
+ @app.post("/api/feedback", response_model=FeedbackResponse)
47
+ async def get_feedback(request: FeedbackRequest):
48
+ if len(request.problems) != len(request.user_answers):
 
49
  raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
50
+ try:
51
+ grader = ProblemGradingPipeline()
52
+
53
+ grading_tasks = [
54
+ grader.grade(
55
+ query=request.user_query,
56
+ problem=problem,
57
+ answer=user_answer,
58
+ )
59
+ for problem, user_answer in zip(request.problems, request.user_answers)
60
+ ]
61
+
62
+ feedback_list = await asyncio.gather(*grading_tasks)
63
+
64
+ return FeedbackResponse(feedback=feedback_list)
65
+
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))
68
 
69
  # Serve static files
70
  app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
backend/app/problem_generator.py CHANGED
@@ -36,7 +36,6 @@ class ProblemGenerationPipeline:
36
  self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
37
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
38
 
39
- # Build the RAG chain
40
  self.rag_chain = (
41
  {"context": self.retriever, "query": RunnablePassthrough()}
42
  | self.chat_prompt
 
36
  self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
37
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
38
 
 
39
  self.rag_chain = (
40
  {"context": self.retriever, "query": RunnablePassthrough()}
41
  | self.chat_prompt
backend/app/problem_grader.py CHANGED
@@ -40,7 +40,6 @@ class ProblemGradingPipeline:
40
  self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
41
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
42
 
43
- # Build the RAG chain
44
  self.rag_chain = (
45
  {
46
  "context": self.retriever,
@@ -53,9 +52,9 @@ class ProblemGradingPipeline:
53
  | StrOutputParser()
54
  )
55
 
56
- def grade(self, query: str, problem: str, answer: str) -> str:
57
  """
58
- Grade a student's answer to a problem using RAG for context-aware evaluation.
59
 
60
  Args:
61
  query (str): The topic/context to use for grading
@@ -65,7 +64,7 @@ class ProblemGradingPipeline:
65
  Returns:
66
  str: Grading response indicating if the answer is correct and providing feedback
67
  """
68
- return self.rag_chain.invoke({
69
  "query": query,
70
  "problem": problem,
71
  "answer": answer
 
40
  self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
41
  self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
42
 
 
43
  self.rag_chain = (
44
  {
45
  "context": self.retriever,
 
52
  | StrOutputParser()
53
  )
54
 
55
+ async def grade(self, query: str, problem: str, answer: str) -> str:
56
  """
57
+ Asynchronously grade a student's answer to a problem using RAG for context-aware evaluation.
58
 
59
  Args:
60
  query (str): The topic/context to use for grading
 
64
  Returns:
65
  str: Grading response indicating if the answer is correct and providing feedback
66
  """
67
+ return await self.rag_chain.ainvoke({
68
  "query": query,
69
  "problem": problem,
70
  "answer": answer
backend/tests/test_api.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi.testclient import TestClient
2
  from backend.app.main import app
 
3
 
4
  client = TestClient(app)
5
 
@@ -19,4 +20,48 @@ def test_problems_endpoint():
19
  assert response.status_code == 200
20
  assert "Problems" in response.json()
21
  assert len(response.json()["Problems"]) == 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
1
  from fastapi.testclient import TestClient
2
  from backend.app.main import app
3
+ import pytest
4
 
5
  client = TestClient(app)
6
 
 
20
  assert response.status_code == 200
21
  assert "Problems" in response.json()
22
  assert len(response.json()["Problems"]) == 5
23
+
24
+ def test_feedback_validation_error():
25
+ """Test that mismatched problems and answers lengths return 400"""
26
+ response = client.post(
27
+ "/api/feedback",
28
+ json={
29
+ "user_query": "Python lists",
30
+ "problems": ["What is a list?", "How do you append?"],
31
+ "user_answers": ["A sequence",] # Only one answer
32
+ }
33
+ )
34
+
35
+ assert response.status_code == 400
36
+ assert "same length" in response.json()["detail"]
37
+
38
+ @pytest.mark.asyncio
39
+ async def test_successful_feedback():
40
+ """Test successful grading of multiple problems"""
41
+ response = client.post(
42
+ "/api/feedback",
43
+ json={
44
+ "user_query": "RAG",
45
+ "problems": [
46
+ "What are the two main components of a typical RAG application?",
47
+ "What is the purpose of the indexing component in a RAG application?"
48
+ ],
49
+ "user_answers": [
50
+ "A list is a mutable sequence type that can store multiple items in Python",
51
+ "You use the append() method to add an element to the end of a list"
52
+ ]
53
+ }
54
+ )
55
+
56
+ assert response.status_code == 200
57
+ result = response.json()
58
+ assert "feedback" in result
59
+ assert len(result["feedback"]) == 2
60
+
61
+ # Check that responses start with either "Correct" or "Incorrect"
62
+ for feedback in result["feedback"]:
63
+ assert feedback.startswith(("Correct", "Incorrect"))
64
+ # Check that there's an explanation after the classification
65
+ assert len(feedback.split(". ")) >= 2
66
+
67