|
""" |
|
π§ Perplexity AI Integration for AI Dataset Studio |
|
Automatically discovers relevant sources based on project descriptions |
|
""" |
|
|
|
import os |
|
import requests |
|
import json |
|
import logging |
|
import time |
|
import re |
|
from typing import List, Dict, Optional, Tuple |
|
from urllib.parse import urlparse, urljoin |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class SearchType(Enum): |
|
"""Types of searches supported by Perplexity AI""" |
|
GENERAL = "general" |
|
ACADEMIC = "academic" |
|
NEWS = "news" |
|
SOCIAL = "social" |
|
TECHNICAL = "technical" |
|
|
|
@dataclass |
|
class SourceResult: |
|
"""Structure for individual source results""" |
|
url: str |
|
title: str |
|
description: str |
|
relevance_score: float |
|
source_type: str |
|
domain: str |
|
publication_date: Optional[str] = None |
|
author: Optional[str] = None |
|
|
|
@dataclass |
|
class SearchResults: |
|
"""Container for search results""" |
|
query: str |
|
sources: List[SourceResult] |
|
total_found: int |
|
search_time: float |
|
perplexity_response: str |
|
suggestions: List[str] |
|
|
|
class PerplexityClient: |
|
""" |
|
π§ Perplexity AI Client for Smart Source Discovery |
|
|
|
Features: |
|
- Intelligent source discovery based on project descriptions |
|
- Multiple search strategies (academic, news, technical, etc.) |
|
- Quality filtering and relevance scoring |
|
- Rate limiting and error handling |
|
- Domain validation and safety checks |
|
""" |
|
|
|
def __init__(self, api_key: Optional[str] = None): |
|
""" |
|
Initialize Perplexity AI client |
|
|
|
Args: |
|
api_key: Perplexity API key (if not provided, will try env var) |
|
""" |
|
self.api_key = api_key or os.getenv('PERPLEXITY_API_KEY') |
|
self.base_url = "https://api.perplexity.ai" |
|
self.session = requests.Session() |
|
|
|
|
|
if self.api_key: |
|
self.session.headers.update({ |
|
'Authorization': f'Bearer {self.api_key}', |
|
'Content-Type': 'application/json', |
|
'User-Agent': 'AI-Dataset-Studio/1.0' |
|
}) |
|
|
|
|
|
self.last_request_time = 0 |
|
self.min_request_interval = 1.0 |
|
|
|
|
|
self.max_retries = 3 |
|
self.timeout = 30 |
|
|
|
logger.info("π§ Perplexity AI client initialized") |
|
|
|
def _validate_api_key(self) -> bool: |
|
"""Validate that API key is available and working""" |
|
if not self.api_key: |
|
logger.error("β No Perplexity API key found. Set PERPLEXITY_API_KEY environment variable.") |
|
return False |
|
return True |
|
|
|
def _rate_limit(self): |
|
"""Implement rate limiting to respect API limits""" |
|
current_time = time.time() |
|
time_since_last = current_time - self.last_request_time |
|
|
|
if time_since_last < self.min_request_interval: |
|
sleep_time = self.min_request_interval - time_since_last |
|
logger.debug(f"β±οΈ Rate limiting: sleeping {sleep_time:.2f}s") |
|
time.sleep(sleep_time) |
|
|
|
self.last_request_time = time.time() |
|
|
|
def _make_request(self, payload: Dict) -> Optional[Dict]: |
|
""" |
|
Make API request to Perplexity with error handling |
|
|
|
Args: |
|
payload: Request payload |
|
|
|
Returns: |
|
API response or None if failed |
|
""" |
|
if not self._validate_api_key(): |
|
return None |
|
|
|
self._rate_limit() |
|
|
|
for attempt in range(self.max_retries): |
|
try: |
|
logger.debug(f"π‘ Making Perplexity API request (attempt {attempt + 1})") |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/chat/completions", |
|
json=payload, |
|
timeout=self.timeout |
|
) |
|
|
|
if response.status_code == 200: |
|
logger.debug("β
Perplexity API request successful") |
|
return response.json() |
|
elif response.status_code == 429: |
|
logger.warning("π¦ Rate limit hit, waiting longer...") |
|
time.sleep(2 ** attempt) |
|
continue |
|
else: |
|
logger.error(f"β API request failed: {response.status_code} - {response.text}") |
|
|
|
except requests.exceptions.Timeout: |
|
logger.warning(f"β° Request timeout (attempt {attempt + 1})") |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"π Request error: {str(e)}") |
|
|
|
if attempt < self.max_retries - 1: |
|
time.sleep(2 ** attempt) |
|
|
|
logger.error("β All retry attempts failed") |
|
return None |
|
|
|
def discover_sources( |
|
self, |
|
project_description: str, |
|
search_type: SearchType = SearchType.GENERAL, |
|
max_sources: int = 20, |
|
include_academic: bool = True, |
|
include_news: bool = True, |
|
domain_filter: Optional[List[str]] = None |
|
) -> SearchResults: |
|
""" |
|
π Discover relevant sources based on project description |
|
|
|
Args: |
|
project_description: User's project description |
|
search_type: Type of search to perform |
|
max_sources: Maximum number of sources to return |
|
include_academic: Include academic sources |
|
include_news: Include news sources |
|
domain_filter: Optional list of domains to focus on |
|
|
|
Returns: |
|
SearchResults object with discovered sources |
|
""" |
|
start_time = time.time() |
|
|
|
logger.info(f"π Discovering sources for: {project_description[:100]}...") |
|
|
|
|
|
search_prompt = self._build_search_prompt( |
|
project_description, |
|
search_type, |
|
max_sources, |
|
include_academic, |
|
include_news, |
|
domain_filter |
|
) |
|
|
|
|
|
payload = { |
|
"model": "llama-3.1-sonar-large-128k-online", |
|
"messages": [ |
|
{ |
|
"role": "system", |
|
"content": "You are an expert research assistant specializing in finding high-quality, relevant sources for AI/ML dataset creation. Always provide specific URLs, titles, and descriptions." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": search_prompt |
|
} |
|
], |
|
"max_tokens": 4000, |
|
"temperature": 0.3, |
|
"top_p": 0.9 |
|
} |
|
|
|
|
|
response = self._make_request(payload) |
|
|
|
if not response: |
|
logger.error("β Failed to get response from Perplexity API") |
|
return self._create_empty_results(project_description, time.time() - start_time) |
|
|
|
|
|
try: |
|
content = response['choices'][0]['message']['content'] |
|
sources = self._parse_sources_from_response(content) |
|
suggestions = self._extract_suggestions(content) |
|
|
|
search_time = time.time() - start_time |
|
|
|
logger.info(f"β
Found {len(sources)} sources in {search_time:.2f}s") |
|
|
|
return SearchResults( |
|
query=project_description, |
|
sources=sources[:max_sources], |
|
total_found=len(sources), |
|
search_time=search_time, |
|
perplexity_response=content, |
|
suggestions=suggestions |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"β Error parsing Perplexity response: {str(e)}") |
|
return self._create_empty_results(project_description, time.time() - start_time) |
|
|
|
def _build_search_prompt( |
|
self, |
|
project_description: str, |
|
search_type: SearchType, |
|
max_sources: int, |
|
include_academic: bool, |
|
include_news: bool, |
|
domain_filter: Optional[List[str]] |
|
) -> str: |
|
"""Build optimized search prompt for Perplexity AI""" |
|
|
|
prompt = f""" |
|
Find {max_sources} high-quality, diverse sources for an AI/ML dataset creation project: |
|
|
|
PROJECT DESCRIPTION: {project_description} |
|
|
|
SEARCH REQUIREMENTS: |
|
- Find sources with extractable text content suitable for ML training |
|
- Prioritize sources with structured, high-quality content |
|
- Include diverse perspectives and data types |
|
- Focus on sources that are legally scrapable (respect robots.txt) |
|
|
|
SEARCH TYPE: {search_type.value} |
|
""" |
|
|
|
if include_academic: |
|
prompt += "\n- Include academic papers, research articles, and scholarly sources" |
|
|
|
if include_news: |
|
prompt += "\n- Include news articles, press releases, and journalistic content" |
|
|
|
if domain_filter: |
|
prompt += f"\n- Focus on these domains: {', '.join(domain_filter)}" |
|
|
|
prompt += f""" |
|
|
|
OUTPUT FORMAT: |
|
For each source, provide: |
|
1. **URL**: Direct link to the content |
|
2. **Title**: Clear, descriptive title |
|
3. **Description**: 2-3 sentence summary of content and why it's relevant |
|
4. **Type**: [academic/news/blog/government/technical/forum/social] |
|
5. **Quality Score**: 1-10 rating for dataset suitability |
|
|
|
ADDITIONAL REQUIREMENTS: |
|
- Verify URLs are accessible and contain substantial text |
|
- Avoid paywalled or login-required content when possible |
|
- Prioritize sources with consistent formatting |
|
- Include publication dates when available |
|
- Suggest related search terms for expanding the dataset |
|
|
|
Please provide concrete, actionable sources that can be immediately scraped for dataset creation. |
|
""" |
|
|
|
return prompt |
|
|
|
def _parse_sources_from_response(self, content: str) -> List[SourceResult]: |
|
"""Parse source information from Perplexity AI response""" |
|
sources = [] |
|
|
|
|
|
|
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+[^\s<>"{}|\\^`\[\].,!?;:]' |
|
|
|
|
|
sections = re.split(r'\n\s*\n', content) |
|
|
|
for section in sections: |
|
|
|
urls = re.findall(url_pattern, section, re.IGNORECASE) |
|
|
|
if urls: |
|
for url in urls: |
|
try: |
|
|
|
url = url.strip() |
|
|
|
|
|
title = self._extract_title_from_section(section, url) |
|
|
|
|
|
description = self._extract_description_from_section(section, url) |
|
|
|
|
|
source_type = self._determine_source_type(url, section) |
|
|
|
|
|
relevance_score = self._calculate_relevance_score(section, url) |
|
|
|
|
|
domain = self._extract_domain(url) |
|
|
|
|
|
if self._is_valid_url(url): |
|
source = SourceResult( |
|
url=url, |
|
title=title, |
|
description=description, |
|
relevance_score=relevance_score, |
|
source_type=source_type, |
|
domain=domain |
|
) |
|
sources.append(source) |
|
|
|
except Exception as e: |
|
logger.debug(f"β οΈ Error parsing source: {str(e)}") |
|
continue |
|
|
|
|
|
seen_urls = set() |
|
unique_sources = [] |
|
|
|
for source in sources: |
|
if source.url not in seen_urls: |
|
seen_urls.add(source.url) |
|
unique_sources.append(source) |
|
|
|
|
|
unique_sources.sort(key=lambda x: x.relevance_score, reverse=True) |
|
|
|
return unique_sources |
|
|
|
def _extract_title_from_section(self, section: str, url: str) -> str: |
|
"""Extract title from section text""" |
|
lines = section.split('\n') |
|
|
|
for line in lines: |
|
if url in line: |
|
|
|
title_patterns = [ |
|
r'\*\*([^*]+)\*\*', |
|
r'#{1,6}\s*([^\n]+)', |
|
r'Title:\s*([^\n]+)', |
|
r'([^:\n]+):?\s*' + re.escape(url), |
|
] |
|
|
|
for pattern in title_patterns: |
|
match = re.search(pattern, line, re.IGNORECASE) |
|
if match: |
|
return match.group(1).strip() |
|
|
|
|
|
return self._extract_domain(url) |
|
|
|
def _extract_description_from_section(self, section: str, url: str) -> str: |
|
"""Extract description from section text""" |
|
|
|
lines = section.split('\n') |
|
description_lines = [] |
|
|
|
for line in lines: |
|
if url not in line and line.strip(): |
|
|
|
clean_line = re.sub(r'^[#*\-\d\.]+\s*', '', line.strip()) |
|
if len(clean_line) > 20: |
|
description_lines.append(clean_line) |
|
|
|
description = ' '.join(description_lines) |
|
|
|
|
|
if len(description) > 200: |
|
description = description[:200] + "..." |
|
|
|
return description or "High-quality source for dataset creation" |
|
|
|
def _determine_source_type(self, url: str, section: str) -> str: |
|
"""Determine the type of source based on URL and context""" |
|
url_lower = url.lower() |
|
section_lower = section.lower() |
|
|
|
|
|
if any(domain in url_lower for domain in [ |
|
'arxiv.org', 'scholar.google', 'pubmed', 'ieee.org', |
|
'acm.org', 'springer.com', 'elsevier.com', 'nature.com', |
|
'sciencedirect.com', 'jstor.org' |
|
]): |
|
return 'academic' |
|
|
|
|
|
if any(domain in url_lower for domain in [ |
|
'cnn.com', 'bbc.com', 'reuters.com', 'ap.org', 'nytimes.com', |
|
'washingtonpost.com', 'theguardian.com', 'bloomberg.com', |
|
'techcrunch.com', 'wired.com' |
|
]): |
|
return 'news' |
|
|
|
|
|
if '.gov' in url_lower or 'government' in section_lower: |
|
return 'government' |
|
|
|
|
|
if any(domain in url_lower for domain in [ |
|
'docs.', 'documentation', 'github.com', 'stackoverflow.com', |
|
'medium.com', 'dev.to' |
|
]): |
|
return 'technical' |
|
|
|
|
|
if any(domain in url_lower for domain in [ |
|
'twitter.com', 'reddit.com', 'linkedin.com', 'facebook.com' |
|
]): |
|
return 'social' |
|
|
|
|
|
return 'blog' |
|
|
|
def _calculate_relevance_score(self, section: str, url: str) -> float: |
|
"""Calculate relevance score for a source (0-10)""" |
|
score = 5.0 |
|
|
|
section_lower = section.lower() |
|
url_lower = url.lower() |
|
|
|
|
|
quality_indicators = [ |
|
'research', 'study', 'analysis', 'comprehensive', 'detailed', |
|
'expert', 'professional', 'authoritative', 'peer-reviewed', |
|
'dataset', 'data', 'machine learning', 'ai', 'artificial intelligence' |
|
] |
|
|
|
for indicator in quality_indicators: |
|
if indicator in section_lower: |
|
score += 0.5 |
|
|
|
|
|
if any(domain in url_lower for domain in ['arxiv.org', 'scholar.google', 'pubmed']): |
|
score += 2.0 |
|
|
|
|
|
if '.gov' in url_lower: |
|
score += 1.5 |
|
|
|
|
|
if any(domain in url_lower for domain in ['twitter.com', 'facebook.com']): |
|
score -= 1.0 |
|
|
|
|
|
return min(score, 10.0) |
|
|
|
def _extract_domain(self, url: str) -> str: |
|
"""Extract domain from URL""" |
|
try: |
|
parsed = urlparse(url) |
|
return parsed.netloc |
|
except: |
|
return "unknown" |
|
|
|
def _is_valid_url(self, url: str) -> bool: |
|
"""Validate URL format and basic accessibility""" |
|
try: |
|
parsed = urlparse(url) |
|
return all([parsed.scheme, parsed.netloc]) |
|
except: |
|
return False |
|
|
|
def _extract_suggestions(self, content: str) -> List[str]: |
|
"""Extract search suggestions from Perplexity response""" |
|
suggestions = [] |
|
|
|
|
|
suggestion_patterns = [ |
|
r'related search terms?:?\s*([^\n]+)', |
|
r'you might also search for:?\s*([^\n]+)', |
|
r'additional keywords?:?\s*([^\n]+)', |
|
r'suggestions?:?\s*([^\n]+)' |
|
] |
|
|
|
for pattern in suggestion_patterns: |
|
matches = re.findall(pattern, content, re.IGNORECASE) |
|
for match in matches: |
|
|
|
terms = re.split(r'[,;|]', match) |
|
suggestions.extend([term.strip().strip('"\'') for term in terms if term.strip()]) |
|
|
|
return suggestions[:10] |
|
|
|
def _create_empty_results(self, query: str, search_time: float) -> SearchResults: |
|
"""Create empty results object for failed searches""" |
|
return SearchResults( |
|
query=query, |
|
sources=[], |
|
total_found=0, |
|
search_time=search_time, |
|
perplexity_response="", |
|
suggestions=[] |
|
) |
|
|
|
def search_with_keywords(self, keywords: List[str], search_type: SearchType = SearchType.GENERAL) -> SearchResults: |
|
""" |
|
π Search using specific keywords |
|
|
|
Args: |
|
keywords: List of search keywords |
|
search_type: Type of search to perform |
|
|
|
Returns: |
|
SearchResults object |
|
""" |
|
query = " ".join(keywords) |
|
return self.discover_sources( |
|
project_description=f"Find sources related to: {query}", |
|
search_type=search_type |
|
) |
|
|
|
def get_domain_sources(self, domain: str, topic: str, max_sources: int = 10) -> SearchResults: |
|
""" |
|
π Find sources from a specific domain |
|
|
|
Args: |
|
domain: Target domain (e.g., "nature.com") |
|
topic: Topic to search for |
|
max_sources: Maximum sources to return |
|
|
|
Returns: |
|
SearchResults object |
|
""" |
|
return self.discover_sources( |
|
project_description=f"Find articles about {topic} from {domain}", |
|
domain_filter=[domain], |
|
max_sources=max_sources |
|
) |
|
|
|
def validate_sources(self, sources: List[SourceResult]) -> List[SourceResult]: |
|
""" |
|
β
Validate and filter sources for quality and accessibility |
|
|
|
Args: |
|
sources: List of source results to validate |
|
|
|
Returns: |
|
Filtered list of valid sources |
|
""" |
|
valid_sources = [] |
|
|
|
for source in sources: |
|
try: |
|
|
|
if not self._is_valid_url(source.url): |
|
logger.debug(f"β οΈ Invalid URL: {source.url}") |
|
continue |
|
|
|
|
|
domain = self._extract_domain(source.url) |
|
if not domain or domain == "unknown": |
|
logger.debug(f"β οΈ Unknown domain: {source.url}") |
|
continue |
|
|
|
|
|
if source.relevance_score < 3.0: |
|
logger.debug(f"β οΈ Low quality score: {source.url}") |
|
continue |
|
|
|
valid_sources.append(source) |
|
|
|
except Exception as e: |
|
logger.debug(f"β οΈ Error validating source {source.url}: {str(e)}") |
|
continue |
|
|
|
logger.info(f"β
Validated {len(valid_sources)} out of {len(sources)} sources") |
|
return valid_sources |
|
|
|
def export_sources(self, results: SearchResults, format: str = "json") -> str: |
|
""" |
|
π Export search results to various formats |
|
|
|
Args: |
|
results: SearchResults object to export |
|
format: Export format ("json", "csv", "markdown") |
|
|
|
Returns: |
|
Exported data as string |
|
""" |
|
if format.lower() == "json": |
|
return self._export_json(results) |
|
elif format.lower() == "csv": |
|
return self._export_csv(results) |
|
elif format.lower() == "markdown": |
|
return self._export_markdown(results) |
|
else: |
|
raise ValueError(f"Unsupported export format: {format}") |
|
|
|
def _export_json(self, results: SearchResults) -> str: |
|
"""Export results as JSON""" |
|
data = { |
|
"query": results.query, |
|
"total_found": results.total_found, |
|
"search_time": results.search_time, |
|
"sources": [ |
|
{ |
|
"url": source.url, |
|
"title": source.title, |
|
"description": source.description, |
|
"relevance_score": source.relevance_score, |
|
"source_type": source.source_type, |
|
"domain": source.domain, |
|
"publication_date": source.publication_date, |
|
"author": source.author |
|
} |
|
for source in results.sources |
|
], |
|
"suggestions": results.suggestions |
|
} |
|
return json.dumps(data, indent=2) |
|
|
|
def _export_csv(self, results: SearchResults) -> str: |
|
"""Export results as CSV""" |
|
import csv |
|
from io import StringIO |
|
|
|
output = StringIO() |
|
writer = csv.writer(output) |
|
|
|
|
|
writer.writerow([ |
|
"URL", "Title", "Description", "Relevance Score", |
|
"Source Type", "Domain", "Publication Date", "Author" |
|
]) |
|
|
|
|
|
for source in results.sources: |
|
writer.writerow([ |
|
source.url, |
|
source.title, |
|
source.description, |
|
source.relevance_score, |
|
source.source_type, |
|
source.domain, |
|
source.publication_date or "", |
|
source.author or "" |
|
]) |
|
|
|
return output.getvalue() |
|
|
|
def _export_markdown(self, results: SearchResults) -> str: |
|
"""Export results as Markdown""" |
|
md = f"# Search Results for: {results.query}\n\n" |
|
md += f"**Total Sources Found:** {results.total_found}\n" |
|
md += f"**Search Time:** {results.search_time:.2f} seconds\n\n" |
|
|
|
md += "## Sources\n\n" |
|
|
|
for i, source in enumerate(results.sources, 1): |
|
md += f"### {i}. {source.title}\n\n" |
|
md += f"**URL:** {source.url}\n" |
|
md += f"**Type:** {source.source_type}\n" |
|
md += f"**Domain:** {source.domain}\n" |
|
md += f"**Relevance Score:** {source.relevance_score}/10\n" |
|
md += f"**Description:** {source.description}\n\n" |
|
|
|
if results.suggestions: |
|
md += "## Related Search Suggestions\n\n" |
|
for suggestion in results.suggestions: |
|
md += f"- {suggestion}\n" |
|
|
|
return md |
|
|
|
|
|
def test_perplexity_client(): |
|
"""Test function for Perplexity client""" |
|
client = PerplexityClient() |
|
|
|
if not client._validate_api_key(): |
|
print("β No API key found. Set PERPLEXITY_API_KEY environment variable.") |
|
return |
|
|
|
|
|
results = client.discover_sources( |
|
project_description="Create a dataset for sentiment analysis of product reviews", |
|
search_type=SearchType.GENERAL, |
|
max_sources=10 |
|
) |
|
|
|
print(f"π Found {len(results.sources)} sources") |
|
for source in results.sources[:3]: |
|
print(f" - {source.title}: {source.url}") |
|
|
|
|
|
json_export = client.export_sources(results, "json") |
|
print(f"π JSON export: {len(json_export)} characters") |
|
|
|
if __name__ == "__main__": |
|
|
|
test_perplexity_client() |