Spaces:
Runtime error
Runtime error
import asyncio | |
import aiohttp | |
import json | |
import time | |
import os | |
from typing import Optional, Dict, Any | |
from pydantic import BaseModel, HttpUrl | |
class NBCNewsAPITest: | |
def __init__(self, base_url: str = "http://localhost:8000"): | |
self.base_url = base_url | |
self.session = None | |
async def __aenter__(self): | |
self.session = aiohttp.ClientSession() | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
if self.session: | |
await self.session.close() | |
async def submit_crawl(self, request_data: Dict[str, Any]) -> str: | |
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response: | |
result = await response.json() | |
return result["task_id"] | |
async def get_task_status(self, task_id: str) -> Dict[str, Any]: | |
async with self.session.get(f"{self.base_url}/task/{task_id}") as response: | |
return await response.json() | |
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]: | |
start_time = time.time() | |
while True: | |
if time.time() - start_time > timeout: | |
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") | |
status = await self.get_task_status(task_id) | |
if status["status"] in ["completed", "failed"]: | |
return status | |
await asyncio.sleep(poll_interval) | |
async def check_health(self) -> Dict[str, Any]: | |
async with self.session.get(f"{self.base_url}/health") as response: | |
return await response.json() | |
async def test_basic_crawl(): | |
print("\n=== Testing Basic Crawl ===") | |
async with NBCNewsAPITest() as api: | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 10 | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
print(f"Basic crawl result length: {len(result['result']['markdown'])}") | |
assert result["status"] == "completed" | |
assert "result" in result | |
assert result["result"]["success"] | |
async def test_js_execution(): | |
print("\n=== Testing JS Execution ===") | |
async with NBCNewsAPITest() as api: | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 8, | |
"js_code": [ | |
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" | |
], | |
"wait_for": "article.tease-card:nth-child(10)", | |
"crawler_params": { | |
"headless": True | |
} | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
print(f"JS execution result length: {len(result['result']['markdown'])}") | |
assert result["status"] == "completed" | |
assert result["result"]["success"] | |
async def test_css_selector(): | |
print("\n=== Testing CSS Selector ===") | |
async with NBCNewsAPITest() as api: | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 7, | |
"css_selector": ".wide-tease-item__description" | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
print(f"CSS selector result length: {len(result['result']['markdown'])}") | |
assert result["status"] == "completed" | |
assert result["result"]["success"] | |
async def test_structured_extraction(): | |
print("\n=== Testing Structured Extraction ===") | |
async with NBCNewsAPITest() as api: | |
schema = { | |
"name": "NBC News Articles", | |
"baseSelector": "article.tease-card", | |
"fields": [ | |
{ | |
"name": "title", | |
"selector": "h2", | |
"type": "text" | |
}, | |
{ | |
"name": "description", | |
"selector": ".tease-card__description", | |
"type": "text" | |
}, | |
{ | |
"name": "link", | |
"selector": "a", | |
"type": "attribute", | |
"attribute": "href" | |
} | |
] | |
} | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 9, | |
"extraction_config": { | |
"type": "json_css", | |
"params": { | |
"schema": schema | |
} | |
} | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print(f"Extracted {len(extracted)} articles") | |
assert result["status"] == "completed" | |
assert result["result"]["success"] | |
assert len(extracted) > 0 | |
async def test_batch_crawl(): | |
print("\n=== Testing Batch Crawl ===") | |
async with NBCNewsAPITest() as api: | |
request = { | |
"urls": [ | |
"https://www.nbcnews.com/business", | |
"https://www.nbcnews.com/business/consumer", | |
"https://www.nbcnews.com/business/economy" | |
], | |
"priority": 6, | |
"crawler_params": { | |
"headless": True | |
} | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
print(f"Batch crawl completed, got {len(result['results'])} results") | |
assert result["status"] == "completed" | |
assert "results" in result | |
assert len(result["results"]) == 3 | |
async def test_llm_extraction(): | |
print("\n=== Testing LLM Extraction with Ollama ===") | |
async with NBCNewsAPITest() as api: | |
schema = { | |
"type": "object", | |
"properties": { | |
"article_title": { | |
"type": "string", | |
"description": "The main title of the news article" | |
}, | |
"summary": { | |
"type": "string", | |
"description": "A brief summary of the article content" | |
}, | |
"main_topics": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "Main topics or themes discussed in the article" | |
} | |
}, | |
"required": ["article_title", "summary", "main_topics"] | |
} | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 8, | |
"extraction_config": { | |
"type": "llm", | |
"params": { | |
"provider": "openai/gpt-4o-mini", | |
"api_key": os.getenv("OLLAMA_API_KEY"), | |
"schema": schema, | |
"extraction_type": "schema", | |
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed. | |
Focus on the primary business news article on the page.""" | |
} | |
}, | |
"crawler_params": { | |
"headless": True, | |
"word_count_threshold": 1 | |
} | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
if result["status"] == "completed": | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print(f"Extracted article analysis:") | |
print(json.dumps(extracted, indent=2)) | |
assert result["status"] == "completed" | |
assert result["result"]["success"] | |
async def test_screenshot(): | |
print("\n=== Testing Screenshot ===") | |
async with NBCNewsAPITest() as api: | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 5, | |
"screenshot": True, | |
"crawler_params": { | |
"headless": True | |
} | |
} | |
task_id = await api.submit_crawl(request) | |
result = await api.wait_for_task(task_id) | |
print("Screenshot captured:", bool(result["result"]["screenshot"])) | |
assert result["status"] == "completed" | |
assert result["result"]["success"] | |
assert result["result"]["screenshot"] is not None | |
async def test_priority_handling(): | |
print("\n=== Testing Priority Handling ===") | |
async with NBCNewsAPITest() as api: | |
# Submit low priority task first | |
low_priority = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 1, | |
"crawler_params": {"headless": True} | |
} | |
low_task_id = await api.submit_crawl(low_priority) | |
# Submit high priority task | |
high_priority = { | |
"urls": "https://www.nbcnews.com/business/consumer", | |
"priority": 10, | |
"crawler_params": {"headless": True} | |
} | |
high_task_id = await api.submit_crawl(high_priority) | |
# Get both results | |
high_result = await api.wait_for_task(high_task_id) | |
low_result = await api.wait_for_task(low_task_id) | |
print("Both tasks completed") | |
assert high_result["status"] == "completed" | |
assert low_result["status"] == "completed" | |
async def main(): | |
try: | |
# Start with health check | |
async with NBCNewsAPITest() as api: | |
health = await api.check_health() | |
print("Server health:", health) | |
# Run all tests | |
# await test_basic_crawl() | |
# await test_js_execution() | |
# await test_css_selector() | |
# await test_structured_extraction() | |
await test_llm_extraction() | |
# await test_batch_crawl() | |
# await test_screenshot() | |
# await test_priority_handling() | |
except Exception as e: | |
print(f"Test failed: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
asyncio.run(main()) |