Spaces:
Running
Running
from concurrent.futures import ThreadPoolExecutor | |
from urllib.parse import quote | |
import os | |
import io | |
from trafilatura import extract | |
from selenium.common.exceptions import TimeoutException | |
from langchain_core.documents.base import Document | |
from langchain_community.vectorstores.faiss import FAISS | |
from langchain_experimental.text_splitter import SemanticChunker | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter | |
from langsmith import traceable | |
import requests | |
import pdfplumber | |
def get_sources(query, max_pages=10, domain=None): | |
""" | |
Fetch search results from the Brave Search API based on the given query. | |
Args: | |
query (str): The search query. | |
max_pages (int): Maximum number of pages to retrieve. | |
domain (str, optional): Limit search to a specific domain. | |
Returns: | |
list: A list of search results with title, link, snippet, and favicon. | |
""" | |
search_query = query | |
if domain: | |
search_query += f" site:{domain}" | |
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}" | |
headers = { | |
'Accept': 'application/json', | |
'Accept-Encoding': 'gzip', | |
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY") | |
} | |
try: | |
response = requests.get(url, headers=headers, timeout=30) | |
if response.status_code != 200: | |
return [] | |
json_response = response.json() | |
if 'web' not in json_response or 'results' not in json_response['web']: | |
print(response.text) | |
raise Exception('Invalid API response format') | |
final_results = [{ | |
'title': result['title'], | |
'link': result['url'], | |
'snippet': extract(result['description'], output_format='txt', include_tables=False, include_images=False, include_formatting=True), | |
'favicon': result.get('profile', {}).get('img', '') | |
} for result in json_response['web']['results']] | |
return final_results | |
except Exception as error: | |
print('Error fetching search results:', error) | |
raise | |
def fetch_with_selenium(url, driver, timeout=8): | |
""" | |
Fetch the HTML content of a webpage using Selenium. | |
Args: | |
url (str): The URL of the webpage. | |
driver: Selenium WebDriver instance. | |
timeout (int): Page load timeout in seconds. | |
Returns: | |
str: The HTML content of the page. | |
""" | |
try: | |
driver.set_page_load_timeout(timeout) | |
driver.get(url) | |
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") | |
html = driver.page_source | |
except TimeoutException: | |
print(f"Page load timed out after {timeout} seconds.") | |
html = None | |
finally: | |
driver.quit() | |
return html | |
def fetch_with_timeout(url, timeout=8): | |
""" | |
Fetch a webpage with a specified timeout. | |
Args: | |
url (str): The URL of the webpage. | |
timeout (int): Request timeout in seconds. | |
Returns: | |
Response: The HTTP response object, or None if an error occurred. | |
""" | |
try: | |
response = requests.get(url, timeout=timeout) | |
response.raise_for_status() | |
return response | |
except requests.RequestException as error: | |
return None | |
def process_source(source): | |
""" | |
Process a single source to extract its content. | |
Args: | |
source (dict): A dictionary containing the source's link and other metadata. | |
Returns: | |
dict: The source with its extracted page content. | |
""" | |
url = source['link'] | |
response = fetch_with_timeout(url, 2) | |
if response: | |
content_type = response.headers.get('Content-Type') | |
if content_type: | |
if content_type.startswith('application/pdf'): | |
# The response is a PDF file | |
pdf_content = response.content | |
# Create a file-like object from the bytes | |
pdf_file = io.BytesIO(pdf_content) | |
# Extract text from PDF using pdfplumber | |
with pdfplumber.open(pdf_file) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return {**source, 'page_content': text} | |
elif content_type.startswith('text/html'): | |
# The response is an HTML file | |
html = response.text | |
main_content = extract(html, output_format='txt', include_links=True) | |
return {**source, 'page_content': main_content} | |
else: | |
print(f"Skipping {url}! Unsupported content type: {content_type}") | |
return {**source, 'page_content': source['snippet']} | |
else: | |
print(f"Skipping {url}! No content type") | |
return {**source, 'page_content': source['snippet']} | |
return {**source, 'page_content': None} | |
def get_links_contents(sources, get_driver_func=None, use_selenium=False): | |
""" | |
Retrieve and process the content of multiple sources. | |
Args: | |
sources (list): A list of source dictionaries. | |
get_driver_func (callable, optional): Function to get a Selenium WebDriver. | |
use_selenium (bool): Whether to use Selenium for fetching content. | |
Returns: | |
list: A list of processed sources with their page content. | |
""" | |
with ThreadPoolExecutor() as executor: | |
results = list(executor.map(process_source, sources)) | |
if get_driver_func is None or not use_selenium: | |
return [result for result in results if result is not None and result['page_content']] | |
for result in results: | |
if result['page_content'] is None: | |
url = result['link'] | |
print(f"Fetching with selenium {url}") | |
driver = get_driver_func() | |
html = fetch_with_selenium(url, driver) | |
main_content = extract(html, output_format='txt', include_links=True) | |
if main_content: | |
result['page_content'] = main_content | |
return results | |
def vectorize(contents, embedding_model): | |
""" | |
Vectorize the contents using the specified embedding model. | |
Args: | |
contents (list): A list of content dictionaries. | |
embedding_model: The embedding model to use. | |
Returns: | |
FAISS: A FAISS vector store containing the vectorized documents. | |
""" | |
documents = [] | |
for content in contents: | |
try: | |
page_content = content['page_content'] | |
if page_content: | |
metadata = {'title': content['title'], 'source': content['link']} | |
doc = Document(page_content=content['page_content'], metadata=metadata) | |
documents.append(doc) | |
except Exception as e: | |
print(f"Error processing content for {content['link']}: {e}") | |
# Initialize recursive text splitter | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
# Split documents | |
split_documents = text_splitter.split_documents(documents) | |
# Create vector store | |
vector_store = None | |
batch_size = 250 # Slightly less than 256 to be safe | |
for i in range(0, len(split_documents), batch_size): | |
batch = split_documents[i:i+batch_size] | |
if vector_store is None: | |
vector_store = FAISS.from_documents(batch, embedding_model) | |
else: | |
texts = [doc.page_content for doc in batch] | |
metadatas = [doc.metadata for doc in batch] | |
embeddings = embedding_model.embed_documents(texts) | |
vector_store.add_embeddings( | |
list(zip(texts, embeddings)), | |
metadatas | |
) | |
return vector_store | |