Spaces:
Running
Running
Eddie Pick
commited on
By default now use spacy for retrieval and augmentation (vs embeddings)
Browse files- nlp_rag.py +144 -0
- requirements.txt +2 -1
- search_agent.py +40 -19
- spacy.ipynb +0 -0
- web_rag.py +2 -2
nlp_rag.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
from itertools import groupby
|
3 |
+
from operator import itemgetter
|
4 |
+
from langsmith import traceable
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def get_nlp_model():
|
9 |
+
if not spacy.util.is_package("en_core_web_md"):
|
10 |
+
print("Downloading en_core_web_md model...")
|
11 |
+
spacy.cli.download("en_core_web_md")
|
12 |
+
print("Model downloaded successfully!")
|
13 |
+
nlp = spacy.load("en_core_web_md")
|
14 |
+
return nlp
|
15 |
+
|
16 |
+
|
17 |
+
def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
|
18 |
+
from langchain_core.documents.base import Document
|
19 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
20 |
+
|
21 |
+
documents = []
|
22 |
+
for content in contents:
|
23 |
+
try:
|
24 |
+
page_content = content['page_content']
|
25 |
+
if page_content:
|
26 |
+
metadata = {'title': content['title'], 'source': content['link']}
|
27 |
+
doc = Document(page_content=content['page_content'], metadata=metadata)
|
28 |
+
documents.append(doc)
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Error processing content for {content['link']}: {e}")
|
31 |
+
|
32 |
+
# Initialize recursive text splitter
|
33 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_chunk_size, chunk_overlap=overlap)
|
34 |
+
|
35 |
+
# Split documents
|
36 |
+
split_documents = text_splitter.split_documents(documents)
|
37 |
+
|
38 |
+
# Convert split documents to the same format as recursive_split
|
39 |
+
chunks = []
|
40 |
+
for doc in split_documents:
|
41 |
+
chunk = {
|
42 |
+
'text': doc.page_content,
|
43 |
+
'metadata': {
|
44 |
+
'title': doc.metadata.get('title', ''),
|
45 |
+
'source': doc.metadata.get('source', '')
|
46 |
+
}
|
47 |
+
}
|
48 |
+
chunks.append(chunk)
|
49 |
+
|
50 |
+
return chunks
|
51 |
+
|
52 |
+
|
53 |
+
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
54 |
+
# Precompute query vector and its norm
|
55 |
+
query_vector = nlp(query).vector
|
56 |
+
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
|
57 |
+
|
58 |
+
# Check if chunks have precomputed vectors; if not, compute them
|
59 |
+
if 'vector' not in chunks[0]:
|
60 |
+
texts = [chunk['text'] for chunk in chunks]
|
61 |
+
|
62 |
+
# Process texts in batches using nlp.pipe()
|
63 |
+
batch_size = 1000 # Adjust based on available memory
|
64 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
65 |
+
docs = nlp.pipe(texts, batch_size=batch_size)
|
66 |
+
|
67 |
+
# Add vectors to chunks
|
68 |
+
for chunk, doc in zip(chunks, docs):
|
69 |
+
chunk['vector'] = doc.vector
|
70 |
+
|
71 |
+
# Prepare chunk vectors and norms
|
72 |
+
chunk_vectors = np.array([chunk['vector'] for chunk in chunks])
|
73 |
+
chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon to avoid division by zero
|
74 |
+
|
75 |
+
# Compute similarities
|
76 |
+
similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
|
77 |
+
|
78 |
+
# Filter and sort results
|
79 |
+
relevant_chunks = [
|
80 |
+
(chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
|
81 |
+
]
|
82 |
+
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
83 |
+
|
84 |
+
return relevant_chunks[:top_n]
|
85 |
+
|
86 |
+
|
87 |
+
# Perform semantic search using spaCy
|
88 |
+
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
89 |
+
import numpy as np
|
90 |
+
from concurrent.futures import ThreadPoolExecutor
|
91 |
+
|
92 |
+
# Precompute query vector and its norm with epsilon to prevent division by zero
|
93 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
94 |
+
query_vector = nlp(query).vector
|
95 |
+
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon
|
96 |
+
|
97 |
+
# Prepare texts from chunks
|
98 |
+
texts = [chunk['text'] for chunk in chunks]
|
99 |
+
|
100 |
+
# Function to process each text and compute its vector
|
101 |
+
def compute_vector(text):
|
102 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
103 |
+
doc = nlp(text)
|
104 |
+
vector = doc.vector
|
105 |
+
return vector
|
106 |
+
|
107 |
+
# Process texts in parallel using ThreadPoolExecutor
|
108 |
+
with ThreadPoolExecutor() as executor:
|
109 |
+
chunk_vectors = list(executor.map(compute_vector, texts))
|
110 |
+
|
111 |
+
chunk_vectors = np.array(chunk_vectors)
|
112 |
+
chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon
|
113 |
+
|
114 |
+
# Compute similarities using vectorized operations
|
115 |
+
similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
|
116 |
+
|
117 |
+
# Filter and sort results
|
118 |
+
relevant_chunks = [
|
119 |
+
(chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
|
120 |
+
]
|
121 |
+
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
122 |
+
|
123 |
+
return relevant_chunks[:top_n]
|
124 |
+
|
125 |
+
|
126 |
+
@traceable(run_type="llm", name="nlp_rag")
|
127 |
+
def query_rag(chat_llm, query, relevant_results):
|
128 |
+
import web_rag as wr
|
129 |
+
|
130 |
+
formatted_chunks = ""
|
131 |
+
for chunk, similarity in relevant_results:
|
132 |
+
formatted_chunk = f"""
|
133 |
+
<source>
|
134 |
+
<url>{chunk['metadata']['source']}</url>
|
135 |
+
<title>{chunk['metadata']['title']}</title>
|
136 |
+
<text>{chunk['text']}</text>
|
137 |
+
</source>
|
138 |
+
"""
|
139 |
+
formatted_chunks += formatted_chunk
|
140 |
+
|
141 |
+
prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
|
142 |
+
|
143 |
+
draft = chat_llm.invoke(prompt).content
|
144 |
+
return draft
|
requirements.txt
CHANGED
@@ -30,4 +30,5 @@ tiktoken
|
|
30 |
transformers >= 4.44.2
|
31 |
rich >= 13.8.1
|
32 |
trafilatura >= 1.12.2
|
33 |
-
watchdog >= 2.1.5, < 5.0.0
|
|
|
|
30 |
transformers >= 4.44.2
|
31 |
rich >= 13.8.1
|
32 |
trafilatura >= 1.12.2
|
33 |
+
watchdog >= 2.1.5, < 5.0.0
|
34 |
+
spacy >= 3.6.1, < 4.0.0
|
search_agent.py
CHANGED
@@ -10,7 +10,7 @@ Usage:
|
|
10 |
[--copywrite]
|
11 |
[--max_pages=num]
|
12 |
[--max_extracts=num]
|
13 |
-
[--
|
14 |
[--output=text]
|
15 |
[--verbose]
|
16 |
SEARCH_QUERY
|
@@ -23,10 +23,10 @@ Options:
|
|
23 |
-d domain --domain=domain Limit search to a specific domain
|
24 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
25 |
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
|
26 |
-
-e model --embedding_model=model Use
|
27 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
29 |
-
-
|
30 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
31 |
-v --verbose Print verbose output [default: False]
|
32 |
|
@@ -49,6 +49,7 @@ import web_rag as wr
|
|
49 |
import web_crawler as wc
|
50 |
import copywriter as cw
|
51 |
import models as md
|
|
|
52 |
|
53 |
console = Console()
|
54 |
dotenv.load_dotenv()
|
@@ -91,32 +92,35 @@ def main(arguments):
|
|
91 |
max_pages=int(arguments["--max_pages"])
|
92 |
max_extract=int(arguments["--max_extracts"])
|
93 |
output=arguments["--output"]
|
94 |
-
use_selenium=arguments["--
|
95 |
query = arguments["SEARCH_QUERY"]
|
96 |
|
97 |
chat = md.get_model(model, temperature)
|
98 |
-
if embedding_model
|
99 |
-
|
100 |
-
|
101 |
else:
|
102 |
embedding_model = md.get_embedding_model(embedding_model)
|
|
|
103 |
|
104 |
if verbose:
|
105 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
106 |
-
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
107 |
-
console.log(f"Using model: {model_name}")
|
108 |
console.log(f"Using embedding model: {embedding_model_name}")
|
|
|
|
|
|
|
|
|
109 |
|
110 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
111 |
-
|
112 |
-
if len(
|
113 |
-
|
114 |
-
console.log(f"Optimized search query: [bold blue]{
|
115 |
|
116 |
with console.status(
|
117 |
-
f"[bold green]Searching sources using the optimized query: {
|
118 |
):
|
119 |
-
sources = wc.get_sources(
|
120 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
121 |
|
122 |
with console.status(
|
@@ -125,11 +129,28 @@ def main(arguments):
|
|
125 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
126 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
127 |
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
132 |
-
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
|
133 |
|
134 |
console.rule(f"[bold green]Response")
|
135 |
if output == "text":
|
|
|
10 |
[--copywrite]
|
11 |
[--max_pages=num]
|
12 |
[--max_extracts=num]
|
13 |
+
[--use_browser]
|
14 |
[--output=text]
|
15 |
[--verbose]
|
16 |
SEARCH_QUERY
|
|
|
23 |
-d domain --domain=domain Limit search to a specific domain
|
24 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
25 |
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
|
26 |
+
-e model --embedding_model=model Use an embedding model
|
27 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
29 |
+
-b --use_browser Use browser to fetch content from the web [default: False]
|
30 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
31 |
-v --verbose Print verbose output [default: False]
|
32 |
|
|
|
49 |
import web_crawler as wc
|
50 |
import copywriter as cw
|
51 |
import models as md
|
52 |
+
import nlp_rag as nr
|
53 |
|
54 |
console = Console()
|
55 |
dotenv.load_dotenv()
|
|
|
92 |
max_pages=int(arguments["--max_pages"])
|
93 |
max_extract=int(arguments["--max_extracts"])
|
94 |
output=arguments["--output"]
|
95 |
+
use_selenium=arguments["--use_browser"]
|
96 |
query = arguments["SEARCH_QUERY"]
|
97 |
|
98 |
chat = md.get_model(model, temperature)
|
99 |
+
if embedding_model is None:
|
100 |
+
use_nlp = True
|
101 |
+
nlp = nr.get_nlp_model()
|
102 |
else:
|
103 |
embedding_model = md.get_embedding_model(embedding_model)
|
104 |
+
use_nlp = False
|
105 |
|
106 |
if verbose:
|
107 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
|
|
|
|
108 |
console.log(f"Using embedding model: {embedding_model_name}")
|
109 |
+
if not use_nlp:
|
110 |
+
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
111 |
+
console.log(f"Using model: {embedding_model_name}")
|
112 |
+
|
113 |
|
114 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
115 |
+
optimized_search_query = wr.optimize_search_query(chat, query)
|
116 |
+
if len(optimized_search_query) < 3:
|
117 |
+
optimized_search_query = query
|
118 |
+
console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
|
119 |
|
120 |
with console.status(
|
121 |
+
f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
|
122 |
):
|
123 |
+
sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
|
124 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
125 |
|
126 |
with console.status(
|
|
|
129 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
130 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
131 |
|
132 |
+
if use_nlp:
|
133 |
+
with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
|
134 |
+
chunks = nr.recursive_split_documents(contents)
|
135 |
+
#chunks = nr.chunk_contents(nlp, contents)
|
136 |
+
console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
|
137 |
+
with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
|
138 |
+
import time
|
139 |
+
|
140 |
+
start_time = time.time()
|
141 |
+
relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
|
142 |
+
end_time = time.time()
|
143 |
+
execution_time = end_time - start_time
|
144 |
+
console.log(f"Semantic search took {execution_time:.2f} seconds")
|
145 |
+
console.log(f"Found {len(relevant_results)} relevant chunks")
|
146 |
+
with console.status(f"[bold green]Writing content", spinner="growVertical"):
|
147 |
+
draft = nr.query_rag(chat, query, relevant_results)
|
148 |
+
else:
|
149 |
+
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
150 |
+
vector_store = wc.vectorize(contents, embedding_model)
|
151 |
+
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
152 |
+
draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k = max_extract)
|
153 |
|
|
|
|
|
154 |
|
155 |
console.rule(f"[bold green]Response")
|
156 |
if output == "text":
|
spacy.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
web_rag.py
CHANGED
@@ -74,13 +74,13 @@ def get_optimized_search_messages(query):
|
|
74 |
chocolate chip cookies recipe from scratch**
|
75 |
Example:
|
76 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
77 |
-
|
78 |
Example:
|
79 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
80 |
geopolitics nato russia**
|
81 |
Example:
|
82 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
83 |
-
|
84 |
Example:
|
85 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
86 |
solar system**
|
|
|
74 |
chocolate chip cookies recipe from scratch**
|
75 |
Example:
|
76 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
77 |
+
Marie Curie timeline**
|
78 |
Example:
|
79 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
80 |
geopolitics nato russia**
|
81 |
Example:
|
82 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
83 |
+
Andrew Ng**
|
84 |
Example:
|
85 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
86 |
solar system**
|