Spaces:
Running
Running
"""search_agent.py | |
Usage: | |
search_agent.py | |
[--domain=domain] | |
[--provider=provider] | |
[--model=model] | |
[--temperature=temp] | |
[--max_pages=num] | |
[--output=text] | |
SEARCH_QUERY | |
search_agent.py --version | |
Options: | |
-h --help Show this screen. | |
--version Show version. | |
-d domain --domain=domain Limit search to a specific domain | |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0] | |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama) [default: openai] | |
-m model --model=model Use a specific model | |
-n num --max_pages=num Max number of pages to retrieve [default: 10] | |
-o text --output=text Output format (choices: text, markdown) [default: markdown] | |
""" | |
import json | |
import os | |
from concurrent.futures import ThreadPoolExecutor | |
from urllib.parse import quote | |
from bs4 import BeautifulSoup | |
from docopt import docopt | |
import dotenv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import SystemMessage, HumanMessage | |
from langchain.callbacks import LangChainTracer | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_community.chat_models import ChatOllama | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores.faiss import FAISS | |
from langchain_community.chat_models.bedrock import BedrockChat | |
from langsmith import Client | |
import requests | |
from rich.console import Console | |
from rich.rule import Rule | |
from rich.markdown import Markdown | |
def get_chat_llm(provider, model, temperature=0.0): | |
match provider: | |
case 'bedrock': | |
if(model == None): | |
model = "anthropic.claude-3-sonnet-20240229-v1:0" | |
chat_llm = BedrockChat( | |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'), | |
model_id=model, | |
model_kwargs={"temperature": temperature }, | |
) | |
case 'openai': | |
if(model == None): | |
model = "gpt-3.5-turbo" | |
chat_llm = ChatOpenAI(model_name=model, temperature=temperature) | |
case 'groq': | |
if(model == None): | |
model = 'mixtral-8x7b-32768' | |
chat_llm = ChatGroq(model_name=model, temperature=temperature) | |
case 'ollama': | |
if(model == None): | |
model = 'llam2' | |
chat_llm = ChatOllama(model=model, temperature=temperature) | |
case _: | |
raise ValueError(f"Unknown LLM provider {provider}") | |
console.log(f"Using {model} on {provider} with temperature {temperature}") | |
return chat_llm | |
def optimize_search_query(query): | |
from messages import get_optimized_search_messages | |
messages = get_optimized_search_messages(query) | |
response = chat.invoke(messages, config={"callbacks": callbacks}) | |
optimized_search_query = response.content | |
return optimized_search_query.strip('"').strip("**") | |
def get_sources(query, max_pages=10, domain=None): | |
search_query = query | |
if domain: | |
search_query += f" site:{domain}" | |
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}" | |
headers = { | |
'Accept': 'application/json', | |
'Accept-Encoding': 'gzip', | |
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY") | |
} | |
try: | |
response = requests.get(url, headers=headers) | |
if response.status_code != 200: | |
raise Exception(f"HTTP error! status: {response.status_code}") | |
json_response = response.json() | |
if 'web' not in json_response or 'results' not in json_response['web']: | |
raise Exception('Invalid API response format') | |
final_results = [{ | |
'title': result['title'], | |
'link': result['url'], | |
'snippet': result['description'], | |
'favicon': result.get('profile', {}).get('img', '') | |
} for result in json_response['web']['results']] | |
return final_results | |
except Exception as error: | |
#console.log('Error fetching search results:', error) | |
raise | |
def fetch_with_timeout(url, timeout=8): | |
try: | |
response = requests.get(url, timeout=timeout) | |
response.raise_for_status() | |
return response | |
except requests.RequestException as error: | |
#console.log(f"Skipping {url}! Error: {error}") | |
return None | |
def extract_main_content(html): | |
try: | |
soup = BeautifulSoup(html, 'html.parser') | |
for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]): | |
element.extract() | |
main_content = ' '.join(soup.body.get_text().split()) | |
return main_content | |
except Exception as error: | |
#console.log(f"Error extracting main content: {error}") | |
return None | |
def process_source(source): | |
response = fetch_with_timeout(source['link'], 8) | |
if response: | |
html = response.text | |
main_content = extract_main_content(html) | |
return {**source, 'html': main_content} | |
return None | |
def get_links_contents(sources): | |
with ThreadPoolExecutor() as executor: | |
results = list(executor.map(process_source, sources)) | |
# Filter out None results | |
return [result for result in results if result is not None] | |
def process_and_vectorize_content( | |
contents, | |
query, | |
text_chunk_size=1000, | |
text_chunk_overlap=200, | |
number_of_similarity_results=5 | |
): | |
""" | |
Process and vectorize content using Langchain. | |
Args: | |
contents (list): List of dictionaries containing 'title', 'link', and 'html' keys. | |
query (str): Query string for similarity search. | |
text_chunk_size (int): Size of each text chunk. | |
text_chunk_overlap (int): Overlap between text chunks. | |
number_of_similarity_results (int): Number of most similar results to return. | |
Returns: | |
list: List of most similar documents. | |
""" | |
documents = [] | |
for content in contents: | |
if content['html']: | |
try: | |
# Split text into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=text_chunk_size, | |
chunk_overlap=text_chunk_overlap | |
) | |
texts = text_splitter.split_text(content['html']) | |
# Create metadata for each text chunk | |
metadatas = [{'title': content['title'], 'link': content['link']} for _ in range(len(texts))] | |
# Create vector store | |
embeddings = OpenAIEmbeddings() | |
docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas) | |
# Perform similarity search | |
docs = docsearch.similarity_search(query, k=number_of_similarity_results) | |
doc_dicts = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in docs] | |
documents.extend(doc_dicts) | |
except Exception as e: | |
console.log(f"[gray]Error processing content for {content['link']}: {e}") | |
return documents | |
def answer_query_with_sources(query, relevant_docs): | |
from messages import get_query_with_sources_messages | |
messages = get_query_with_sources_messages(query, relevant_docs) | |
response = chat.invoke(messages, config={"callbacks": callbacks}) | |
return response | |
console = Console() | |
dotenv.load_dotenv() | |
callbacks = [] | |
if(os.getenv("LANGCHAIN_API_KEY")): | |
callbacks.append( | |
LangChainTracer( | |
project_name="search agent", | |
client=Client( | |
api_url="https://api.smith.langchain.com", | |
) | |
) | |
) | |
if __name__ == '__main__': | |
arguments = docopt(__doc__, version='Search Agent 0.1') | |
provider = arguments["--provider"] | |
model = arguments["--model"] | |
temperature = float(arguments["--temperature"]) | |
domain=arguments["--domain"] | |
max_pages=arguments["--max_pages"] | |
output=arguments["--output"] | |
query = arguments["SEARCH_QUERY"] | |
chat = get_chat_llm(provider, model, temperature) | |
with console.status(f"[bold green]Optimizing query for search: {query}"): | |
optimize_search_query = optimize_search_query(query) | |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}") | |
with console.status(f"[bold green]Searching sources using the optimized query: {optimize_search_query}"): | |
sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain) | |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}") | |
with console.status(f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"): | |
contents = get_links_contents(sources) | |
console.log(f"Managed to extract content from {len(contents)} sources") | |
with console.status( | |
f"[bold green]Processing {len(contents)} contents and finding relevant extracts", | |
spinner="dots8Bit" | |
): | |
relevant_docs = process_and_vectorize_content(contents, query) | |
console.log(f"Filtered {len(relevant_docs)} relevant content extracts") | |
with console.status(f"[bold green]Querying LLM with {len(relevant_docs)} relevant extracts", spinner='dots8Bit'): | |
respomse = answer_query_with_sources(query, relevant_docs) | |
console.rule(f"[bold green]Response from {provider}") | |
if output == "text": | |
console.print(respomse.content) | |
else: | |
console.print(Markdown(respomse.content)) | |
console.rule("[bold green]") | |