search_agent / search_agent.py
CyranoB's picture
Added Ollama and model option.
4c66227
raw
history blame
9.88 kB
"""search_agent.py
Usage:
search_agent.py
[--domain=domain]
[--provider=provider]
[--model=model]
[--temperature=temp]
[--max_pages=num]
[--output=text]
SEARCH_QUERY
search_agent.py --version
Options:
-h --help Show this screen.
--version Show version.
-d domain --domain=domain Limit search to a specific domain
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama) [default: openai]
-m model --model=model Use a specific model
-n num --max_pages=num Max number of pages to retrieve [default: 10]
-o text --output=text Output format (choices: text, markdown) [default: markdown]
"""
import json
import os
from concurrent.futures import ThreadPoolExecutor
from urllib.parse import quote
from bs4 import BeautifulSoup
from docopt import docopt
import dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import SystemMessage, HumanMessage
from langchain.callbacks import LangChainTracer
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from langchain_community.chat_models.bedrock import BedrockChat
from langsmith import Client
import requests
from rich.console import Console
from rich.rule import Rule
from rich.markdown import Markdown
def get_chat_llm(provider, model, temperature=0.0):
match provider:
case 'bedrock':
if(model == None):
model = "anthropic.claude-3-sonnet-20240229-v1:0"
chat_llm = BedrockChat(
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
model_id=model,
model_kwargs={"temperature": temperature },
)
case 'openai':
if(model == None):
model = "gpt-3.5-turbo"
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
case 'groq':
if(model == None):
model = 'mixtral-8x7b-32768'
chat_llm = ChatGroq(model_name=model, temperature=temperature)
case 'ollama':
if(model == None):
model = 'llam2'
chat_llm = ChatOllama(model=model, temperature=temperature)
case _:
raise ValueError(f"Unknown LLM provider {provider}")
console.log(f"Using {model} on {provider} with temperature {temperature}")
return chat_llm
def optimize_search_query(query):
from messages import get_optimized_search_messages
messages = get_optimized_search_messages(query)
response = chat.invoke(messages, config={"callbacks": callbacks})
optimized_search_query = response.content
return optimized_search_query.strip('"').strip("**")
def get_sources(query, max_pages=10, domain=None):
search_query = query
if domain:
search_query += f" site:{domain}"
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
headers = {
'Accept': 'application/json',
'Accept-Encoding': 'gzip',
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
}
try:
response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"HTTP error! status: {response.status_code}")
json_response = response.json()
if 'web' not in json_response or 'results' not in json_response['web']:
raise Exception('Invalid API response format')
final_results = [{
'title': result['title'],
'link': result['url'],
'snippet': result['description'],
'favicon': result.get('profile', {}).get('img', '')
} for result in json_response['web']['results']]
return final_results
except Exception as error:
#console.log('Error fetching search results:', error)
raise
def fetch_with_timeout(url, timeout=8):
try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()
return response
except requests.RequestException as error:
#console.log(f"Skipping {url}! Error: {error}")
return None
def extract_main_content(html):
try:
soup = BeautifulSoup(html, 'html.parser')
for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]):
element.extract()
main_content = ' '.join(soup.body.get_text().split())
return main_content
except Exception as error:
#console.log(f"Error extracting main content: {error}")
return None
def process_source(source):
response = fetch_with_timeout(source['link'], 8)
if response:
html = response.text
main_content = extract_main_content(html)
return {**source, 'html': main_content}
return None
def get_links_contents(sources):
with ThreadPoolExecutor() as executor:
results = list(executor.map(process_source, sources))
# Filter out None results
return [result for result in results if result is not None]
def process_and_vectorize_content(
contents,
query,
text_chunk_size=1000,
text_chunk_overlap=200,
number_of_similarity_results=5
):
"""
Process and vectorize content using Langchain.
Args:
contents (list): List of dictionaries containing 'title', 'link', and 'html' keys.
query (str): Query string for similarity search.
text_chunk_size (int): Size of each text chunk.
text_chunk_overlap (int): Overlap between text chunks.
number_of_similarity_results (int): Number of most similar results to return.
Returns:
list: List of most similar documents.
"""
documents = []
for content in contents:
if content['html']:
try:
# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=text_chunk_size,
chunk_overlap=text_chunk_overlap
)
texts = text_splitter.split_text(content['html'])
# Create metadata for each text chunk
metadatas = [{'title': content['title'], 'link': content['link']} for _ in range(len(texts))]
# Create vector store
embeddings = OpenAIEmbeddings()
docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)
# Perform similarity search
docs = docsearch.similarity_search(query, k=number_of_similarity_results)
doc_dicts = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in docs]
documents.extend(doc_dicts)
except Exception as e:
console.log(f"[gray]Error processing content for {content['link']}: {e}")
return documents
def answer_query_with_sources(query, relevant_docs):
from messages import get_query_with_sources_messages
messages = get_query_with_sources_messages(query, relevant_docs)
response = chat.invoke(messages, config={"callbacks": callbacks})
return response
console = Console()
dotenv.load_dotenv()
callbacks = []
if(os.getenv("LANGCHAIN_API_KEY")):
callbacks.append(
LangChainTracer(
project_name="search agent",
client=Client(
api_url="https://api.smith.langchain.com",
)
)
)
if __name__ == '__main__':
arguments = docopt(__doc__, version='Search Agent 0.1')
provider = arguments["--provider"]
model = arguments["--model"]
temperature = float(arguments["--temperature"])
domain=arguments["--domain"]
max_pages=arguments["--max_pages"]
output=arguments["--output"]
query = arguments["SEARCH_QUERY"]
chat = get_chat_llm(provider, model, temperature)
with console.status(f"[bold green]Optimizing query for search: {query}"):
optimize_search_query = optimize_search_query(query)
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
with console.status(f"[bold green]Searching sources using the optimized query: {optimize_search_query}"):
sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
with console.status(f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"):
contents = get_links_contents(sources)
console.log(f"Managed to extract content from {len(contents)} sources")
with console.status(
f"[bold green]Processing {len(contents)} contents and finding relevant extracts",
spinner="dots8Bit"
):
relevant_docs = process_and_vectorize_content(contents, query)
console.log(f"Filtered {len(relevant_docs)} relevant content extracts")
with console.status(f"[bold green]Querying LLM with {len(relevant_docs)} relevant extracts", spinner='dots8Bit'):
respomse = answer_query_with_sources(query, relevant_docs)
console.rule(f"[bold green]Response from {provider}")
if output == "text":
console.print(respomse.content)
else:
console.print(Markdown(respomse.content))
console.rule("[bold green]")