Got feedback endpoint working
Browse files- backend/app/main.py +28 -11
- backend/app/problem_generator.py +0 -1
- backend/app/problem_grader.py +3 -4
- backend/tests/test_api.py +45 -0
backend/app/main.py
CHANGED
@@ -4,7 +4,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
4 |
from fastapi.responses import FileResponse
|
5 |
from pydantic import BaseModel
|
6 |
from backend.app.problem_generator import ProblemGenerationPipeline
|
|
|
7 |
from typing import Dict, List
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
@@ -22,11 +24,15 @@ class UrlInput(BaseModel):
|
|
22 |
class UserQuery(BaseModel):
|
23 |
user_query: str
|
24 |
|
25 |
-
|
|
|
26 |
user_query: str
|
27 |
problems: list[str]
|
28 |
user_answers: list[str]
|
29 |
|
|
|
|
|
|
|
30 |
@app.post("/api/crawl/")
|
31 |
async def crawl_documentation(input_data: UrlInput):
|
32 |
print(f"Received url {input_data.url}")
|
@@ -37,17 +43,28 @@ async def generate_problems(query: UserQuery):
|
|
37 |
problems = ProblemGenerationPipeline().generate_problems(query.user_query)
|
38 |
return {"Problems": problems}
|
39 |
|
40 |
-
@app.post("/api/feedback
|
41 |
-
async def
|
42 |
-
|
43 |
-
if len(feedback.problems) != len(feedback.user_answers):
|
44 |
raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Serve static files
|
53 |
app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
|
|
|
4 |
from fastapi.responses import FileResponse
|
5 |
from pydantic import BaseModel
|
6 |
from backend.app.problem_generator import ProblemGenerationPipeline
|
7 |
+
from backend.app.problem_grader import ProblemGradingPipeline
|
8 |
from typing import Dict, List
|
9 |
+
import asyncio
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
|
|
24 |
class UserQuery(BaseModel):
|
25 |
user_query: str
|
26 |
|
27 |
+
# TODO: Make this a list of {problem: str, answer: str}. Would be cleaner for data validation
|
28 |
+
class FeedbackRequest(BaseModel):
|
29 |
user_query: str
|
30 |
problems: list[str]
|
31 |
user_answers: list[str]
|
32 |
|
33 |
+
class FeedbackResponse(BaseModel):
|
34 |
+
feedback: List[str]
|
35 |
+
|
36 |
@app.post("/api/crawl/")
|
37 |
async def crawl_documentation(input_data: UrlInput):
|
38 |
print(f"Received url {input_data.url}")
|
|
|
43 |
problems = ProblemGenerationPipeline().generate_problems(query.user_query)
|
44 |
return {"Problems": problems}
|
45 |
|
46 |
+
@app.post("/api/feedback", response_model=FeedbackResponse)
|
47 |
+
async def get_feedback(request: FeedbackRequest):
|
48 |
+
if len(request.problems) != len(request.user_answers):
|
|
|
49 |
raise HTTPException(status_code=400, detail="Problems and user answers must have the same length")
|
50 |
+
try:
|
51 |
+
grader = ProblemGradingPipeline()
|
52 |
+
|
53 |
+
grading_tasks = [
|
54 |
+
grader.grade(
|
55 |
+
query=request.user_query,
|
56 |
+
problem=problem,
|
57 |
+
answer=user_answer,
|
58 |
+
)
|
59 |
+
for problem, user_answer in zip(request.problems, request.user_answers)
|
60 |
+
]
|
61 |
+
|
62 |
+
feedback_list = await asyncio.gather(*grading_tasks)
|
63 |
+
|
64 |
+
return FeedbackResponse(feedback=feedback_list)
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
raise HTTPException(status_code=500, detail=str(e))
|
68 |
|
69 |
# Serve static files
|
70 |
app.mount("/static", StaticFiles(directory="/app/static/static"), name="static")
|
backend/app/problem_generator.py
CHANGED
@@ -36,7 +36,6 @@ class ProblemGenerationPipeline:
|
|
36 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
|
37 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
38 |
|
39 |
-
# Build the RAG chain
|
40 |
self.rag_chain = (
|
41 |
{"context": self.retriever, "query": RunnablePassthrough()}
|
42 |
| self.chat_prompt
|
|
|
36 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
|
37 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
38 |
|
|
|
39 |
self.rag_chain = (
|
40 |
{"context": self.retriever, "query": RunnablePassthrough()}
|
41 |
| self.chat_prompt
|
backend/app/problem_grader.py
CHANGED
@@ -40,7 +40,6 @@ class ProblemGradingPipeline:
|
|
40 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
|
41 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
42 |
|
43 |
-
# Build the RAG chain
|
44 |
self.rag_chain = (
|
45 |
{
|
46 |
"context": self.retriever,
|
@@ -53,9 +52,9 @@ class ProblemGradingPipeline:
|
|
53 |
| StrOutputParser()
|
54 |
)
|
55 |
|
56 |
-
def grade(self, query: str, problem: str, answer: str) -> str:
|
57 |
"""
|
58 |
-
|
59 |
|
60 |
Args:
|
61 |
query (str): The topic/context to use for grading
|
@@ -65,7 +64,7 @@ class ProblemGradingPipeline:
|
|
65 |
Returns:
|
66 |
str: Grading response indicating if the answer is correct and providing feedback
|
67 |
"""
|
68 |
-
return self.rag_chain.
|
69 |
"query": query,
|
70 |
"problem": problem,
|
71 |
"answer": answer
|
|
|
40 |
self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
|
41 |
self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
|
42 |
|
|
|
43 |
self.rag_chain = (
|
44 |
{
|
45 |
"context": self.retriever,
|
|
|
52 |
| StrOutputParser()
|
53 |
)
|
54 |
|
55 |
+
async def grade(self, query: str, problem: str, answer: str) -> str:
|
56 |
"""
|
57 |
+
Asynchronously grade a student's answer to a problem using RAG for context-aware evaluation.
|
58 |
|
59 |
Args:
|
60 |
query (str): The topic/context to use for grading
|
|
|
64 |
Returns:
|
65 |
str: Grading response indicating if the answer is correct and providing feedback
|
66 |
"""
|
67 |
+
return await self.rag_chain.ainvoke({
|
68 |
"query": query,
|
69 |
"problem": problem,
|
70 |
"answer": answer
|
backend/tests/test_api.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
from backend.app.main import app
|
|
|
3 |
|
4 |
client = TestClient(app)
|
5 |
|
@@ -19,4 +20,48 @@ def test_problems_endpoint():
|
|
19 |
assert response.status_code == 200
|
20 |
assert "Problems" in response.json()
|
21 |
assert len(response.json()["Problems"]) == 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
from backend.app.main import app
|
3 |
+
import pytest
|
4 |
|
5 |
client = TestClient(app)
|
6 |
|
|
|
20 |
assert response.status_code == 200
|
21 |
assert "Problems" in response.json()
|
22 |
assert len(response.json()["Problems"]) == 5
|
23 |
+
|
24 |
+
def test_feedback_validation_error():
|
25 |
+
"""Test that mismatched problems and answers lengths return 400"""
|
26 |
+
response = client.post(
|
27 |
+
"/api/feedback",
|
28 |
+
json={
|
29 |
+
"user_query": "Python lists",
|
30 |
+
"problems": ["What is a list?", "How do you append?"],
|
31 |
+
"user_answers": ["A sequence",] # Only one answer
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
assert response.status_code == 400
|
36 |
+
assert "same length" in response.json()["detail"]
|
37 |
+
|
38 |
+
@pytest.mark.asyncio
|
39 |
+
async def test_successful_feedback():
|
40 |
+
"""Test successful grading of multiple problems"""
|
41 |
+
response = client.post(
|
42 |
+
"/api/feedback",
|
43 |
+
json={
|
44 |
+
"user_query": "RAG",
|
45 |
+
"problems": [
|
46 |
+
"What are the two main components of a typical RAG application?",
|
47 |
+
"What is the purpose of the indexing component in a RAG application?"
|
48 |
+
],
|
49 |
+
"user_answers": [
|
50 |
+
"A list is a mutable sequence type that can store multiple items in Python",
|
51 |
+
"You use the append() method to add an element to the end of a list"
|
52 |
+
]
|
53 |
+
}
|
54 |
+
)
|
55 |
+
|
56 |
+
assert response.status_code == 200
|
57 |
+
result = response.json()
|
58 |
+
assert "feedback" in result
|
59 |
+
assert len(result["feedback"]) == 2
|
60 |
+
|
61 |
+
# Check that responses start with either "Correct" or "Incorrect"
|
62 |
+
for feedback in result["feedback"]:
|
63 |
+
assert feedback.startswith(("Correct", "Incorrect"))
|
64 |
+
# Check that there's an explanation after the classification
|
65 |
+
assert len(feedback.split(". ")) >= 2
|
66 |
+
|
67 |
|