CyranoB commited on
Commit
7406911
·
1 Parent(s): bb1d601

Enhance model support, improve documentation, and refactor core components

Browse files

- Added support for new language models including DeepSeek, Qwen, and expanded model options
- Updated README with comprehensive project structure and usage details
- Refactored core modules (copywriter.py, models.py, nlp_rag.py) with improved type hints and docstrings
- Updated default models and embedding configurations
- Enhanced search agent with more flexible model selection and verbose output options
- Improved requirements.txt with latest library dependencies

Files changed (8) hide show
  1. README.md +47 -15
  2. copywriter.py +4 -9
  3. models.py +29 -11
  4. nlp_rag.py +41 -39
  5. requirements.txt +4 -1
  6. search_agent.py +42 -44
  7. web_crawler.py +67 -7
  8. web_rag.py +53 -20
README.md CHANGED
@@ -37,7 +37,7 @@ To run the script, users need to provide their API keys for the desired language
37
 
38
  ## Features
39
 
40
- - Supports multiple language model providers (Bedrock, OpenAI, Groq, Cohere, and Ollama)
41
  - Optimizes search queries using a language model
42
  - Fetches web pages and extracts main content (HTML and PDF)
43
  - Vectorizes the content for efficient retrieval
@@ -55,8 +55,9 @@ To run the script, users need to provide their API keys for the desired language
55
 
56
  3. Set up API keys:
57
 
58
- - You will need API keys for the Brave Search API and LLM API.
59
- - Add your API keys to the `.env` file. Use `dotenv.sample` to create this file.
 
60
 
61
  ## Usage
62
 
@@ -68,20 +69,29 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
68
 
69
  ### Options:
70
 
71
- - `-h`, `--help`: Show this help message and exit.
72
- - `--version`: Show the program's version number and exit.
73
- - `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
74
- - `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
75
- - `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
76
- - `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai:gpt-4o-mini].
77
- - `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
78
- - `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
79
- - `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
80
- - `-s`, `--use_selenium`: Use selenium to fetch content from the web [default: False].
81
- - `-o TEXT`, `--output=TEXT`: Output format (choices: text, markdown) [default: markdown].
 
 
 
 
 
82
 
83
  ### Examples
84
 
 
 
 
 
85
  ```bash
86
  python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
87
  ```
@@ -98,4 +108,26 @@ python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the curr
98
 
99
  This project is licensed under the Apache License Version 2.0. See the `LICENSE` file for details.
100
 
101
- Let me know if you have any other questions! The key components are using a web search API to find relevant information, extracting the key snippets from the search results, passing that as context to a large language model, and having the LLM generate a natural language answer based on the web search context.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ## Features
39
 
40
+ - Supports multiple language model providers (HuggingFace, Bedrock, OpenAI, Groq, Cohere, and Ollama)
41
  - Optimizes search queries using a language model
42
  - Fetches web pages and extracts main content (HTML and PDF)
43
  - Vectorizes the content for efficient retrieval
 
55
 
56
  3. Set up API keys:
57
 
58
+ - create a `.env` file and add your API keys. Use `dotenv.sample` to create this file.
59
+ - Get an API key from the following sources: https://brave.com/search/api/
60
+ - Optionally you can add API keys from other LLM providers.
61
 
62
  ## Usage
63
 
 
69
 
70
  ### Options:
71
 
72
+ -h --help Show this screen.
73
+ --version Show version.
74
+ -c --copywrite First produce a draft, review it and rewrite for a final text
75
+ -d domain --domain=domain Limit search to a specific domain
76
+ -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
77
+ -m model --model=model Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct]
78
+ -e model --embedding_model=model Use an embedding model
79
+ -n num --max_pages=num Max number of pages to retrieve [default: 10]
80
+ -x num --max_extracts=num Max number of page extract to consider [default: 7]
81
+ -b --use_browser Use browser to fetch content from the web [default: False]
82
+ -o text --output=text Output format (choices: text, markdown) [default: markdown]
83
+ -v --verbose Print verbose output [default: False]
84
+
85
+ The model can be a language model provider and a model name separated by a colon. e.g. `openai:gpt-4o-mini`
86
+ If a embedding model is not specified, spaCy will be used for semantic search.
87
+
88
 
89
  ### Examples
90
 
91
+ ```bash
92
+ python search_agent.py 'What is the radioactive anomaly in the Pacific Ocean?'
93
+ ```
94
+
95
  ```bash
96
  python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
97
  ```
 
108
 
109
  This project is licensed under the Apache License Version 2.0. See the `LICENSE` file for details.
110
 
111
+ Let me know if you have any other questions! The key components are using a web search API to find relevant information, extracting the key snippets from the search results, passing that as context to a large language model, and having the LLM generate a natural language answer based on the web search context.
112
+
113
+ ## Project Structure
114
+
115
+ The project consists of several key components:
116
+
117
+ - `search_agent.py`: The main script that handles the core search agent functionality
118
+ - `search_agent_ui.py`: Streamlit-based user interface for the search agent
119
+ - `web_crawler.py`: Handles web content fetching and processing
120
+ - `web_rag.py`: Implements the Retrieval-Augmented Generation (RAG) functionality
121
+ - `nlp_rag.py`: Natural language processing utilities for RAG
122
+ - `models.py`: Contains model definitions and configurations
123
+ - `copywriter.py`: Implements content rewriting and optimization features
124
+
125
+ ## Additional Tools
126
+
127
+ The project includes several development and configuration files:
128
+
129
+ - `requirements.txt`: Lists all Python dependencies
130
+ - `.env`: Configuration file for API keys and settings (use `dotenv.sample` as a template)
131
+ - `.gitignore`: Specifies which files Git should ignore
132
+ - `LICENSE`: Apache License Version 2.0
133
+ - `.devcontainer/`: Contains development container configuration for consistent development environments
copywriter.py CHANGED
@@ -1,10 +1,4 @@
1
  from langchain.schema import SystemMessage, HumanMessage
2
- from langchain.prompts.chat import (
3
- HumanMessagePromptTemplate,
4
- SystemMessagePromptTemplate,
5
- ChatPromptTemplate
6
- )
7
- from langchain.prompts.prompt import PromptTemplate
8
  from langsmith import traceable
9
 
10
 
@@ -21,7 +15,6 @@ def get_comments_prompt(query, draft):
21
  5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
22
  6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
23
  7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
24
-
25
  Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
26
  You never generate the corrected text by itself. *Only* give the comment.
27
  """
@@ -35,12 +28,14 @@ def get_comments_prompt(query, draft):
35
  )
36
  return [system_message, human_message]
37
 
 
38
  @traceable(run_type="llm", name="generate_comments")
39
  def generate_comments(chat_llm, query, draft, callbacks=[]):
40
  messages = get_comments_prompt(query, draft)
41
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
42
  return response.content
43
 
 
44
  def get_final_text_prompt(query, draft, comments):
45
  system_message = SystemMessage(
46
  content="""
@@ -73,7 +68,7 @@ def get_final_text_prompt(query, draft, comments):
73
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
74
  messages = get_final_text_prompt(query, draft, comments)
75
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
76
- return response.content
77
 
78
 
79
  def get_compare_texts_prompts(query, draft_text, final_text):
@@ -109,4 +104,4 @@ def get_compare_texts_prompts(query, draft_text, final_text):
109
  def compare_text(chat_llm, query, draft, final, callbacks=[]):
110
  messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
111
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
112
- return response.content
 
1
  from langchain.schema import SystemMessage, HumanMessage
 
 
 
 
 
 
2
  from langsmith import traceable
3
 
4
 
 
15
  5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
16
  6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
17
  7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
 
18
  Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
19
  You never generate the corrected text by itself. *Only* give the comment.
20
  """
 
28
  )
29
  return [system_message, human_message]
30
 
31
+
32
  @traceable(run_type="llm", name="generate_comments")
33
  def generate_comments(chat_llm, query, draft, callbacks=[]):
34
  messages = get_comments_prompt(query, draft)
35
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
36
  return response.content
37
 
38
+
39
  def get_final_text_prompt(query, draft, comments):
40
  system_message = SystemMessage(
41
  content="""
 
68
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
69
  messages = get_final_text_prompt(query, draft, comments)
70
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
71
+ return response.content
72
 
73
 
74
  def get_compare_texts_prompts(query, draft_text, final_text):
 
104
  def compare_text(chat_llm, query, draft, final, callbacks=[]):
105
  messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
106
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
107
+ return response.content
models.py CHANGED
@@ -30,6 +30,10 @@ from langchain.chat_models.base import BaseChatModel
30
  from langchain.embeddings.base import Embeddings
31
 
32
  def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
 
 
 
 
33
  parts = provider_model.split(":", 1)
34
  provider = parts[0]
35
  if len(parts) > 1:
@@ -48,7 +52,7 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
48
  match provider.lower():
49
  case 'anthropic':
50
  if model is None:
51
- model = "claude-3-sonnet-20240229"
52
  chat_llm = ChatAnthropic(model=model, temperature=temperature)
53
  case 'bedrock':
54
  if model is None:
@@ -58,22 +62,32 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
58
  if model is None:
59
  model = 'command-r-plus'
60
  chat_llm = ChatCohere(model=model, temperature=temperature)
 
 
 
 
 
 
 
 
 
 
61
  case 'fireworks':
62
  if model is None:
63
- model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
64
  chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
65
  case 'googlegenerativeai':
66
  if model is None:
67
- model = "gemini-1.5-flash"
68
  chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
69
  max_tokens=None, timeout=None, max_retries=2,)
70
  case 'groq':
71
  if model is None:
72
- model = 'llama-3.1-8b-instant'
73
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
74
  case 'huggingface' | 'hf':
75
  if model is None:
76
- model = 'mistralai/Mistral-Nemo-Instruct-2407'
77
  llm = HuggingFaceEndpoint(
78
  repo_id=model,
79
  max_length=8192,
@@ -91,23 +105,23 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
91
  chat_llm = ChatOpenAI(model=model, temperature=temperature)
92
  case 'openrouter':
93
  if model is None:
94
- model = "google/gemini-flash-1.5-exp"
95
  chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
96
  case 'mistralai' | 'mistral':
97
  if model is None:
98
- model = "open-mistral-nemo"
99
  chat_llm = ChatMistralAI(model=model, temperature=temperature)
100
  case 'perplexity':
101
  if model is None:
102
- model = 'llama-3.1-sonar-small-128k-online'
103
  chat_llm = ChatPerplexity(model=model, temperature=temperature)
104
  case 'together':
105
  if model is None:
106
- model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
107
  chat_llm = ChatTogether(model=model, temperature=temperature)
108
  case 'xai':
109
  if model is None:
110
- model = 'grok-beta'
111
  chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
112
  case _:
113
  raise ValueError(f"Unknown LLM provider {provider}")
@@ -118,6 +132,10 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
118
 
119
 
120
  def get_embedding_model(provider_model: str) -> Embeddings:
 
 
 
 
121
  provider, model = split_provider_model(provider_model)
122
  match provider.lower():
123
  case 'bedrock':
@@ -126,7 +144,7 @@ def get_embedding_model(provider_model: str) -> Embeddings:
126
  embedding_model = BedrockEmbeddings(model_id=model)
127
  case 'cohere':
128
  if model is None:
129
- model = "embed-multilingual-v3"
130
  embedding_model = CohereEmbeddings(model=model)
131
  case 'fireworks':
132
  if model is None:
 
30
  from langchain.embeddings.base import Embeddings
31
 
32
  def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
33
+ """
34
+ Split the provider and model name from a string.
35
+ returns Tuple[str, Optional[str]]
36
+ """
37
  parts = provider_model.split(":", 1)
38
  provider = parts[0]
39
  if len(parts) > 1:
 
52
  match provider.lower():
53
  case 'anthropic':
54
  if model is None:
55
+ model = "claude-3-5-haiku-20241022"
56
  chat_llm = ChatAnthropic(model=model, temperature=temperature)
57
  case 'bedrock':
58
  if model is None:
 
62
  if model is None:
63
  model = 'command-r-plus'
64
  chat_llm = ChatCohere(model=model, temperature=temperature)
65
+
66
+ case 'deepseek':
67
+ if model is None:
68
+ model='deepseek-chat'
69
+ chat_llm = ChatOpenAI(
70
+ model=model,
71
+ openai_api_key=os.getenv("DEEPSEEK_API_KEY"),
72
+ openai_api_base='https://api.deepseek.com',
73
+ max_tokens=8192
74
+ )
75
  case 'fireworks':
76
  if model is None:
77
+ model = 'accounts/fireworks/models/llama-v3p3-70b-instruct'
78
  chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
79
  case 'googlegenerativeai':
80
  if model is None:
81
+ model = "gemini-2.0-flash-exp"
82
  chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
83
  max_tokens=None, timeout=None, max_retries=2,)
84
  case 'groq':
85
  if model is None:
86
+ model = 'qwen-2.5-32b'
87
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
88
  case 'huggingface' | 'hf':
89
  if model is None:
90
+ model = 'Qwen/Qwen2.5-72B-Instruct'
91
  llm = HuggingFaceEndpoint(
92
  repo_id=model,
93
  max_length=8192,
 
105
  chat_llm = ChatOpenAI(model=model, temperature=temperature)
106
  case 'openrouter':
107
  if model is None:
108
+ model = "cognitivecomputations/dolphin3.0-mistral-24b:free"
109
  chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
110
  case 'mistralai' | 'mistral':
111
  if model is None:
112
+ model = "mistral-small-latest"
113
  chat_llm = ChatMistralAI(model=model, temperature=temperature)
114
  case 'perplexity':
115
  if model is None:
116
+ model = 'sonar'
117
  chat_llm = ChatPerplexity(model=model, temperature=temperature)
118
  case 'together':
119
  if model is None:
120
+ model = 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free'
121
  chat_llm = ChatTogether(model=model, temperature=temperature)
122
  case 'xai':
123
  if model is None:
124
+ model = 'grok-2-1212'
125
  chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
126
  case _:
127
  raise ValueError(f"Unknown LLM provider {provider}")
 
132
 
133
 
134
  def get_embedding_model(provider_model: str) -> Embeddings:
135
+ """
136
+ Get an embedding model from a provider and model name.
137
+ returns Embeddings
138
+ """
139
  provider, model = split_provider_model(provider_model)
140
  match provider.lower():
141
  case 'bedrock':
 
144
  embedding_model = BedrockEmbeddings(model_id=model)
145
  case 'cohere':
146
  if model is None:
147
+ model = "embed-multilingual-v3.0"
148
  embedding_model = CohereEmbeddings(model=model)
149
  case 'fireworks':
150
  if model is None:
nlp_rag.py CHANGED
@@ -6,6 +6,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
6
  import numpy as np
7
 
8
  def get_nlp_model():
 
 
 
 
 
 
9
  if not spacy.util.is_package("en_core_web_md"):
10
  print("Downloading en_core_web_md model...")
11
  spacy.cli.download("en_core_web_md")
@@ -15,6 +21,17 @@ def get_nlp_model():
15
 
16
 
17
  def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
 
 
 
 
 
 
 
 
 
 
 
18
  from langchain_core.documents.base import Document
19
  from langchain.text_splitter import RecursiveCharacterTextSplitter
20
 
@@ -51,6 +68,19 @@ def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
51
 
52
 
53
  def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Precompute query vector and its norm
55
  query_vector = nlp(query).vector
56
  query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
@@ -84,47 +114,19 @@ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
84
  return relevant_chunks[:top_n]
85
 
86
 
87
- # Perform semantic search using spaCy
88
- def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
89
- import numpy as np
90
- from concurrent.futures import ThreadPoolExecutor
91
-
92
- # Precompute query vector and its norm with epsilon to prevent division by zero
93
- with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
94
- query_vector = nlp(query).vector
95
- query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon
96
-
97
- # Prepare texts from chunks
98
- texts = [chunk['text'] for chunk in chunks]
99
-
100
- # Function to process each text and compute its vector
101
- def compute_vector(text):
102
- with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
103
- doc = nlp(text)
104
- vector = doc.vector
105
- return vector
106
-
107
- # Process texts in parallel using ThreadPoolExecutor
108
- with ThreadPoolExecutor() as executor:
109
- chunk_vectors = list(executor.map(compute_vector, texts))
110
-
111
- chunk_vectors = np.array(chunk_vectors)
112
- chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon
113
-
114
- # Compute similarities using vectorized operations
115
- similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
116
-
117
- # Filter and sort results
118
- relevant_chunks = [
119
- (chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
120
- ]
121
- relevant_chunks.sort(key=lambda x: x[1], reverse=True)
122
-
123
- return relevant_chunks[:top_n]
124
-
125
-
126
  @traceable(run_type="llm", name="nlp_rag")
127
  def query_rag(chat_llm, query, relevant_results):
 
 
 
 
 
 
 
 
 
 
 
128
  import web_rag as wr
129
 
130
  formatted_chunks = ""
 
6
  import numpy as np
7
 
8
  def get_nlp_model():
9
+ """
10
+ Load and return the spaCy NLP model. Downloads the model if not already installed.
11
+
12
+ Returns:
13
+ nlp: The loaded spaCy NLP model.
14
+ """
15
  if not spacy.util.is_package("en_core_web_md"):
16
  print("Downloading en_core_web_md model...")
17
  spacy.cli.download("en_core_web_md")
 
21
 
22
 
23
  def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
24
+ """
25
+ Split documents into smaller chunks using a recursive character text splitter.
26
+
27
+ Args:
28
+ contents (list): List of content dictionaries with 'page_content', 'title', and 'link'.
29
+ max_chunk_size (int): Maximum size of each chunk.
30
+ overlap (int): Overlap between chunks.
31
+
32
+ Returns:
33
+ list: List of chunks with text and metadata.
34
+ """
35
  from langchain_core.documents.base import Document
36
  from langchain.text_splitter import RecursiveCharacterTextSplitter
37
 
 
68
 
69
 
70
  def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
71
+ """
72
+ Perform semantic search to find relevant chunks based on similarity to the query.
73
+
74
+ Args:
75
+ query (str): The search query.
76
+ chunks (list): List of text chunks with vectors.
77
+ nlp: The spaCy NLP model.
78
+ similarity_threshold (float): Minimum similarity score to consider a chunk relevant.
79
+ top_n (int): Number of top relevant chunks to return.
80
+
81
+ Returns:
82
+ list: List of relevant chunks and their similarity scores.
83
+ """
84
  # Precompute query vector and its norm
85
  query_vector = nlp(query).vector
86
  query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
 
114
  return relevant_chunks[:top_n]
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @traceable(run_type="llm", name="nlp_rag")
118
  def query_rag(chat_llm, query, relevant_results):
119
+ """
120
+ Generate a response using retrieval-augmented generation (RAG) based on relevant results.
121
+
122
+ Args:
123
+ chat_llm: The chat language model to use.
124
+ query (str): The user's query.
125
+ relevant_results (list): List of relevant chunks and their similarity scores.
126
+
127
+ Returns:
128
+ str: The generated response.
129
+ """
130
  import web_rag as wr
131
 
132
  formatted_chunks = ""
requirements.txt CHANGED
@@ -3,18 +3,21 @@ boto3 >= 1.34.131, < 1.35.0
3
  bs4
4
  chromedriver-py >= 128.0.6613.137
5
  cohere >= 5.9.2
6
- docopt >= 0.6.2
7
  faiss-cpu >= 1.8.0
8
  google-api-python-client >= 2.145.0
9
  pdfplumber >= 0.11.4
10
  python-dotenv >= 1.0.1
11
  langchain >= 0.3.0
 
12
  langchain-aws >= 0.2.0
13
  langchain-fireworks
14
  langchain_core >= 0.3.0
15
  langchain-cohere
16
  langchain_community
17
  langchain_experimental
 
 
18
  langchain_openai
19
  langchain-ollama
20
  langchain_groq
 
3
  bs4
4
  chromedriver-py >= 128.0.6613.137
5
  cohere >= 5.9.2
6
+ docopt
7
  faiss-cpu >= 1.8.0
8
  google-api-python-client >= 2.145.0
9
  pdfplumber >= 0.11.4
10
  python-dotenv >= 1.0.1
11
  langchain >= 0.3.0
12
+ langchain_anthropic
13
  langchain-aws >= 0.2.0
14
  langchain-fireworks
15
  langchain_core >= 0.3.0
16
  langchain-cohere
17
  langchain_community
18
  langchain_experimental
19
+ langchain_huggingface
20
+ langchain_mistralai
21
  langchain_openai
22
  langchain-ollama
23
  langchain_groq
search_agent.py CHANGED
@@ -1,7 +1,8 @@
1
- """search_agent.py
 
2
 
3
  Usage:
4
- search_agent.py
5
  [--domain=domain]
6
  [--provider=provider]
7
  [--model=model]
@@ -22,7 +23,7 @@ Options:
22
  -c --copywrite First produce a draft, review it and rewrite for a final text
23
  -d domain --domain=domain Limit search to a specific domain
24
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
25
- -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
26
  -e model --embedding_model=model Use an embedding model
27
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
28
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
@@ -35,7 +36,6 @@ Options:
35
  import os
36
 
37
  from docopt import docopt
38
- #from schema import Schema, Use, SchemaError
39
  import dotenv
40
 
41
  from langchain.callbacks import LangChainTracer
@@ -51,10 +51,13 @@ import copywriter as cw
51
  import models as md
52
  import nlp_rag as nr
53
 
 
54
  console = Console()
 
55
  dotenv.load_dotenv()
56
 
57
  def get_selenium_driver():
 
58
  from selenium import webdriver
59
  from selenium.webdriver.chrome.options import Options
60
  from selenium.common.exceptions import WebDriverException
@@ -76,72 +79,76 @@ def get_selenium_driver():
76
  print(f"Error creating Selenium WebDriver: {e}")
77
  return None
78
 
 
79
  callbacks = []
 
80
  if os.getenv("LANGCHAIN_API_KEY"):
81
  callbacks.append(
82
  LangChainTracer(client=Client())
83
  )
 
84
  @traceable(run_type="tool", name="search_agent")
85
  def main(arguments):
 
86
  verbose = arguments["--verbose"]
87
  copywrite_mode = arguments["--copywrite"]
88
  model = arguments["--model"]
89
  embedding_model = arguments["--embedding_model"]
90
  temperature = float(arguments["--temperature"])
91
- domain=arguments["--domain"]
92
- max_pages=int(arguments["--max_pages"])
93
- max_extract=int(arguments["--max_extracts"])
94
- output=arguments["--output"]
95
- use_selenium=arguments["--use_browser"]
96
  query = arguments["SEARCH_QUERY"]
97
 
 
98
  chat = md.get_model(model, temperature)
 
 
99
  if embedding_model is None:
100
  use_nlp = True
101
  nlp = nr.get_nlp_model()
102
  else:
 
103
  embedding_model = md.get_embedding_model(embedding_model)
104
- use_nlp = False
105
 
 
106
  if verbose:
107
  model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
108
- console.log(f"Using embedding model: {embedding_model_name}")
109
  if not use_nlp:
110
  embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
111
- console.log(f"Using model: {embedding_model_name}")
112
-
113
 
 
114
  with console.status(f"[bold green]Optimizing query for search: {query}"):
115
  optimized_search_query = wr.optimize_search_query(chat, query)
116
  if len(optimized_search_query) < 3:
117
  optimized_search_query = query
118
  console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
119
 
 
120
  with console.status(
121
  f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
122
  ):
123
  sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
124
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
125
 
 
126
  with console.status(
127
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
128
  ):
129
  contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
130
  console.log(f"Managed to extract content from {len(contents)} sources")
131
 
 
132
  if use_nlp:
133
- with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
134
  chunks = nr.recursive_split_documents(contents)
135
- #chunks = nr.chunk_contents(nlp, contents)
136
  console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
137
  with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
138
- import time
139
-
140
- start_time = time.time()
141
  relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
142
- end_time = time.time()
143
- execution_time = end_time - start_time
144
- console.log(f"Semantic search took {execution_time:.2f} seconds")
145
  console.log(f"Found {len(relevant_results)} relevant chunks")
146
  with console.status(f"[bold green]Writing content", spinner="growVertical"):
147
  draft = nr.query_rag(chat, query, relevant_results)
@@ -149,38 +156,29 @@ def main(arguments):
149
  with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
150
  vector_store = wc.vectorize(contents, embedding_model)
151
  with console.status("[bold green]Writing content", spinner='dots8Bit'):
152
- draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k = max_extract)
153
-
154
 
155
- console.rule(f"[bold green]Response")
156
- if output == "text":
157
- console.print(draft)
158
- else:
159
- console.print(Markdown(draft))
160
- console.rule("[bold green]")
161
-
162
  if(copywrite_mode):
163
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
164
  comments = cw.generate_comments(chat, query, draft)
165
 
166
- console.rule("[bold green]Response from reviewer")
167
- if output == "text":
168
- console.print(comments)
169
- else:
170
- console.print(Markdown(comments))
171
- console.rule("[bold green]")
172
-
173
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
174
  final_text = cw.generate_final_text(chat, query, draft, comments)
 
 
175
 
176
- console.rule("[bold green]Final text")
177
- if output == "text":
178
- console.print(final_text)
179
- else:
180
- console.print(Markdown(final_text))
181
- console.rule("[bold green]")
 
 
 
182
 
183
  if __name__ == '__main__':
 
184
  arguments = docopt(__doc__, version='Search Agent 0.1')
185
  main(arguments)
186
-
 
1
+ """
2
+ search_agent.py
3
 
4
  Usage:
5
+ search_agent.py
6
  [--domain=domain]
7
  [--provider=provider]
8
  [--model=model]
 
23
  -c --copywrite First produce a draft, review it and rewrite for a final text
24
  -d domain --domain=domain Limit search to a specific domain
25
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
26
+ -m model --model=model Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct]
27
  -e model --embedding_model=model Use an embedding model
28
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
29
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
 
36
  import os
37
 
38
  from docopt import docopt
 
39
  import dotenv
40
 
41
  from langchain.callbacks import LangChainTracer
 
51
  import models as md
52
  import nlp_rag as nr
53
 
54
+ # Initialize console for rich text output
55
  console = Console()
56
+ # Load environment variables from a .env file
57
  dotenv.load_dotenv()
58
 
59
  def get_selenium_driver():
60
+ """Initialize and return a headless Selenium WebDriver for Chrome."""
61
  from selenium import webdriver
62
  from selenium.webdriver.chrome.options import Options
63
  from selenium.common.exceptions import WebDriverException
 
79
  print(f"Error creating Selenium WebDriver: {e}")
80
  return None
81
 
82
+ # Initialize callbacks list
83
  callbacks = []
84
+ # Add LangChainTracer to callbacks if API key is set
85
  if os.getenv("LANGCHAIN_API_KEY"):
86
  callbacks.append(
87
  LangChainTracer(client=Client())
88
  )
89
+
90
  @traceable(run_type="tool", name="search_agent")
91
  def main(arguments):
92
+ """Main function to execute the search agent logic."""
93
  verbose = arguments["--verbose"]
94
  copywrite_mode = arguments["--copywrite"]
95
  model = arguments["--model"]
96
  embedding_model = arguments["--embedding_model"]
97
  temperature = float(arguments["--temperature"])
98
+ domain = arguments["--domain"]
99
+ max_pages = int(arguments["--max_pages"])
100
+ max_extract = int(arguments["--max_extracts"])
101
+ output = arguments["--output"]
102
+ use_selenium = arguments["--use_browser"]
103
  query = arguments["SEARCH_QUERY"]
104
 
105
+ # Get the language model based on the provided model name and temperature
106
  chat = md.get_model(model, temperature)
107
+
108
+ # If no embedding model is provided, use spacy for semantic search
109
  if embedding_model is None:
110
  use_nlp = True
111
  nlp = nr.get_nlp_model()
112
  else:
113
+ use_nlp = False
114
  embedding_model = md.get_embedding_model(embedding_model)
 
115
 
116
+ # Log model details if verbose mode is enabled
117
  if verbose:
118
  model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
119
+ console.log(f"Using model: {model_name}")
120
  if not use_nlp:
121
  embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
122
+ console.log(f"Using embedding model: {embedding_model_name}")
 
123
 
124
+ # Optimize the search query
125
  with console.status(f"[bold green]Optimizing query for search: {query}"):
126
  optimized_search_query = wr.optimize_search_query(chat, query)
127
  if len(optimized_search_query) < 3:
128
  optimized_search_query = query
129
  console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
130
 
131
+ # Retrieve sources using the optimized query
132
  with console.status(
133
  f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
134
  ):
135
  sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
136
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
137
 
138
+ # Fetch content from the retrieved sources
139
  with console.status(
140
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
141
  ):
142
  contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
143
  console.log(f"Managed to extract content from {len(contents)} sources")
144
 
145
+ # Process content using spaCy or embedding model
146
  if use_nlp:
147
+ with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
148
  chunks = nr.recursive_split_documents(contents)
 
149
  console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
150
  with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
 
 
 
151
  relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
 
 
 
152
  console.log(f"Found {len(relevant_results)} relevant chunks")
153
  with console.status(f"[bold green]Writing content", spinner="growVertical"):
154
  draft = nr.query_rag(chat, query, relevant_results)
 
156
  with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
157
  vector_store = wc.vectorize(contents, embedding_model)
158
  with console.status("[bold green]Writing content", spinner='dots8Bit'):
159
+ draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k=max_extract)
 
160
 
161
+ # If copywrite mode is enabled, generate comments and final text
 
 
 
 
 
 
162
  if(copywrite_mode):
163
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
164
  comments = cw.generate_comments(chat, query, draft)
165
 
 
 
 
 
 
 
 
166
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
167
  final_text = cw.generate_final_text(chat, query, draft, comments)
168
+ else:
169
+ final_text = draft
170
 
171
+ # Output the answer
172
+ console.rule(f"[bold green]Response")
173
+ if output == "text":
174
+ console.print(final_text)
175
+ else:
176
+ console.print(Markdown(final_text))
177
+ console.rule("[bold green]")
178
+
179
+ return final_text
180
 
181
  if __name__ == '__main__':
182
+ # Parse command-line arguments and execute the main function
183
  arguments = docopt(__doc__, version='Search Agent 0.1')
184
  main(arguments)
 
web_crawler.py CHANGED
@@ -7,20 +7,32 @@ import io
7
  from trafilatura import extract
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
 
10
  from langchain_experimental.text_splitter import SemanticChunker
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
12
- from langchain_community.vectorstores.faiss import FAISS
13
  from langsmith import traceable
14
  import requests
15
  import pdfplumber
16
 
17
  @traceable(run_type="tool", name="get_sources")
18
- def get_sources(query, max_pages=10, domain=None):
 
 
 
 
 
 
 
 
 
 
 
19
  search_query = query
20
  if domain:
21
  search_query += f" site:{domain}"
22
 
23
  url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
 
24
  headers = {
25
  'Accept': 'application/json',
26
  'Accept-Encoding': 'gzip',
@@ -52,9 +64,18 @@ def get_sources(query, max_pages=10, domain=None):
52
  print('Error fetching search results:', error)
53
  raise
54
 
 
 
 
55
 
 
 
 
 
56
 
57
- def fetch_with_selenium(url, driver, timeout=8,):
 
 
58
  try:
59
  driver.set_page_load_timeout(timeout)
60
  driver.get(url)
@@ -65,10 +86,20 @@ def fetch_with_selenium(url, driver, timeout=8,):
65
  html = None
66
  finally:
67
  driver.quit()
68
-
69
  return html
70
 
71
  def fetch_with_timeout(url, timeout=8):
 
 
 
 
 
 
 
 
 
 
72
  try:
73
  response = requests.get(url, timeout=timeout)
74
  response.raise_for_status()
@@ -76,8 +107,16 @@ def fetch_with_timeout(url, timeout=8):
76
  except requests.RequestException as error:
77
  return None
78
 
79
-
80
  def process_source(source):
 
 
 
 
 
 
 
 
 
81
  url = source['link']
82
  response = fetch_with_timeout(url, 2)
83
  if response:
@@ -109,6 +148,17 @@ def process_source(source):
109
 
110
  @traceable(run_type="tool", name="get_links_contents")
111
  def get_links_contents(sources, get_driver_func=None, use_selenium=False):
 
 
 
 
 
 
 
 
 
 
 
112
  with ThreadPoolExecutor() as executor:
113
  results = list(executor.map(process_source, sources))
114
 
@@ -128,6 +178,16 @@ def get_links_contents(sources, get_driver_func=None, use_selenium=False):
128
 
129
  @traceable(run_type="embedding")
130
  def vectorize(contents, embedding_model):
 
 
 
 
 
 
 
 
 
 
131
  documents = []
132
  for content in contents:
133
  try:
@@ -151,7 +211,7 @@ def vectorize(contents, embedding_model):
151
 
152
  for i in range(0, len(split_documents), batch_size):
153
  batch = split_documents[i:i+batch_size]
154
-
155
  if vector_store is None:
156
  vector_store = FAISS.from_documents(batch, embedding_model)
157
  else:
@@ -163,4 +223,4 @@ def vectorize(contents, embedding_model):
163
  metadatas
164
  )
165
 
166
- return vector_store
 
7
  from trafilatura import extract
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
10
+ from langchain_community.vectorstores.faiss import FAISS
11
  from langchain_experimental.text_splitter import SemanticChunker
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
 
13
  from langsmith import traceable
14
  import requests
15
  import pdfplumber
16
 
17
  @traceable(run_type="tool", name="get_sources")
18
+ def get_sources(query, max_pages=10, domain=None):
19
+ """
20
+ Fetch search results from the Brave Search API based on the given query.
21
+
22
+ Args:
23
+ query (str): The search query.
24
+ max_pages (int): Maximum number of pages to retrieve.
25
+ domain (str, optional): Limit search to a specific domain.
26
+
27
+ Returns:
28
+ list: A list of search results with title, link, snippet, and favicon.
29
+ """
30
  search_query = query
31
  if domain:
32
  search_query += f" site:{domain}"
33
 
34
  url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
35
+
36
  headers = {
37
  'Accept': 'application/json',
38
  'Accept-Encoding': 'gzip',
 
64
  print('Error fetching search results:', error)
65
  raise
66
 
67
+ def fetch_with_selenium(url, driver, timeout=8):
68
+ """
69
+ Fetch the HTML content of a webpage using Selenium.
70
 
71
+ Args:
72
+ url (str): The URL of the webpage.
73
+ driver: Selenium WebDriver instance.
74
+ timeout (int): Page load timeout in seconds.
75
 
76
+ Returns:
77
+ str: The HTML content of the page.
78
+ """
79
  try:
80
  driver.set_page_load_timeout(timeout)
81
  driver.get(url)
 
86
  html = None
87
  finally:
88
  driver.quit()
89
+
90
  return html
91
 
92
  def fetch_with_timeout(url, timeout=8):
93
+ """
94
+ Fetch a webpage with a specified timeout.
95
+
96
+ Args:
97
+ url (str): The URL of the webpage.
98
+ timeout (int): Request timeout in seconds.
99
+
100
+ Returns:
101
+ Response: The HTTP response object, or None if an error occurred.
102
+ """
103
  try:
104
  response = requests.get(url, timeout=timeout)
105
  response.raise_for_status()
 
107
  except requests.RequestException as error:
108
  return None
109
 
 
110
  def process_source(source):
111
+ """
112
+ Process a single source to extract its content.
113
+
114
+ Args:
115
+ source (dict): A dictionary containing the source's link and other metadata.
116
+
117
+ Returns:
118
+ dict: The source with its extracted page content.
119
+ """
120
  url = source['link']
121
  response = fetch_with_timeout(url, 2)
122
  if response:
 
148
 
149
  @traceable(run_type="tool", name="get_links_contents")
150
  def get_links_contents(sources, get_driver_func=None, use_selenium=False):
151
+ """
152
+ Retrieve and process the content of multiple sources.
153
+
154
+ Args:
155
+ sources (list): A list of source dictionaries.
156
+ get_driver_func (callable, optional): Function to get a Selenium WebDriver.
157
+ use_selenium (bool): Whether to use Selenium for fetching content.
158
+
159
+ Returns:
160
+ list: A list of processed sources with their page content.
161
+ """
162
  with ThreadPoolExecutor() as executor:
163
  results = list(executor.map(process_source, sources))
164
 
 
178
 
179
  @traceable(run_type="embedding")
180
  def vectorize(contents, embedding_model):
181
+ """
182
+ Vectorize the contents using the specified embedding model.
183
+
184
+ Args:
185
+ contents (list): A list of content dictionaries.
186
+ embedding_model: The embedding model to use.
187
+
188
+ Returns:
189
+ FAISS: A FAISS vector store containing the vectorized documents.
190
+ """
191
  documents = []
192
  for content in contents:
193
  try:
 
211
 
212
  for i in range(0, len(split_documents), batch_size):
213
  batch = split_documents[i:i+batch_size]
214
+
215
  if vector_store is None:
216
  vector_store = FAISS.from_documents(batch, embedding_model)
217
  else:
 
223
  metadatas
224
  )
225
 
226
+ return vector_store
web_rag.py CHANGED
@@ -19,6 +19,7 @@ Perform RAG using a single query to retrieve relevant documents.
19
  """
20
  import os
21
  import json
 
22
  from langchain.schema import SystemMessage, HumanMessage
23
  from langchain.prompts.chat import (
24
  HumanMessagePromptTemplate,
@@ -53,13 +54,13 @@ def get_optimized_search_messages(query):
53
  content="""
54
  You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
55
  The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
56
-
57
  To optimize the prompt:
58
  - Identify the key information being requested
59
  - Consider any implicit information or context that might be useful for the search.
60
  - Arrange the keywords into a concise search string
61
  - Put the most important keywords first
62
-
63
  Some tips and things to be sure to remove:
64
  - Remove any conversational or instructional phrases
65
  - Removed style such as "in the style of", "engaging", "short", "long"
@@ -68,7 +69,7 @@ def get_optimized_search_messages(query):
68
  - Remove lenght instruction (example: essay, article, letter, etc)
69
 
70
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
71
-
72
  Example:
73
  Question: How do I bake chocolate chip cookies from scratch?
74
  chocolate chip cookies recipe from scratch**
@@ -105,9 +106,9 @@ def get_optimized_search_messages(query):
105
  """
106
  )
107
  human_message = HumanMessage(
108
- content=f"""
109
  Question: {query}
110
-
111
  """
112
  )
113
  return [system_message, human_message]
@@ -150,14 +151,14 @@ def get_optimized_search_messages2(query):
150
  3. Adding quotation marks around exact phrases if applicable
151
  4. Including relevant synonyms or related terms (in parentheses) to broaden the search
152
  5. Using Boolean operators if needed to refine the search
153
-
154
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
155
  """
156
  )
157
  human_message = HumanMessage(
158
- content=f"""
159
  Question: {query}
160
-
161
  """
162
  )
163
  return [system_message, human_message]
@@ -165,20 +166,31 @@ def get_optimized_search_messages2(query):
165
 
166
  @traceable(run_type="llm", name="optimize_search_query")
167
  def optimize_search_query(chat_llm, query, callbacks=[]):
 
 
 
 
 
 
 
 
 
 
 
168
  messages = get_optimized_search_messages(query)
169
  response = chat_llm.invoke(messages)
170
  optimized_search_query = response.content.strip()
171
-
172
  # Split by '**' and take the first part, then strip whitespace
173
  optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
174
-
175
  # Remove surrounding quotes if present
176
  optimized_search_query = optimized_search_query.strip('"')
177
-
178
  # If the result is empty, fall back to the original query
179
  if not optimized_search_query:
180
  optimized_search_query = query
181
-
182
  return optimized_search_query
183
 
184
  def get_rag_prompt_template():
@@ -193,7 +205,7 @@ def get_rag_prompt_template():
193
  input_variables=[],
194
  template="""
195
  You are an expert research assistant.
196
- You are provided with a Context in JSON format and a Question.
197
  Each JSON entry contains: content, title, link
198
 
199
  Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
@@ -203,7 +215,7 @@ def get_rag_prompt_template():
203
  - Synthesize the retrieved information into a clear, informative answer to the question
204
  - Format your answer in Markdown, using heading levels 2-3 as needed
205
  - Include a "References" section at the end with the full citations and link for each source you used
206
-
207
  If the provided context is not relevant to the question, say it and answer with your internal knowledge.
208
  If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
209
  If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
@@ -214,7 +226,7 @@ def get_rag_prompt_template():
214
  prompt=PromptTemplate(
215
  input_variables=["context", "query"],
216
  template="""
217
- Context:
218
  ---------------------
219
  {context}
220
  ---------------------
@@ -229,6 +241,15 @@ def get_rag_prompt_template():
229
  )
230
 
231
  def format_docs(docs):
 
 
 
 
 
 
 
 
 
232
  formatted_docs = []
233
  for d in docs:
234
  content = d.page_content
@@ -241,6 +262,19 @@ def format_docs(docs):
241
 
242
 
243
  def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  retriever_from_llm = MultiQueryRetriever.from_llm(
245
  retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
246
  )
@@ -259,7 +293,7 @@ def get_context_size(chat_llm):
259
  else:
260
  return 16385
261
  if isinstance(chat_llm, ChatFireworks):
262
- 32768
263
  if isinstance(chat_llm, ChatGroq):
264
  return 32768
265
  if isinstance(chat_llm, ChatOllama):
@@ -278,9 +312,10 @@ def get_context_size(chat_llm):
278
  return 128000
279
  return 32000
280
  return 4096
281
-
282
- @traceable(run_type="retriever")
283
  def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
 
284
  done = False
285
  while not done:
286
  unique_docs = vectorstore.similarity_search(
@@ -292,14 +327,12 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
292
  done = True
293
  else:
294
  top_k = int(top_k * 0.75)
295
-
296
  return prompt
297
 
298
  @traceable(run_type="llm", name="query_rag")
299
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
300
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
301
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
302
-
303
  # Ensure we're returning a string
304
  if isinstance(response.content, list):
305
  # If it's a list, join the elements into a single string
 
19
  """
20
  import os
21
  import json
22
+ from docopt import re
23
  from langchain.schema import SystemMessage, HumanMessage
24
  from langchain.prompts.chat import (
25
  HumanMessagePromptTemplate,
 
54
  content="""
55
  You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
56
  The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
57
+
58
  To optimize the prompt:
59
  - Identify the key information being requested
60
  - Consider any implicit information or context that might be useful for the search.
61
  - Arrange the keywords into a concise search string
62
  - Put the most important keywords first
63
+
64
  Some tips and things to be sure to remove:
65
  - Remove any conversational or instructional phrases
66
  - Removed style such as "in the style of", "engaging", "short", "long"
 
69
  - Remove lenght instruction (example: essay, article, letter, etc)
70
 
71
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
72
+
73
  Example:
74
  Question: How do I bake chocolate chip cookies from scratch?
75
  chocolate chip cookies recipe from scratch**
 
106
  """
107
  )
108
  human_message = HumanMessage(
109
+ content=f"""
110
  Question: {query}
111
+
112
  """
113
  )
114
  return [system_message, human_message]
 
151
  3. Adding quotation marks around exact phrases if applicable
152
  4. Including relevant synonyms or related terms (in parentheses) to broaden the search
153
  5. Using Boolean operators if needed to refine the search
154
+
155
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
156
  """
157
  )
158
  human_message = HumanMessage(
159
+ content=f"""
160
  Question: {query}
161
+
162
  """
163
  )
164
  return [system_message, human_message]
 
166
 
167
  @traceable(run_type="llm", name="optimize_search_query")
168
  def optimize_search_query(chat_llm, query, callbacks=[]):
169
+ """
170
+ Optimize the search query using the chat language model.
171
+
172
+ Args:
173
+ chat_llm: The chat language model to use.
174
+ query (str): The user's query.
175
+ callbacks (list): Optional callbacks for tracing.
176
+
177
+ Returns:
178
+ str: The optimized search query.
179
+ """
180
  messages = get_optimized_search_messages(query)
181
  response = chat_llm.invoke(messages)
182
  optimized_search_query = response.content.strip()
183
+
184
  # Split by '**' and take the first part, then strip whitespace
185
  optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
186
+
187
  # Remove surrounding quotes if present
188
  optimized_search_query = optimized_search_query.strip('"')
189
+
190
  # If the result is empty, fall back to the original query
191
  if not optimized_search_query:
192
  optimized_search_query = query
193
+
194
  return optimized_search_query
195
 
196
  def get_rag_prompt_template():
 
205
  input_variables=[],
206
  template="""
207
  You are an expert research assistant.
208
+ You are provided with a Context in JSON format and a Question.
209
  Each JSON entry contains: content, title, link
210
 
211
  Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
 
215
  - Synthesize the retrieved information into a clear, informative answer to the question
216
  - Format your answer in Markdown, using heading levels 2-3 as needed
217
  - Include a "References" section at the end with the full citations and link for each source you used
218
+
219
  If the provided context is not relevant to the question, say it and answer with your internal knowledge.
220
  If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
221
  If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
 
226
  prompt=PromptTemplate(
227
  input_variables=["context", "query"],
228
  template="""
229
+ Context:
230
  ---------------------
231
  {context}
232
  ---------------------
 
241
  )
242
 
243
  def format_docs(docs):
244
+ """
245
+ Format the retrieved documents into a JSON string.
246
+
247
+ Args:
248
+ docs (list): A list of documents to format.
249
+
250
+ Returns:
251
+ str: The formatted documents as a JSON string.
252
+ """
253
  formatted_docs = []
254
  for d in docs:
255
  content = d.page_content
 
262
 
263
 
264
  def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
265
+ """
266
+ Perform RAG using multiple queries to retrieve relevant documents.
267
+
268
+ Args:
269
+ chat_llm: The chat language model to use.
270
+ question (str): The user's question.
271
+ search_query (str): The search query to use.
272
+ vectorstore: The vector store for document retrieval.
273
+ callbacks (list): Optional callbacks for tracing.
274
+
275
+ Returns:
276
+ str: The generated answer to the question.
277
+ """
278
  retriever_from_llm = MultiQueryRetriever.from_llm(
279
  retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
280
  )
 
293
  else:
294
  return 16385
295
  if isinstance(chat_llm, ChatFireworks):
296
+ return 32768
297
  if isinstance(chat_llm, ChatGroq):
298
  return 32768
299
  if isinstance(chat_llm, ChatOllama):
 
312
  return 128000
313
  return 32000
314
  return 4096
315
+
316
+ @traceable(run_type="retriever")
317
  def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
318
+ prompt = ""
319
  done = False
320
  while not done:
321
  unique_docs = vectorstore.similarity_search(
 
327
  done = True
328
  else:
329
  top_k = int(top_k * 0.75)
 
330
  return prompt
331
 
332
  @traceable(run_type="llm", name="query_rag")
333
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
334
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
335
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
 
336
  # Ensure we're returning a string
337
  if isinstance(response.content, list):
338
  # If it's a list, join the elements into a single string