CyranoB commited on
Commit
d21cce9
·
1 Parent(s): 8d0d362

Cohere support. Added Selenium to fetch some pages.

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -1
  2. search_agent.py +44 -23
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
  boto3
2
  bs4
 
3
  docopt
4
  faiss-cpu
5
  google-api-python-client
6
  pdfplumber
7
  python-dotenv
8
  langchain
 
9
  langchain_core
10
  langchain_community
11
  langchain_experimental
@@ -13,4 +15,4 @@ langchain_openai
13
  langchain_groq
14
  langsmith
15
  rich
16
- trafilatura
 
1
  boto3
2
  bs4
3
+ cohere
4
  docopt
5
  faiss-cpu
6
  google-api-python-client
7
  pdfplumber
8
  python-dotenv
9
  langchain
10
+ langchain-cohere
11
  langchain_core
12
  langchain_community
13
  langchain_experimental
 
15
  langchain_groq
16
  langsmith
17
  rich
18
+ trafilatura
search_agent.py CHANGED
@@ -16,7 +16,7 @@ Options:
16
  --version Show version.
17
  -d domain --domain=domain Limit search to a specific domain
18
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
19
- -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama) [default: openai]
20
  -m model --model=model Use a specific model
21
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
22
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
@@ -29,16 +29,19 @@ import io
29
  from concurrent.futures import ThreadPoolExecutor
30
  from urllib.parse import quote
31
 
32
- from bs4 import BeautifulSoup
33
  from docopt import docopt
34
  import dotenv
35
  import pdfplumber
36
  from trafilatura import extract
37
 
 
 
 
38
  from langchain_core.documents.base import Document
39
  from langchain_experimental.text_splitter import SemanticChunker
40
  from langchain.retrievers.multi_query import MultiQueryRetriever
41
  from langchain.callbacks import LangChainTracer
 
42
  from langchain_groq import ChatGroq
43
  from langchain_openai import ChatOpenAI
44
  from langchain_openai import OpenAIEmbeddings
@@ -78,6 +81,10 @@ def get_chat_llm(provider, model=None, temperature=0.0):
78
  if model is None:
79
  model = 'llama2'
80
  chat_llm = ChatOllama(model=model, temperature=temperature)
 
 
 
 
81
  case _:
82
  raise ValueError(f"Unknown LLM provider {provider}")
83
 
@@ -127,31 +134,39 @@ def get_sources(query, max_pages=10, domain=None):
127
  console.log('Error fetching search results:', error)
128
  raise
129
 
130
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def fetch_with_timeout(url, timeout=8):
133
-
134
  try:
135
  response = requests.get(url, timeout=timeout)
136
  response.raise_for_status()
137
  return response
138
  except requests.RequestException as error:
139
- console.log(f"Skipping {url}! Error: {error}")
140
  return None
141
 
142
- def extract_main_content(html):
143
- try:
144
- soup = BeautifulSoup(html, 'html.parser')
145
- for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]):
146
- element.extract()
147
- main_content = soup.get_text(separator='\n', strip=True)
148
- return main_content
149
- except Exception:
150
- return None
151
 
152
  def process_source(source):
153
- response = fetch_with_timeout(source['link'], 8)
154
- console.log(f"Processing {source['link']}")
 
155
  if response:
156
  content_type = response.headers.get('Content-Type')
157
  if content_type:
@@ -172,16 +187,24 @@ def process_source(source):
172
  main_content = extract(html, output_format='txt', include_links=True)
173
  return {**source, 'page_content': main_content}
174
  else:
175
- console.log(f"Skipping {source['link']}! Unsupported content type: {content_type}")
176
  return {**source, 'page_content': source['snippet']}
177
  else:
178
- console.log(f"Skipping {source['link']}! No content type")
179
  return {**source, 'page_content': source['snippet']}
180
- return None
181
 
182
  def get_links_contents(sources):
183
  with ThreadPoolExecutor() as executor:
184
- results = list(executor.map(process_source, sources))
 
 
 
 
 
 
 
 
185
 
186
  # Filter out None results
187
  return [result for result in results if result is not None]
@@ -228,9 +251,7 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore):
228
 
229
 
230
  def query_rag(chat_llm, question, search_query, vectorstore):
231
- #retriver = vectorstore.as_retriever()
232
- #unique_docs = retriver.get_relevant_documents(search_query, callbacks=callbacks, verbose=True)
233
- unique_docs = vectorstore.similarity_search(search_query, k=5)
234
  context = format_docs(unique_docs)
235
  prompt = get_rag_prompt_template().format(query=question, context=context)
236
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
 
16
  --version Show version.
17
  -d domain --domain=domain Limit search to a specific domain
18
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
19
+ -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
20
  -m model --model=model Use a specific model
21
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
22
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
 
29
  from concurrent.futures import ThreadPoolExecutor
30
  from urllib.parse import quote
31
 
 
32
  from docopt import docopt
33
  import dotenv
34
  import pdfplumber
35
  from trafilatura import extract
36
 
37
+ from selenium import webdriver
38
+ from selenium.webdriver.chrome.options import Options
39
+
40
  from langchain_core.documents.base import Document
41
  from langchain_experimental.text_splitter import SemanticChunker
42
  from langchain.retrievers.multi_query import MultiQueryRetriever
43
  from langchain.callbacks import LangChainTracer
44
+ from langchain_cohere.chat_models import ChatCohere
45
  from langchain_groq import ChatGroq
46
  from langchain_openai import ChatOpenAI
47
  from langchain_openai import OpenAIEmbeddings
 
81
  if model is None:
82
  model = 'llama2'
83
  chat_llm = ChatOllama(model=model, temperature=temperature)
84
+ case 'cohere':
85
+ if model is None:
86
+ model = 'command-r-plus'
87
+ chat_llm = ChatCohere(model=model, temperature=temperature)
88
  case _:
89
  raise ValueError(f"Unknown LLM provider {provider}")
90
 
 
134
  console.log('Error fetching search results:', error)
135
  raise
136
 
137
+ def fetch_with_selenium(url, timeout=8):
138
+ chrome_options = Options()
139
+ chrome_options.add_argument("headless")
140
+ chrome_options.add_argument("--disable-extensions")
141
+ chrome_options.add_argument("--disable-gpu")
142
+ chrome_options.add_argument("--no-sandbox")
143
+ chrome_options.add_argument("--disable-dev-shm-usage")
144
+ chrome_options.add_argument("--remote-debugging-port=9222")
145
+ chrome_options.add_argument('--blink-settings=imagesEnabled=false')
146
+ chrome_options.add_argument("--window-size=1920,1080")
147
+
148
+ driver = webdriver.Chrome(options=chrome_options)
149
+
150
+ driver.get(url)
151
+ driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
152
+ html = driver.page_source
153
+ driver.quit()
154
+
155
+ return html
156
 
157
  def fetch_with_timeout(url, timeout=8):
 
158
  try:
159
  response = requests.get(url, timeout=timeout)
160
  response.raise_for_status()
161
  return response
162
  except requests.RequestException as error:
 
163
  return None
164
 
 
 
 
 
 
 
 
 
 
165
 
166
  def process_source(source):
167
+ url = source['link']
168
+ #console.log(f"Processing {url}")
169
+ response = fetch_with_timeout(url, 8)
170
  if response:
171
  content_type = response.headers.get('Content-Type')
172
  if content_type:
 
187
  main_content = extract(html, output_format='txt', include_links=True)
188
  return {**source, 'page_content': main_content}
189
  else:
190
+ console.log(f"Skipping {url}! Unsupported content type: {content_type}")
191
  return {**source, 'page_content': source['snippet']}
192
  else:
193
+ console.log(f"Skipping {url}! No content type")
194
  return {**source, 'page_content': source['snippet']}
195
+ return {**source, 'page_content': None}
196
 
197
  def get_links_contents(sources):
198
  with ThreadPoolExecutor() as executor:
199
+ results = list(executor.map(process_source, sources))
200
+ for result in results:
201
+ if result['page_content'] is None:
202
+ url = result['link']
203
+ console.log(f"Fetching with selenium {url}")
204
+ html = fetch_with_selenium(url, 8)
205
+ main_content = extract(html, output_format='txt', include_links=True)
206
+ if main_content:
207
+ result['page_content'] = main_content
208
 
209
  # Filter out None results
210
  return [result for result in results if result is not None]
 
251
 
252
 
253
  def query_rag(chat_llm, question, search_query, vectorstore):
254
+ unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
 
 
255
  context = format_docs(unique_docs)
256
  prompt = get_rag_prompt_template().format(query=question, context=context)
257
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})