File size: 2,439 Bytes
58973c7 999f24c 58973c7 26b3c2a 58973c7 e344fab 58973c7 999f24c 58973c7 999f24c 58973c7 999f24c 14044f3 999f24c 14044f3 999f24c 14044f3 999f24c d042c43 a6dd268 14044f3 999f24c 14044f3 999f24c 14044f3 999f24c 14044f3 999f24c 14044f3 d042c43 14044f3 b22f9a0 ca73689 b22f9a0 7d1e4e7 b22f9a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from fastapi.testclient import TestClient
from backend.app.main import app
client = TestClient(app)
def test_crawl_endpoint():
response = client.post(
"/api/ingest/",
json={
"url": "https://example.com",
"topic": "LangChain RAG Tutorial",
},
)
assert response.status_code == 200
assert response.json() == {"status": "RECEIVED"}
def test_problems_endpoint():
response = client.post("/api/problems/", json={"user_query": "RAG"})
assert response.status_code == 200
assert "Problems" in response.json()
assert len(response.json()["Problems"]) == 5
def test_feedback_validation_error():
"""Test that mismatched problems and answers lengths return 400"""
response = client.post(
"/api/feedback",
json={
"user_query": "Python lists",
"problems": ["What is a list?", "How do you append?"],
"user_answers": [
"A sequence",
], # Only one answer
},
)
assert response.status_code == 400
assert "same length" in response.json()["detail"]
# this test can be a bit flaky, but it's not a big deal (because it's checking the content of the response. Correct/Incorrect might be prefaced by /n or something)
def test_successful_feedback():
response = client.post(
"/api/feedback",
json={
"user_query": "RAG",
"problems": [
"What are the two main components of a typical RAG application?",
"What is the purpose of the indexing component in a RAG application?",
],
"user_answers": [
"A list is a mutable sequence type that can store multiple items in Python",
"You use the append() method to add an element to the end of a list",
],
},
)
assert response.status_code == 200
result = response.json()
assert "feedback" in result
assert len(result["feedback"]) == 2
for feedback in result["feedback"]:
assert feedback.strip().startswith(("Correct", "Incorrect"))
assert len(feedback.split(". ")) >= 2
def test_topics_endpoint():
response = client.get("/api/topics")
assert response.status_code == 200
result = response.json()
assert "sources" in result
assert len(result["sources"]) == 1
assert result["sources"][0] == "LangChain RAG Tutorial"
|