Eddie Pick commited on
Commit
c7143b1
·
unverified ·
1 Parent(s): 5103a91

By default now use spacy for retrieval and augmentation (vs embeddings)

Browse files
Files changed (5) hide show
  1. nlp_rag.py +144 -0
  2. requirements.txt +2 -1
  3. search_agent.py +40 -19
  4. spacy.ipynb +0 -0
  5. web_rag.py +2 -2
nlp_rag.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ from itertools import groupby
3
+ from operator import itemgetter
4
+ from langsmith import traceable
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ import numpy as np
7
+
8
+ def get_nlp_model():
9
+ if not spacy.util.is_package("en_core_web_md"):
10
+ print("Downloading en_core_web_md model...")
11
+ spacy.cli.download("en_core_web_md")
12
+ print("Model downloaded successfully!")
13
+ nlp = spacy.load("en_core_web_md")
14
+ return nlp
15
+
16
+
17
+ def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
18
+ from langchain_core.documents.base import Document
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+
21
+ documents = []
22
+ for content in contents:
23
+ try:
24
+ page_content = content['page_content']
25
+ if page_content:
26
+ metadata = {'title': content['title'], 'source': content['link']}
27
+ doc = Document(page_content=content['page_content'], metadata=metadata)
28
+ documents.append(doc)
29
+ except Exception as e:
30
+ print(f"Error processing content for {content['link']}: {e}")
31
+
32
+ # Initialize recursive text splitter
33
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_chunk_size, chunk_overlap=overlap)
34
+
35
+ # Split documents
36
+ split_documents = text_splitter.split_documents(documents)
37
+
38
+ # Convert split documents to the same format as recursive_split
39
+ chunks = []
40
+ for doc in split_documents:
41
+ chunk = {
42
+ 'text': doc.page_content,
43
+ 'metadata': {
44
+ 'title': doc.metadata.get('title', ''),
45
+ 'source': doc.metadata.get('source', '')
46
+ }
47
+ }
48
+ chunks.append(chunk)
49
+
50
+ return chunks
51
+
52
+
53
+ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
54
+ # Precompute query vector and its norm
55
+ query_vector = nlp(query).vector
56
+ query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
57
+
58
+ # Check if chunks have precomputed vectors; if not, compute them
59
+ if 'vector' not in chunks[0]:
60
+ texts = [chunk['text'] for chunk in chunks]
61
+
62
+ # Process texts in batches using nlp.pipe()
63
+ batch_size = 1000 # Adjust based on available memory
64
+ with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
65
+ docs = nlp.pipe(texts, batch_size=batch_size)
66
+
67
+ # Add vectors to chunks
68
+ for chunk, doc in zip(chunks, docs):
69
+ chunk['vector'] = doc.vector
70
+
71
+ # Prepare chunk vectors and norms
72
+ chunk_vectors = np.array([chunk['vector'] for chunk in chunks])
73
+ chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon to avoid division by zero
74
+
75
+ # Compute similarities
76
+ similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
77
+
78
+ # Filter and sort results
79
+ relevant_chunks = [
80
+ (chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
81
+ ]
82
+ relevant_chunks.sort(key=lambda x: x[1], reverse=True)
83
+
84
+ return relevant_chunks[:top_n]
85
+
86
+
87
+ # Perform semantic search using spaCy
88
+ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
89
+ import numpy as np
90
+ from concurrent.futures import ThreadPoolExecutor
91
+
92
+ # Precompute query vector and its norm with epsilon to prevent division by zero
93
+ with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
94
+ query_vector = nlp(query).vector
95
+ query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon
96
+
97
+ # Prepare texts from chunks
98
+ texts = [chunk['text'] for chunk in chunks]
99
+
100
+ # Function to process each text and compute its vector
101
+ def compute_vector(text):
102
+ with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
103
+ doc = nlp(text)
104
+ vector = doc.vector
105
+ return vector
106
+
107
+ # Process texts in parallel using ThreadPoolExecutor
108
+ with ThreadPoolExecutor() as executor:
109
+ chunk_vectors = list(executor.map(compute_vector, texts))
110
+
111
+ chunk_vectors = np.array(chunk_vectors)
112
+ chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon
113
+
114
+ # Compute similarities using vectorized operations
115
+ similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
116
+
117
+ # Filter and sort results
118
+ relevant_chunks = [
119
+ (chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
120
+ ]
121
+ relevant_chunks.sort(key=lambda x: x[1], reverse=True)
122
+
123
+ return relevant_chunks[:top_n]
124
+
125
+
126
+ @traceable(run_type="llm", name="nlp_rag")
127
+ def query_rag(chat_llm, query, relevant_results):
128
+ import web_rag as wr
129
+
130
+ formatted_chunks = ""
131
+ for chunk, similarity in relevant_results:
132
+ formatted_chunk = f"""
133
+ <source>
134
+ <url>{chunk['metadata']['source']}</url>
135
+ <title>{chunk['metadata']['title']}</title>
136
+ <text>{chunk['text']}</text>
137
+ </source>
138
+ """
139
+ formatted_chunks += formatted_chunk
140
+
141
+ prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
142
+
143
+ draft = chat_llm.invoke(prompt).content
144
+ return draft
requirements.txt CHANGED
@@ -30,4 +30,5 @@ tiktoken
30
  transformers >= 4.44.2
31
  rich >= 13.8.1
32
  trafilatura >= 1.12.2
33
- watchdog >= 2.1.5, < 5.0.0
 
 
30
  transformers >= 4.44.2
31
  rich >= 13.8.1
32
  trafilatura >= 1.12.2
33
+ watchdog >= 2.1.5, < 5.0.0
34
+ spacy >= 3.6.1, < 4.0.0
search_agent.py CHANGED
@@ -10,7 +10,7 @@ Usage:
10
  [--copywrite]
11
  [--max_pages=num]
12
  [--max_extracts=num]
13
- [--use_selenium]
14
  [--output=text]
15
  [--verbose]
16
  SEARCH_QUERY
@@ -23,10 +23,10 @@ Options:
23
  -d domain --domain=domain Limit search to a specific domain
24
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
25
  -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
26
- -e model --embedding_model=model Use a specific embedding model [default: same provider as model]
27
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
28
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
29
- -s --use_selenium Use selenium to fetch content from the web [default: False]
30
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
31
  -v --verbose Print verbose output [default: False]
32
 
@@ -49,6 +49,7 @@ import web_rag as wr
49
  import web_crawler as wc
50
  import copywriter as cw
51
  import models as md
 
52
 
53
  console = Console()
54
  dotenv.load_dotenv()
@@ -91,32 +92,35 @@ def main(arguments):
91
  max_pages=int(arguments["--max_pages"])
92
  max_extract=int(arguments["--max_extracts"])
93
  output=arguments["--output"]
94
- use_selenium=arguments["--use_selenium"]
95
  query = arguments["SEARCH_QUERY"]
96
 
97
  chat = md.get_model(model, temperature)
98
- if embedding_model.lower() == "same provider as model":
99
- provider = model.split(':')[0]
100
- embedding_model = md.get_embedding_model(f"{provider}")
101
  else:
102
  embedding_model = md.get_embedding_model(embedding_model)
 
103
 
104
  if verbose:
105
  model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
106
- embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
107
- console.log(f"Using model: {model_name}")
108
  console.log(f"Using embedding model: {embedding_model_name}")
 
 
 
 
109
 
110
  with console.status(f"[bold green]Optimizing query for search: {query}"):
111
- optimize_search_query = wr.optimize_search_query(chat, query)
112
- if len(optimize_search_query) < 3:
113
- optimize_search_query = query
114
- console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
115
 
116
  with console.status(
117
- f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
118
  ):
119
- sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
120
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
121
 
122
  with console.status(
@@ -125,11 +129,28 @@ def main(arguments):
125
  contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
126
  console.log(f"Managed to extract content from {len(contents)} sources")
127
 
128
- with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
129
- vector_store = wc.vectorize(contents, embedding_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- with console.status("[bold green]Writing content", spinner='dots8Bit'):
132
- draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
133
 
134
  console.rule(f"[bold green]Response")
135
  if output == "text":
 
10
  [--copywrite]
11
  [--max_pages=num]
12
  [--max_extracts=num]
13
+ [--use_browser]
14
  [--output=text]
15
  [--verbose]
16
  SEARCH_QUERY
 
23
  -d domain --domain=domain Limit search to a specific domain
24
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
25
  -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
26
+ -e model --embedding_model=model Use an embedding model
27
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
28
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
29
+ -b --use_browser Use browser to fetch content from the web [default: False]
30
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
31
  -v --verbose Print verbose output [default: False]
32
 
 
49
  import web_crawler as wc
50
  import copywriter as cw
51
  import models as md
52
+ import nlp_rag as nr
53
 
54
  console = Console()
55
  dotenv.load_dotenv()
 
92
  max_pages=int(arguments["--max_pages"])
93
  max_extract=int(arguments["--max_extracts"])
94
  output=arguments["--output"]
95
+ use_selenium=arguments["--use_browser"]
96
  query = arguments["SEARCH_QUERY"]
97
 
98
  chat = md.get_model(model, temperature)
99
+ if embedding_model is None:
100
+ use_nlp = True
101
+ nlp = nr.get_nlp_model()
102
  else:
103
  embedding_model = md.get_embedding_model(embedding_model)
104
+ use_nlp = False
105
 
106
  if verbose:
107
  model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
 
 
108
  console.log(f"Using embedding model: {embedding_model_name}")
109
+ if not use_nlp:
110
+ embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
111
+ console.log(f"Using model: {embedding_model_name}")
112
+
113
 
114
  with console.status(f"[bold green]Optimizing query for search: {query}"):
115
+ optimized_search_query = wr.optimize_search_query(chat, query)
116
+ if len(optimized_search_query) < 3:
117
+ optimized_search_query = query
118
+ console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
119
 
120
  with console.status(
121
+ f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
122
  ):
123
+ sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
124
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
125
 
126
  with console.status(
 
129
  contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
130
  console.log(f"Managed to extract content from {len(contents)} sources")
131
 
132
+ if use_nlp:
133
+ with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
134
+ chunks = nr.recursive_split_documents(contents)
135
+ #chunks = nr.chunk_contents(nlp, contents)
136
+ console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
137
+ with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
138
+ import time
139
+
140
+ start_time = time.time()
141
+ relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
142
+ end_time = time.time()
143
+ execution_time = end_time - start_time
144
+ console.log(f"Semantic search took {execution_time:.2f} seconds")
145
+ console.log(f"Found {len(relevant_results)} relevant chunks")
146
+ with console.status(f"[bold green]Writing content", spinner="growVertical"):
147
+ draft = nr.query_rag(chat, query, relevant_results)
148
+ else:
149
+ with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
150
+ vector_store = wc.vectorize(contents, embedding_model)
151
+ with console.status("[bold green]Writing content", spinner='dots8Bit'):
152
+ draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k = max_extract)
153
 
 
 
154
 
155
  console.rule(f"[bold green]Response")
156
  if output == "text":
spacy.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
web_rag.py CHANGED
@@ -74,13 +74,13 @@ def get_optimized_search_messages(query):
74
  chocolate chip cookies recipe from scratch**
75
  Example:
76
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
77
- "Marie Curie" timeline**
78
  Example:
79
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
80
  geopolitics nato russia**
81
  Example:
82
  Question: Write an engaging LinkedIn post about Andrew Ng
83
- "Andrew Ng"**
84
  Example:
85
  Question: Write a short article about the solar system in the style of Carl Sagan
86
  solar system**
 
74
  chocolate chip cookies recipe from scratch**
75
  Example:
76
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
77
+ Marie Curie timeline**
78
  Example:
79
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
80
  geopolitics nato russia**
81
  Example:
82
  Question: Write an engaging LinkedIn post about Andrew Ng
83
+ Andrew Ng**
84
  Example:
85
  Question: Write a short article about the solar system in the style of Carl Sagan
86
  solar system**