Spaces:
Running
Running
Enhance model support, improve documentation, and refactor core components
Browse files- Added support for new language models including DeepSeek, Qwen, and expanded model options
- Updated README with comprehensive project structure and usage details
- Refactored core modules (copywriter.py, models.py, nlp_rag.py) with improved type hints and docstrings
- Updated default models and embedding configurations
- Enhanced search agent with more flexible model selection and verbose output options
- Improved requirements.txt with latest library dependencies
- README.md +47 -15
- copywriter.py +4 -9
- models.py +29 -11
- nlp_rag.py +41 -39
- requirements.txt +4 -1
- search_agent.py +42 -44
- web_crawler.py +67 -7
- web_rag.py +53 -20
README.md
CHANGED
@@ -37,7 +37,7 @@ To run the script, users need to provide their API keys for the desired language
|
|
37 |
|
38 |
## Features
|
39 |
|
40 |
-
- Supports multiple language model providers (Bedrock, OpenAI, Groq, Cohere, and Ollama)
|
41 |
- Optimizes search queries using a language model
|
42 |
- Fetches web pages and extracts main content (HTML and PDF)
|
43 |
- Vectorizes the content for efficient retrieval
|
@@ -55,8 +55,9 @@ To run the script, users need to provide their API keys for the desired language
|
|
55 |
|
56 |
3. Set up API keys:
|
57 |
|
58 |
-
-
|
59 |
-
-
|
|
|
60 |
|
61 |
## Usage
|
62 |
|
@@ -68,20 +69,29 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
|
|
68 |
|
69 |
### Options:
|
70 |
|
71 |
-
-
|
72 |
-
|
73 |
-
-
|
74 |
-
-
|
75 |
-
-
|
76 |
-
-
|
77 |
-
-
|
78 |
-
-
|
79 |
-
-
|
80 |
-
-
|
81 |
-
-
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
### Examples
|
84 |
|
|
|
|
|
|
|
|
|
85 |
```bash
|
86 |
python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
|
87 |
```
|
@@ -98,4 +108,26 @@ python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the curr
|
|
98 |
|
99 |
This project is licensed under the Apache License Version 2.0. See the `LICENSE` file for details.
|
100 |
|
101 |
-
Let me know if you have any other questions! The key components are using a web search API to find relevant information, extracting the key snippets from the search results, passing that as context to a large language model, and having the LLM generate a natural language answer based on the web search context.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
## Features
|
39 |
|
40 |
+
- Supports multiple language model providers (HuggingFace, Bedrock, OpenAI, Groq, Cohere, and Ollama)
|
41 |
- Optimizes search queries using a language model
|
42 |
- Fetches web pages and extracts main content (HTML and PDF)
|
43 |
- Vectorizes the content for efficient retrieval
|
|
|
55 |
|
56 |
3. Set up API keys:
|
57 |
|
58 |
+
- create a `.env` file and add your API keys. Use `dotenv.sample` to create this file.
|
59 |
+
- Get an API key from the following sources: https://brave.com/search/api/
|
60 |
+
- Optionally you can add API keys from other LLM providers.
|
61 |
|
62 |
## Usage
|
63 |
|
|
|
69 |
|
70 |
### Options:
|
71 |
|
72 |
+
-h --help Show this screen.
|
73 |
+
--version Show version.
|
74 |
+
-c --copywrite First produce a draft, review it and rewrite for a final text
|
75 |
+
-d domain --domain=domain Limit search to a specific domain
|
76 |
+
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
77 |
+
-m model --model=model Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct]
|
78 |
+
-e model --embedding_model=model Use an embedding model
|
79 |
+
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
80 |
+
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
81 |
+
-b --use_browser Use browser to fetch content from the web [default: False]
|
82 |
+
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
83 |
+
-v --verbose Print verbose output [default: False]
|
84 |
+
|
85 |
+
The model can be a language model provider and a model name separated by a colon. e.g. `openai:gpt-4o-mini`
|
86 |
+
If a embedding model is not specified, spaCy will be used for semantic search.
|
87 |
+
|
88 |
|
89 |
### Examples
|
90 |
|
91 |
+
```bash
|
92 |
+
python search_agent.py 'What is the radioactive anomaly in the Pacific Ocean?'
|
93 |
+
```
|
94 |
+
|
95 |
```bash
|
96 |
python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
|
97 |
```
|
|
|
108 |
|
109 |
This project is licensed under the Apache License Version 2.0. See the `LICENSE` file for details.
|
110 |
|
111 |
+
Let me know if you have any other questions! The key components are using a web search API to find relevant information, extracting the key snippets from the search results, passing that as context to a large language model, and having the LLM generate a natural language answer based on the web search context.
|
112 |
+
|
113 |
+
## Project Structure
|
114 |
+
|
115 |
+
The project consists of several key components:
|
116 |
+
|
117 |
+
- `search_agent.py`: The main script that handles the core search agent functionality
|
118 |
+
- `search_agent_ui.py`: Streamlit-based user interface for the search agent
|
119 |
+
- `web_crawler.py`: Handles web content fetching and processing
|
120 |
+
- `web_rag.py`: Implements the Retrieval-Augmented Generation (RAG) functionality
|
121 |
+
- `nlp_rag.py`: Natural language processing utilities for RAG
|
122 |
+
- `models.py`: Contains model definitions and configurations
|
123 |
+
- `copywriter.py`: Implements content rewriting and optimization features
|
124 |
+
|
125 |
+
## Additional Tools
|
126 |
+
|
127 |
+
The project includes several development and configuration files:
|
128 |
+
|
129 |
+
- `requirements.txt`: Lists all Python dependencies
|
130 |
+
- `.env`: Configuration file for API keys and settings (use `dotenv.sample` as a template)
|
131 |
+
- `.gitignore`: Specifies which files Git should ignore
|
132 |
+
- `LICENSE`: Apache License Version 2.0
|
133 |
+
- `.devcontainer/`: Contains development container configuration for consistent development environments
|
copywriter.py
CHANGED
@@ -1,10 +1,4 @@
|
|
1 |
from langchain.schema import SystemMessage, HumanMessage
|
2 |
-
from langchain.prompts.chat import (
|
3 |
-
HumanMessagePromptTemplate,
|
4 |
-
SystemMessagePromptTemplate,
|
5 |
-
ChatPromptTemplate
|
6 |
-
)
|
7 |
-
from langchain.prompts.prompt import PromptTemplate
|
8 |
from langsmith import traceable
|
9 |
|
10 |
|
@@ -21,7 +15,6 @@ def get_comments_prompt(query, draft):
|
|
21 |
5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
|
22 |
6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
|
23 |
7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
|
24 |
-
|
25 |
Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
|
26 |
You never generate the corrected text by itself. *Only* give the comment.
|
27 |
"""
|
@@ -35,12 +28,14 @@ def get_comments_prompt(query, draft):
|
|
35 |
)
|
36 |
return [system_message, human_message]
|
37 |
|
|
|
38 |
@traceable(run_type="llm", name="generate_comments")
|
39 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
40 |
messages = get_comments_prompt(query, draft)
|
41 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
42 |
return response.content
|
43 |
|
|
|
44 |
def get_final_text_prompt(query, draft, comments):
|
45 |
system_message = SystemMessage(
|
46 |
content="""
|
@@ -73,7 +68,7 @@ def get_final_text_prompt(query, draft, comments):
|
|
73 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
74 |
messages = get_final_text_prompt(query, draft, comments)
|
75 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
76 |
-
return response.content
|
77 |
|
78 |
|
79 |
def get_compare_texts_prompts(query, draft_text, final_text):
|
@@ -109,4 +104,4 @@ def get_compare_texts_prompts(query, draft_text, final_text):
|
|
109 |
def compare_text(chat_llm, query, draft, final, callbacks=[]):
|
110 |
messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
|
111 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
112 |
-
return response.content
|
|
|
1 |
from langchain.schema import SystemMessage, HumanMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langsmith import traceable
|
3 |
|
4 |
|
|
|
15 |
5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
|
16 |
6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
|
17 |
7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
|
|
|
18 |
Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
|
19 |
You never generate the corrected text by itself. *Only* give the comment.
|
20 |
"""
|
|
|
28 |
)
|
29 |
return [system_message, human_message]
|
30 |
|
31 |
+
|
32 |
@traceable(run_type="llm", name="generate_comments")
|
33 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
34 |
messages = get_comments_prompt(query, draft)
|
35 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
36 |
return response.content
|
37 |
|
38 |
+
|
39 |
def get_final_text_prompt(query, draft, comments):
|
40 |
system_message = SystemMessage(
|
41 |
content="""
|
|
|
68 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
69 |
messages = get_final_text_prompt(query, draft, comments)
|
70 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
71 |
+
return response.content
|
72 |
|
73 |
|
74 |
def get_compare_texts_prompts(query, draft_text, final_text):
|
|
|
104 |
def compare_text(chat_llm, query, draft, final, callbacks=[]):
|
105 |
messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
|
106 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
107 |
+
return response.content
|
models.py
CHANGED
@@ -30,6 +30,10 @@ from langchain.chat_models.base import BaseChatModel
|
|
30 |
from langchain.embeddings.base import Embeddings
|
31 |
|
32 |
def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
|
|
|
|
|
|
|
|
|
33 |
parts = provider_model.split(":", 1)
|
34 |
provider = parts[0]
|
35 |
if len(parts) > 1:
|
@@ -48,7 +52,7 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
|
48 |
match provider.lower():
|
49 |
case 'anthropic':
|
50 |
if model is None:
|
51 |
-
model = "claude-3-
|
52 |
chat_llm = ChatAnthropic(model=model, temperature=temperature)
|
53 |
case 'bedrock':
|
54 |
if model is None:
|
@@ -58,22 +62,32 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
|
58 |
if model is None:
|
59 |
model = 'command-r-plus'
|
60 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
case 'fireworks':
|
62 |
if model is None:
|
63 |
-
model = 'accounts/fireworks/models/llama-
|
64 |
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
65 |
case 'googlegenerativeai':
|
66 |
if model is None:
|
67 |
-
model = "gemini-
|
68 |
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
|
69 |
max_tokens=None, timeout=None, max_retries=2,)
|
70 |
case 'groq':
|
71 |
if model is None:
|
72 |
-
model = '
|
73 |
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
74 |
case 'huggingface' | 'hf':
|
75 |
if model is None:
|
76 |
-
model = '
|
77 |
llm = HuggingFaceEndpoint(
|
78 |
repo_id=model,
|
79 |
max_length=8192,
|
@@ -91,23 +105,23 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
|
91 |
chat_llm = ChatOpenAI(model=model, temperature=temperature)
|
92 |
case 'openrouter':
|
93 |
if model is None:
|
94 |
-
model = "
|
95 |
chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
|
96 |
case 'mistralai' | 'mistral':
|
97 |
if model is None:
|
98 |
-
model = "
|
99 |
chat_llm = ChatMistralAI(model=model, temperature=temperature)
|
100 |
case 'perplexity':
|
101 |
if model is None:
|
102 |
-
model = '
|
103 |
chat_llm = ChatPerplexity(model=model, temperature=temperature)
|
104 |
case 'together':
|
105 |
if model is None:
|
106 |
-
model = 'meta-llama/
|
107 |
chat_llm = ChatTogether(model=model, temperature=temperature)
|
108 |
case 'xai':
|
109 |
if model is None:
|
110 |
-
model = 'grok-
|
111 |
chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
|
112 |
case _:
|
113 |
raise ValueError(f"Unknown LLM provider {provider}")
|
@@ -118,6 +132,10 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
|
118 |
|
119 |
|
120 |
def get_embedding_model(provider_model: str) -> Embeddings:
|
|
|
|
|
|
|
|
|
121 |
provider, model = split_provider_model(provider_model)
|
122 |
match provider.lower():
|
123 |
case 'bedrock':
|
@@ -126,7 +144,7 @@ def get_embedding_model(provider_model: str) -> Embeddings:
|
|
126 |
embedding_model = BedrockEmbeddings(model_id=model)
|
127 |
case 'cohere':
|
128 |
if model is None:
|
129 |
-
model = "embed-multilingual-v3"
|
130 |
embedding_model = CohereEmbeddings(model=model)
|
131 |
case 'fireworks':
|
132 |
if model is None:
|
|
|
30 |
from langchain.embeddings.base import Embeddings
|
31 |
|
32 |
def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
|
33 |
+
"""
|
34 |
+
Split the provider and model name from a string.
|
35 |
+
returns Tuple[str, Optional[str]]
|
36 |
+
"""
|
37 |
parts = provider_model.split(":", 1)
|
38 |
provider = parts[0]
|
39 |
if len(parts) > 1:
|
|
|
52 |
match provider.lower():
|
53 |
case 'anthropic':
|
54 |
if model is None:
|
55 |
+
model = "claude-3-5-haiku-20241022"
|
56 |
chat_llm = ChatAnthropic(model=model, temperature=temperature)
|
57 |
case 'bedrock':
|
58 |
if model is None:
|
|
|
62 |
if model is None:
|
63 |
model = 'command-r-plus'
|
64 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
65 |
+
|
66 |
+
case 'deepseek':
|
67 |
+
if model is None:
|
68 |
+
model='deepseek-chat'
|
69 |
+
chat_llm = ChatOpenAI(
|
70 |
+
model=model,
|
71 |
+
openai_api_key=os.getenv("DEEPSEEK_API_KEY"),
|
72 |
+
openai_api_base='https://api.deepseek.com',
|
73 |
+
max_tokens=8192
|
74 |
+
)
|
75 |
case 'fireworks':
|
76 |
if model is None:
|
77 |
+
model = 'accounts/fireworks/models/llama-v3p3-70b-instruct'
|
78 |
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
79 |
case 'googlegenerativeai':
|
80 |
if model is None:
|
81 |
+
model = "gemini-2.0-flash-exp"
|
82 |
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
|
83 |
max_tokens=None, timeout=None, max_retries=2,)
|
84 |
case 'groq':
|
85 |
if model is None:
|
86 |
+
model = 'qwen-2.5-32b'
|
87 |
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
88 |
case 'huggingface' | 'hf':
|
89 |
if model is None:
|
90 |
+
model = 'Qwen/Qwen2.5-72B-Instruct'
|
91 |
llm = HuggingFaceEndpoint(
|
92 |
repo_id=model,
|
93 |
max_length=8192,
|
|
|
105 |
chat_llm = ChatOpenAI(model=model, temperature=temperature)
|
106 |
case 'openrouter':
|
107 |
if model is None:
|
108 |
+
model = "cognitivecomputations/dolphin3.0-mistral-24b:free"
|
109 |
chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
|
110 |
case 'mistralai' | 'mistral':
|
111 |
if model is None:
|
112 |
+
model = "mistral-small-latest"
|
113 |
chat_llm = ChatMistralAI(model=model, temperature=temperature)
|
114 |
case 'perplexity':
|
115 |
if model is None:
|
116 |
+
model = 'sonar'
|
117 |
chat_llm = ChatPerplexity(model=model, temperature=temperature)
|
118 |
case 'together':
|
119 |
if model is None:
|
120 |
+
model = 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free'
|
121 |
chat_llm = ChatTogether(model=model, temperature=temperature)
|
122 |
case 'xai':
|
123 |
if model is None:
|
124 |
+
model = 'grok-2-1212'
|
125 |
chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
|
126 |
case _:
|
127 |
raise ValueError(f"Unknown LLM provider {provider}")
|
|
|
132 |
|
133 |
|
134 |
def get_embedding_model(provider_model: str) -> Embeddings:
|
135 |
+
"""
|
136 |
+
Get an embedding model from a provider and model name.
|
137 |
+
returns Embeddings
|
138 |
+
"""
|
139 |
provider, model = split_provider_model(provider_model)
|
140 |
match provider.lower():
|
141 |
case 'bedrock':
|
|
|
144 |
embedding_model = BedrockEmbeddings(model_id=model)
|
145 |
case 'cohere':
|
146 |
if model is None:
|
147 |
+
model = "embed-multilingual-v3.0"
|
148 |
embedding_model = CohereEmbeddings(model=model)
|
149 |
case 'fireworks':
|
150 |
if model is None:
|
nlp_rag.py
CHANGED
@@ -6,6 +6,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
6 |
import numpy as np
|
7 |
|
8 |
def get_nlp_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
if not spacy.util.is_package("en_core_web_md"):
|
10 |
print("Downloading en_core_web_md model...")
|
11 |
spacy.cli.download("en_core_web_md")
|
@@ -15,6 +21,17 @@ def get_nlp_model():
|
|
15 |
|
16 |
|
17 |
def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from langchain_core.documents.base import Document
|
19 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
20 |
|
@@ -51,6 +68,19 @@ def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
|
|
51 |
|
52 |
|
53 |
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Precompute query vector and its norm
|
55 |
query_vector = nlp(query).vector
|
56 |
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
|
@@ -84,47 +114,19 @@ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
|
84 |
return relevant_chunks[:top_n]
|
85 |
|
86 |
|
87 |
-
# Perform semantic search using spaCy
|
88 |
-
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
89 |
-
import numpy as np
|
90 |
-
from concurrent.futures import ThreadPoolExecutor
|
91 |
-
|
92 |
-
# Precompute query vector and its norm with epsilon to prevent division by zero
|
93 |
-
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
94 |
-
query_vector = nlp(query).vector
|
95 |
-
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon
|
96 |
-
|
97 |
-
# Prepare texts from chunks
|
98 |
-
texts = [chunk['text'] for chunk in chunks]
|
99 |
-
|
100 |
-
# Function to process each text and compute its vector
|
101 |
-
def compute_vector(text):
|
102 |
-
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
103 |
-
doc = nlp(text)
|
104 |
-
vector = doc.vector
|
105 |
-
return vector
|
106 |
-
|
107 |
-
# Process texts in parallel using ThreadPoolExecutor
|
108 |
-
with ThreadPoolExecutor() as executor:
|
109 |
-
chunk_vectors = list(executor.map(compute_vector, texts))
|
110 |
-
|
111 |
-
chunk_vectors = np.array(chunk_vectors)
|
112 |
-
chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon
|
113 |
-
|
114 |
-
# Compute similarities using vectorized operations
|
115 |
-
similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
|
116 |
-
|
117 |
-
# Filter and sort results
|
118 |
-
relevant_chunks = [
|
119 |
-
(chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
|
120 |
-
]
|
121 |
-
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
122 |
-
|
123 |
-
return relevant_chunks[:top_n]
|
124 |
-
|
125 |
-
|
126 |
@traceable(run_type="llm", name="nlp_rag")
|
127 |
def query_rag(chat_llm, query, relevant_results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
import web_rag as wr
|
129 |
|
130 |
formatted_chunks = ""
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
def get_nlp_model():
|
9 |
+
"""
|
10 |
+
Load and return the spaCy NLP model. Downloads the model if not already installed.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
nlp: The loaded spaCy NLP model.
|
14 |
+
"""
|
15 |
if not spacy.util.is_package("en_core_web_md"):
|
16 |
print("Downloading en_core_web_md model...")
|
17 |
spacy.cli.download("en_core_web_md")
|
|
|
21 |
|
22 |
|
23 |
def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
|
24 |
+
"""
|
25 |
+
Split documents into smaller chunks using a recursive character text splitter.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
contents (list): List of content dictionaries with 'page_content', 'title', and 'link'.
|
29 |
+
max_chunk_size (int): Maximum size of each chunk.
|
30 |
+
overlap (int): Overlap between chunks.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
list: List of chunks with text and metadata.
|
34 |
+
"""
|
35 |
from langchain_core.documents.base import Document
|
36 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
37 |
|
|
|
68 |
|
69 |
|
70 |
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
71 |
+
"""
|
72 |
+
Perform semantic search to find relevant chunks based on similarity to the query.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
query (str): The search query.
|
76 |
+
chunks (list): List of text chunks with vectors.
|
77 |
+
nlp: The spaCy NLP model.
|
78 |
+
similarity_threshold (float): Minimum similarity score to consider a chunk relevant.
|
79 |
+
top_n (int): Number of top relevant chunks to return.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
list: List of relevant chunks and their similarity scores.
|
83 |
+
"""
|
84 |
# Precompute query vector and its norm
|
85 |
query_vector = nlp(query).vector
|
86 |
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
|
|
|
114 |
return relevant_chunks[:top_n]
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
@traceable(run_type="llm", name="nlp_rag")
|
118 |
def query_rag(chat_llm, query, relevant_results):
|
119 |
+
"""
|
120 |
+
Generate a response using retrieval-augmented generation (RAG) based on relevant results.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
chat_llm: The chat language model to use.
|
124 |
+
query (str): The user's query.
|
125 |
+
relevant_results (list): List of relevant chunks and their similarity scores.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
str: The generated response.
|
129 |
+
"""
|
130 |
import web_rag as wr
|
131 |
|
132 |
formatted_chunks = ""
|
requirements.txt
CHANGED
@@ -3,18 +3,21 @@ boto3 >= 1.34.131, < 1.35.0
|
|
3 |
bs4
|
4 |
chromedriver-py >= 128.0.6613.137
|
5 |
cohere >= 5.9.2
|
6 |
-
docopt
|
7 |
faiss-cpu >= 1.8.0
|
8 |
google-api-python-client >= 2.145.0
|
9 |
pdfplumber >= 0.11.4
|
10 |
python-dotenv >= 1.0.1
|
11 |
langchain >= 0.3.0
|
|
|
12 |
langchain-aws >= 0.2.0
|
13 |
langchain-fireworks
|
14 |
langchain_core >= 0.3.0
|
15 |
langchain-cohere
|
16 |
langchain_community
|
17 |
langchain_experimental
|
|
|
|
|
18 |
langchain_openai
|
19 |
langchain-ollama
|
20 |
langchain_groq
|
|
|
3 |
bs4
|
4 |
chromedriver-py >= 128.0.6613.137
|
5 |
cohere >= 5.9.2
|
6 |
+
docopt
|
7 |
faiss-cpu >= 1.8.0
|
8 |
google-api-python-client >= 2.145.0
|
9 |
pdfplumber >= 0.11.4
|
10 |
python-dotenv >= 1.0.1
|
11 |
langchain >= 0.3.0
|
12 |
+
langchain_anthropic
|
13 |
langchain-aws >= 0.2.0
|
14 |
langchain-fireworks
|
15 |
langchain_core >= 0.3.0
|
16 |
langchain-cohere
|
17 |
langchain_community
|
18 |
langchain_experimental
|
19 |
+
langchain_huggingface
|
20 |
+
langchain_mistralai
|
21 |
langchain_openai
|
22 |
langchain-ollama
|
23 |
langchain_groq
|
search_agent.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
"""
|
|
|
2 |
|
3 |
Usage:
|
4 |
-
search_agent.py
|
5 |
[--domain=domain]
|
6 |
[--provider=provider]
|
7 |
[--model=model]
|
@@ -22,7 +23,7 @@ Options:
|
|
22 |
-c --copywrite First produce a draft, review it and rewrite for a final text
|
23 |
-d domain --domain=domain Limit search to a specific domain
|
24 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
25 |
-
-m model --model=model Use a specific model [default:
|
26 |
-e model --embedding_model=model Use an embedding model
|
27 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
@@ -35,7 +36,6 @@ Options:
|
|
35 |
import os
|
36 |
|
37 |
from docopt import docopt
|
38 |
-
#from schema import Schema, Use, SchemaError
|
39 |
import dotenv
|
40 |
|
41 |
from langchain.callbacks import LangChainTracer
|
@@ -51,10 +51,13 @@ import copywriter as cw
|
|
51 |
import models as md
|
52 |
import nlp_rag as nr
|
53 |
|
|
|
54 |
console = Console()
|
|
|
55 |
dotenv.load_dotenv()
|
56 |
|
57 |
def get_selenium_driver():
|
|
|
58 |
from selenium import webdriver
|
59 |
from selenium.webdriver.chrome.options import Options
|
60 |
from selenium.common.exceptions import WebDriverException
|
@@ -76,72 +79,76 @@ def get_selenium_driver():
|
|
76 |
print(f"Error creating Selenium WebDriver: {e}")
|
77 |
return None
|
78 |
|
|
|
79 |
callbacks = []
|
|
|
80 |
if os.getenv("LANGCHAIN_API_KEY"):
|
81 |
callbacks.append(
|
82 |
LangChainTracer(client=Client())
|
83 |
)
|
|
|
84 |
@traceable(run_type="tool", name="search_agent")
|
85 |
def main(arguments):
|
|
|
86 |
verbose = arguments["--verbose"]
|
87 |
copywrite_mode = arguments["--copywrite"]
|
88 |
model = arguments["--model"]
|
89 |
embedding_model = arguments["--embedding_model"]
|
90 |
temperature = float(arguments["--temperature"])
|
91 |
-
domain=arguments["--domain"]
|
92 |
-
max_pages=int(arguments["--max_pages"])
|
93 |
-
max_extract=int(arguments["--max_extracts"])
|
94 |
-
output=arguments["--output"]
|
95 |
-
use_selenium=arguments["--use_browser"]
|
96 |
query = arguments["SEARCH_QUERY"]
|
97 |
|
|
|
98 |
chat = md.get_model(model, temperature)
|
|
|
|
|
99 |
if embedding_model is None:
|
100 |
use_nlp = True
|
101 |
nlp = nr.get_nlp_model()
|
102 |
else:
|
|
|
103 |
embedding_model = md.get_embedding_model(embedding_model)
|
104 |
-
use_nlp = False
|
105 |
|
|
|
106 |
if verbose:
|
107 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
108 |
-
console.log(f"Using
|
109 |
if not use_nlp:
|
110 |
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
111 |
-
console.log(f"Using model: {embedding_model_name}")
|
112 |
-
|
113 |
|
|
|
114 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
115 |
optimized_search_query = wr.optimize_search_query(chat, query)
|
116 |
if len(optimized_search_query) < 3:
|
117 |
optimized_search_query = query
|
118 |
console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
|
119 |
|
|
|
120 |
with console.status(
|
121 |
f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
|
122 |
):
|
123 |
sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
|
124 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
125 |
|
|
|
126 |
with console.status(
|
127 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
128 |
):
|
129 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
130 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
131 |
|
|
|
132 |
if use_nlp:
|
133 |
-
with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
|
134 |
chunks = nr.recursive_split_documents(contents)
|
135 |
-
#chunks = nr.chunk_contents(nlp, contents)
|
136 |
console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
|
137 |
with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
|
138 |
-
import time
|
139 |
-
|
140 |
-
start_time = time.time()
|
141 |
relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
|
142 |
-
end_time = time.time()
|
143 |
-
execution_time = end_time - start_time
|
144 |
-
console.log(f"Semantic search took {execution_time:.2f} seconds")
|
145 |
console.log(f"Found {len(relevant_results)} relevant chunks")
|
146 |
with console.status(f"[bold green]Writing content", spinner="growVertical"):
|
147 |
draft = nr.query_rag(chat, query, relevant_results)
|
@@ -149,38 +156,29 @@ def main(arguments):
|
|
149 |
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
150 |
vector_store = wc.vectorize(contents, embedding_model)
|
151 |
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
152 |
-
draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k
|
153 |
-
|
154 |
|
155 |
-
|
156 |
-
if output == "text":
|
157 |
-
console.print(draft)
|
158 |
-
else:
|
159 |
-
console.print(Markdown(draft))
|
160 |
-
console.rule("[bold green]")
|
161 |
-
|
162 |
if(copywrite_mode):
|
163 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
164 |
comments = cw.generate_comments(chat, query, draft)
|
165 |
|
166 |
-
console.rule("[bold green]Response from reviewer")
|
167 |
-
if output == "text":
|
168 |
-
console.print(comments)
|
169 |
-
else:
|
170 |
-
console.print(Markdown(comments))
|
171 |
-
console.rule("[bold green]")
|
172 |
-
|
173 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
174 |
final_text = cw.generate_final_text(chat, query, draft, comments)
|
|
|
|
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
console.
|
|
|
|
|
|
|
182 |
|
183 |
if __name__ == '__main__':
|
|
|
184 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
185 |
main(arguments)
|
186 |
-
|
|
|
1 |
+
"""
|
2 |
+
search_agent.py
|
3 |
|
4 |
Usage:
|
5 |
+
search_agent.py
|
6 |
[--domain=domain]
|
7 |
[--provider=provider]
|
8 |
[--model=model]
|
|
|
23 |
-c --copywrite First produce a draft, review it and rewrite for a final text
|
24 |
-d domain --domain=domain Limit search to a specific domain
|
25 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
26 |
+
-m model --model=model Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct]
|
27 |
-e model --embedding_model=model Use an embedding model
|
28 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
29 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
|
|
36 |
import os
|
37 |
|
38 |
from docopt import docopt
|
|
|
39 |
import dotenv
|
40 |
|
41 |
from langchain.callbacks import LangChainTracer
|
|
|
51 |
import models as md
|
52 |
import nlp_rag as nr
|
53 |
|
54 |
+
# Initialize console for rich text output
|
55 |
console = Console()
|
56 |
+
# Load environment variables from a .env file
|
57 |
dotenv.load_dotenv()
|
58 |
|
59 |
def get_selenium_driver():
|
60 |
+
"""Initialize and return a headless Selenium WebDriver for Chrome."""
|
61 |
from selenium import webdriver
|
62 |
from selenium.webdriver.chrome.options import Options
|
63 |
from selenium.common.exceptions import WebDriverException
|
|
|
79 |
print(f"Error creating Selenium WebDriver: {e}")
|
80 |
return None
|
81 |
|
82 |
+
# Initialize callbacks list
|
83 |
callbacks = []
|
84 |
+
# Add LangChainTracer to callbacks if API key is set
|
85 |
if os.getenv("LANGCHAIN_API_KEY"):
|
86 |
callbacks.append(
|
87 |
LangChainTracer(client=Client())
|
88 |
)
|
89 |
+
|
90 |
@traceable(run_type="tool", name="search_agent")
|
91 |
def main(arguments):
|
92 |
+
"""Main function to execute the search agent logic."""
|
93 |
verbose = arguments["--verbose"]
|
94 |
copywrite_mode = arguments["--copywrite"]
|
95 |
model = arguments["--model"]
|
96 |
embedding_model = arguments["--embedding_model"]
|
97 |
temperature = float(arguments["--temperature"])
|
98 |
+
domain = arguments["--domain"]
|
99 |
+
max_pages = int(arguments["--max_pages"])
|
100 |
+
max_extract = int(arguments["--max_extracts"])
|
101 |
+
output = arguments["--output"]
|
102 |
+
use_selenium = arguments["--use_browser"]
|
103 |
query = arguments["SEARCH_QUERY"]
|
104 |
|
105 |
+
# Get the language model based on the provided model name and temperature
|
106 |
chat = md.get_model(model, temperature)
|
107 |
+
|
108 |
+
# If no embedding model is provided, use spacy for semantic search
|
109 |
if embedding_model is None:
|
110 |
use_nlp = True
|
111 |
nlp = nr.get_nlp_model()
|
112 |
else:
|
113 |
+
use_nlp = False
|
114 |
embedding_model = md.get_embedding_model(embedding_model)
|
|
|
115 |
|
116 |
+
# Log model details if verbose mode is enabled
|
117 |
if verbose:
|
118 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
119 |
+
console.log(f"Using model: {model_name}")
|
120 |
if not use_nlp:
|
121 |
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
122 |
+
console.log(f"Using embedding model: {embedding_model_name}")
|
|
|
123 |
|
124 |
+
# Optimize the search query
|
125 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
126 |
optimized_search_query = wr.optimize_search_query(chat, query)
|
127 |
if len(optimized_search_query) < 3:
|
128 |
optimized_search_query = query
|
129 |
console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
|
130 |
|
131 |
+
# Retrieve sources using the optimized query
|
132 |
with console.status(
|
133 |
f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
|
134 |
):
|
135 |
sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
|
136 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
137 |
|
138 |
+
# Fetch content from the retrieved sources
|
139 |
with console.status(
|
140 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
141 |
):
|
142 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
143 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
144 |
|
145 |
+
# Process content using spaCy or embedding model
|
146 |
if use_nlp:
|
147 |
+
with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
|
148 |
chunks = nr.recursive_split_documents(contents)
|
|
|
149 |
console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
|
150 |
with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
|
|
|
|
|
|
|
151 |
relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
|
|
|
|
|
|
|
152 |
console.log(f"Found {len(relevant_results)} relevant chunks")
|
153 |
with console.status(f"[bold green]Writing content", spinner="growVertical"):
|
154 |
draft = nr.query_rag(chat, query, relevant_results)
|
|
|
156 |
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
157 |
vector_store = wc.vectorize(contents, embedding_model)
|
158 |
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
159 |
+
draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k=max_extract)
|
|
|
160 |
|
161 |
+
# If copywrite mode is enabled, generate comments and final text
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
if(copywrite_mode):
|
163 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
164 |
comments = cw.generate_comments(chat, query, draft)
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
167 |
final_text = cw.generate_final_text(chat, query, draft, comments)
|
168 |
+
else:
|
169 |
+
final_text = draft
|
170 |
|
171 |
+
# Output the answer
|
172 |
+
console.rule(f"[bold green]Response")
|
173 |
+
if output == "text":
|
174 |
+
console.print(final_text)
|
175 |
+
else:
|
176 |
+
console.print(Markdown(final_text))
|
177 |
+
console.rule("[bold green]")
|
178 |
+
|
179 |
+
return final_text
|
180 |
|
181 |
if __name__ == '__main__':
|
182 |
+
# Parse command-line arguments and execute the main function
|
183 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
184 |
main(arguments)
|
|
web_crawler.py
CHANGED
@@ -7,20 +7,32 @@ import io
|
|
7 |
from trafilatura import extract
|
8 |
from selenium.common.exceptions import TimeoutException
|
9 |
from langchain_core.documents.base import Document
|
|
|
10 |
from langchain_experimental.text_splitter import SemanticChunker
|
11 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
12 |
-
from langchain_community.vectorstores.faiss import FAISS
|
13 |
from langsmith import traceable
|
14 |
import requests
|
15 |
import pdfplumber
|
16 |
|
17 |
@traceable(run_type="tool", name="get_sources")
|
18 |
-
def get_sources(query, max_pages=10, domain=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
search_query = query
|
20 |
if domain:
|
21 |
search_query += f" site:{domain}"
|
22 |
|
23 |
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
|
|
|
24 |
headers = {
|
25 |
'Accept': 'application/json',
|
26 |
'Accept-Encoding': 'gzip',
|
@@ -52,9 +64,18 @@ def get_sources(query, max_pages=10, domain=None):
|
|
52 |
print('Error fetching search results:', error)
|
53 |
raise
|
54 |
|
|
|
|
|
|
|
55 |
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
58 |
try:
|
59 |
driver.set_page_load_timeout(timeout)
|
60 |
driver.get(url)
|
@@ -65,10 +86,20 @@ def fetch_with_selenium(url, driver, timeout=8,):
|
|
65 |
html = None
|
66 |
finally:
|
67 |
driver.quit()
|
68 |
-
|
69 |
return html
|
70 |
|
71 |
def fetch_with_timeout(url, timeout=8):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
try:
|
73 |
response = requests.get(url, timeout=timeout)
|
74 |
response.raise_for_status()
|
@@ -76,8 +107,16 @@ def fetch_with_timeout(url, timeout=8):
|
|
76 |
except requests.RequestException as error:
|
77 |
return None
|
78 |
|
79 |
-
|
80 |
def process_source(source):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
url = source['link']
|
82 |
response = fetch_with_timeout(url, 2)
|
83 |
if response:
|
@@ -109,6 +148,17 @@ def process_source(source):
|
|
109 |
|
110 |
@traceable(run_type="tool", name="get_links_contents")
|
111 |
def get_links_contents(sources, get_driver_func=None, use_selenium=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
with ThreadPoolExecutor() as executor:
|
113 |
results = list(executor.map(process_source, sources))
|
114 |
|
@@ -128,6 +178,16 @@ def get_links_contents(sources, get_driver_func=None, use_selenium=False):
|
|
128 |
|
129 |
@traceable(run_type="embedding")
|
130 |
def vectorize(contents, embedding_model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
documents = []
|
132 |
for content in contents:
|
133 |
try:
|
@@ -151,7 +211,7 @@ def vectorize(contents, embedding_model):
|
|
151 |
|
152 |
for i in range(0, len(split_documents), batch_size):
|
153 |
batch = split_documents[i:i+batch_size]
|
154 |
-
|
155 |
if vector_store is None:
|
156 |
vector_store = FAISS.from_documents(batch, embedding_model)
|
157 |
else:
|
@@ -163,4 +223,4 @@ def vectorize(contents, embedding_model):
|
|
163 |
metadatas
|
164 |
)
|
165 |
|
166 |
-
return vector_store
|
|
|
7 |
from trafilatura import extract
|
8 |
from selenium.common.exceptions import TimeoutException
|
9 |
from langchain_core.documents.base import Document
|
10 |
+
from langchain_community.vectorstores.faiss import FAISS
|
11 |
from langchain_experimental.text_splitter import SemanticChunker
|
12 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
|
|
13 |
from langsmith import traceable
|
14 |
import requests
|
15 |
import pdfplumber
|
16 |
|
17 |
@traceable(run_type="tool", name="get_sources")
|
18 |
+
def get_sources(query, max_pages=10, domain=None):
|
19 |
+
"""
|
20 |
+
Fetch search results from the Brave Search API based on the given query.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
query (str): The search query.
|
24 |
+
max_pages (int): Maximum number of pages to retrieve.
|
25 |
+
domain (str, optional): Limit search to a specific domain.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
list: A list of search results with title, link, snippet, and favicon.
|
29 |
+
"""
|
30 |
search_query = query
|
31 |
if domain:
|
32 |
search_query += f" site:{domain}"
|
33 |
|
34 |
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
|
35 |
+
|
36 |
headers = {
|
37 |
'Accept': 'application/json',
|
38 |
'Accept-Encoding': 'gzip',
|
|
|
64 |
print('Error fetching search results:', error)
|
65 |
raise
|
66 |
|
67 |
+
def fetch_with_selenium(url, driver, timeout=8):
|
68 |
+
"""
|
69 |
+
Fetch the HTML content of a webpage using Selenium.
|
70 |
|
71 |
+
Args:
|
72 |
+
url (str): The URL of the webpage.
|
73 |
+
driver: Selenium WebDriver instance.
|
74 |
+
timeout (int): Page load timeout in seconds.
|
75 |
|
76 |
+
Returns:
|
77 |
+
str: The HTML content of the page.
|
78 |
+
"""
|
79 |
try:
|
80 |
driver.set_page_load_timeout(timeout)
|
81 |
driver.get(url)
|
|
|
86 |
html = None
|
87 |
finally:
|
88 |
driver.quit()
|
89 |
+
|
90 |
return html
|
91 |
|
92 |
def fetch_with_timeout(url, timeout=8):
|
93 |
+
"""
|
94 |
+
Fetch a webpage with a specified timeout.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
url (str): The URL of the webpage.
|
98 |
+
timeout (int): Request timeout in seconds.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Response: The HTTP response object, or None if an error occurred.
|
102 |
+
"""
|
103 |
try:
|
104 |
response = requests.get(url, timeout=timeout)
|
105 |
response.raise_for_status()
|
|
|
107 |
except requests.RequestException as error:
|
108 |
return None
|
109 |
|
|
|
110 |
def process_source(source):
|
111 |
+
"""
|
112 |
+
Process a single source to extract its content.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
source (dict): A dictionary containing the source's link and other metadata.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
dict: The source with its extracted page content.
|
119 |
+
"""
|
120 |
url = source['link']
|
121 |
response = fetch_with_timeout(url, 2)
|
122 |
if response:
|
|
|
148 |
|
149 |
@traceable(run_type="tool", name="get_links_contents")
|
150 |
def get_links_contents(sources, get_driver_func=None, use_selenium=False):
|
151 |
+
"""
|
152 |
+
Retrieve and process the content of multiple sources.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
sources (list): A list of source dictionaries.
|
156 |
+
get_driver_func (callable, optional): Function to get a Selenium WebDriver.
|
157 |
+
use_selenium (bool): Whether to use Selenium for fetching content.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
list: A list of processed sources with their page content.
|
161 |
+
"""
|
162 |
with ThreadPoolExecutor() as executor:
|
163 |
results = list(executor.map(process_source, sources))
|
164 |
|
|
|
178 |
|
179 |
@traceable(run_type="embedding")
|
180 |
def vectorize(contents, embedding_model):
|
181 |
+
"""
|
182 |
+
Vectorize the contents using the specified embedding model.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
contents (list): A list of content dictionaries.
|
186 |
+
embedding_model: The embedding model to use.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
FAISS: A FAISS vector store containing the vectorized documents.
|
190 |
+
"""
|
191 |
documents = []
|
192 |
for content in contents:
|
193 |
try:
|
|
|
211 |
|
212 |
for i in range(0, len(split_documents), batch_size):
|
213 |
batch = split_documents[i:i+batch_size]
|
214 |
+
|
215 |
if vector_store is None:
|
216 |
vector_store = FAISS.from_documents(batch, embedding_model)
|
217 |
else:
|
|
|
223 |
metadatas
|
224 |
)
|
225 |
|
226 |
+
return vector_store
|
web_rag.py
CHANGED
@@ -19,6 +19,7 @@ Perform RAG using a single query to retrieve relevant documents.
|
|
19 |
"""
|
20 |
import os
|
21 |
import json
|
|
|
22 |
from langchain.schema import SystemMessage, HumanMessage
|
23 |
from langchain.prompts.chat import (
|
24 |
HumanMessagePromptTemplate,
|
@@ -53,13 +54,13 @@ def get_optimized_search_messages(query):
|
|
53 |
content="""
|
54 |
You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
|
55 |
The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
|
56 |
-
|
57 |
To optimize the prompt:
|
58 |
- Identify the key information being requested
|
59 |
- Consider any implicit information or context that might be useful for the search.
|
60 |
- Arrange the keywords into a concise search string
|
61 |
- Put the most important keywords first
|
62 |
-
|
63 |
Some tips and things to be sure to remove:
|
64 |
- Remove any conversational or instructional phrases
|
65 |
- Removed style such as "in the style of", "engaging", "short", "long"
|
@@ -68,7 +69,7 @@ def get_optimized_search_messages(query):
|
|
68 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
69 |
|
70 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
71 |
-
|
72 |
Example:
|
73 |
Question: How do I bake chocolate chip cookies from scratch?
|
74 |
chocolate chip cookies recipe from scratch**
|
@@ -105,9 +106,9 @@ def get_optimized_search_messages(query):
|
|
105 |
"""
|
106 |
)
|
107 |
human_message = HumanMessage(
|
108 |
-
content=f"""
|
109 |
Question: {query}
|
110 |
-
|
111 |
"""
|
112 |
)
|
113 |
return [system_message, human_message]
|
@@ -150,14 +151,14 @@ def get_optimized_search_messages2(query):
|
|
150 |
3. Adding quotation marks around exact phrases if applicable
|
151 |
4. Including relevant synonyms or related terms (in parentheses) to broaden the search
|
152 |
5. Using Boolean operators if needed to refine the search
|
153 |
-
|
154 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
|
155 |
"""
|
156 |
)
|
157 |
human_message = HumanMessage(
|
158 |
-
content=f"""
|
159 |
Question: {query}
|
160 |
-
|
161 |
"""
|
162 |
)
|
163 |
return [system_message, human_message]
|
@@ -165,20 +166,31 @@ def get_optimized_search_messages2(query):
|
|
165 |
|
166 |
@traceable(run_type="llm", name="optimize_search_query")
|
167 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
messages = get_optimized_search_messages(query)
|
169 |
response = chat_llm.invoke(messages)
|
170 |
optimized_search_query = response.content.strip()
|
171 |
-
|
172 |
# Split by '**' and take the first part, then strip whitespace
|
173 |
optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
|
174 |
-
|
175 |
# Remove surrounding quotes if present
|
176 |
optimized_search_query = optimized_search_query.strip('"')
|
177 |
-
|
178 |
# If the result is empty, fall back to the original query
|
179 |
if not optimized_search_query:
|
180 |
optimized_search_query = query
|
181 |
-
|
182 |
return optimized_search_query
|
183 |
|
184 |
def get_rag_prompt_template():
|
@@ -193,7 +205,7 @@ def get_rag_prompt_template():
|
|
193 |
input_variables=[],
|
194 |
template="""
|
195 |
You are an expert research assistant.
|
196 |
-
You are provided with a Context in JSON format and a Question.
|
197 |
Each JSON entry contains: content, title, link
|
198 |
|
199 |
Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
|
@@ -203,7 +215,7 @@ def get_rag_prompt_template():
|
|
203 |
- Synthesize the retrieved information into a clear, informative answer to the question
|
204 |
- Format your answer in Markdown, using heading levels 2-3 as needed
|
205 |
- Include a "References" section at the end with the full citations and link for each source you used
|
206 |
-
|
207 |
If the provided context is not relevant to the question, say it and answer with your internal knowledge.
|
208 |
If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
|
209 |
If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
|
@@ -214,7 +226,7 @@ def get_rag_prompt_template():
|
|
214 |
prompt=PromptTemplate(
|
215 |
input_variables=["context", "query"],
|
216 |
template="""
|
217 |
-
Context:
|
218 |
---------------------
|
219 |
{context}
|
220 |
---------------------
|
@@ -229,6 +241,15 @@ def get_rag_prompt_template():
|
|
229 |
)
|
230 |
|
231 |
def format_docs(docs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
formatted_docs = []
|
233 |
for d in docs:
|
234 |
content = d.page_content
|
@@ -241,6 +262,19 @@ def format_docs(docs):
|
|
241 |
|
242 |
|
243 |
def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
retriever_from_llm = MultiQueryRetriever.from_llm(
|
245 |
retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
|
246 |
)
|
@@ -259,7 +293,7 @@ def get_context_size(chat_llm):
|
|
259 |
else:
|
260 |
return 16385
|
261 |
if isinstance(chat_llm, ChatFireworks):
|
262 |
-
32768
|
263 |
if isinstance(chat_llm, ChatGroq):
|
264 |
return 32768
|
265 |
if isinstance(chat_llm, ChatOllama):
|
@@ -278,9 +312,10 @@ def get_context_size(chat_llm):
|
|
278 |
return 128000
|
279 |
return 32000
|
280 |
return 4096
|
281 |
-
|
282 |
-
@traceable(run_type="retriever")
|
283 |
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
|
|
284 |
done = False
|
285 |
while not done:
|
286 |
unique_docs = vectorstore.similarity_search(
|
@@ -292,14 +327,12 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
|
|
292 |
done = True
|
293 |
else:
|
294 |
top_k = int(top_k * 0.75)
|
295 |
-
|
296 |
return prompt
|
297 |
|
298 |
@traceable(run_type="llm", name="query_rag")
|
299 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
300 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
301 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
302 |
-
|
303 |
# Ensure we're returning a string
|
304 |
if isinstance(response.content, list):
|
305 |
# If it's a list, join the elements into a single string
|
|
|
19 |
"""
|
20 |
import os
|
21 |
import json
|
22 |
+
from docopt import re
|
23 |
from langchain.schema import SystemMessage, HumanMessage
|
24 |
from langchain.prompts.chat import (
|
25 |
HumanMessagePromptTemplate,
|
|
|
54 |
content="""
|
55 |
You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
|
56 |
The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
|
57 |
+
|
58 |
To optimize the prompt:
|
59 |
- Identify the key information being requested
|
60 |
- Consider any implicit information or context that might be useful for the search.
|
61 |
- Arrange the keywords into a concise search string
|
62 |
- Put the most important keywords first
|
63 |
+
|
64 |
Some tips and things to be sure to remove:
|
65 |
- Remove any conversational or instructional phrases
|
66 |
- Removed style such as "in the style of", "engaging", "short", "long"
|
|
|
69 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
70 |
|
71 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
72 |
+
|
73 |
Example:
|
74 |
Question: How do I bake chocolate chip cookies from scratch?
|
75 |
chocolate chip cookies recipe from scratch**
|
|
|
106 |
"""
|
107 |
)
|
108 |
human_message = HumanMessage(
|
109 |
+
content=f"""
|
110 |
Question: {query}
|
111 |
+
|
112 |
"""
|
113 |
)
|
114 |
return [system_message, human_message]
|
|
|
151 |
3. Adding quotation marks around exact phrases if applicable
|
152 |
4. Including relevant synonyms or related terms (in parentheses) to broaden the search
|
153 |
5. Using Boolean operators if needed to refine the search
|
154 |
+
|
155 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
|
156 |
"""
|
157 |
)
|
158 |
human_message = HumanMessage(
|
159 |
+
content=f"""
|
160 |
Question: {query}
|
161 |
+
|
162 |
"""
|
163 |
)
|
164 |
return [system_message, human_message]
|
|
|
166 |
|
167 |
@traceable(run_type="llm", name="optimize_search_query")
|
168 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
169 |
+
"""
|
170 |
+
Optimize the search query using the chat language model.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
chat_llm: The chat language model to use.
|
174 |
+
query (str): The user's query.
|
175 |
+
callbacks (list): Optional callbacks for tracing.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
str: The optimized search query.
|
179 |
+
"""
|
180 |
messages = get_optimized_search_messages(query)
|
181 |
response = chat_llm.invoke(messages)
|
182 |
optimized_search_query = response.content.strip()
|
183 |
+
|
184 |
# Split by '**' and take the first part, then strip whitespace
|
185 |
optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
|
186 |
+
|
187 |
# Remove surrounding quotes if present
|
188 |
optimized_search_query = optimized_search_query.strip('"')
|
189 |
+
|
190 |
# If the result is empty, fall back to the original query
|
191 |
if not optimized_search_query:
|
192 |
optimized_search_query = query
|
193 |
+
|
194 |
return optimized_search_query
|
195 |
|
196 |
def get_rag_prompt_template():
|
|
|
205 |
input_variables=[],
|
206 |
template="""
|
207 |
You are an expert research assistant.
|
208 |
+
You are provided with a Context in JSON format and a Question.
|
209 |
Each JSON entry contains: content, title, link
|
210 |
|
211 |
Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
|
|
|
215 |
- Synthesize the retrieved information into a clear, informative answer to the question
|
216 |
- Format your answer in Markdown, using heading levels 2-3 as needed
|
217 |
- Include a "References" section at the end with the full citations and link for each source you used
|
218 |
+
|
219 |
If the provided context is not relevant to the question, say it and answer with your internal knowledge.
|
220 |
If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
|
221 |
If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
|
|
|
226 |
prompt=PromptTemplate(
|
227 |
input_variables=["context", "query"],
|
228 |
template="""
|
229 |
+
Context:
|
230 |
---------------------
|
231 |
{context}
|
232 |
---------------------
|
|
|
241 |
)
|
242 |
|
243 |
def format_docs(docs):
|
244 |
+
"""
|
245 |
+
Format the retrieved documents into a JSON string.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
docs (list): A list of documents to format.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
str: The formatted documents as a JSON string.
|
252 |
+
"""
|
253 |
formatted_docs = []
|
254 |
for d in docs:
|
255 |
content = d.page_content
|
|
|
262 |
|
263 |
|
264 |
def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
|
265 |
+
"""
|
266 |
+
Perform RAG using multiple queries to retrieve relevant documents.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
chat_llm: The chat language model to use.
|
270 |
+
question (str): The user's question.
|
271 |
+
search_query (str): The search query to use.
|
272 |
+
vectorstore: The vector store for document retrieval.
|
273 |
+
callbacks (list): Optional callbacks for tracing.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
str: The generated answer to the question.
|
277 |
+
"""
|
278 |
retriever_from_llm = MultiQueryRetriever.from_llm(
|
279 |
retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
|
280 |
)
|
|
|
293 |
else:
|
294 |
return 16385
|
295 |
if isinstance(chat_llm, ChatFireworks):
|
296 |
+
return 32768
|
297 |
if isinstance(chat_llm, ChatGroq):
|
298 |
return 32768
|
299 |
if isinstance(chat_llm, ChatOllama):
|
|
|
312 |
return 128000
|
313 |
return 32000
|
314 |
return 4096
|
315 |
+
|
316 |
+
@traceable(run_type="retriever")
|
317 |
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
318 |
+
prompt = ""
|
319 |
done = False
|
320 |
while not done:
|
321 |
unique_docs = vectorstore.similarity_search(
|
|
|
327 |
done = True
|
328 |
else:
|
329 |
top_k = int(top_k * 0.75)
|
|
|
330 |
return prompt
|
331 |
|
332 |
@traceable(run_type="llm", name="query_rag")
|
333 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
334 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
335 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
|
|
336 |
# Ensure we're returning a string
|
337 |
if isinstance(response.content, list):
|
338 |
# If it's a list, join the elements into a single string
|