Spaces:
Running
Running
Cohere support. Added Selenium to fetch some pages.
Browse files- requirements.txt +3 -1
- search_agent.py +44 -23
requirements.txt
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
boto3
|
2 |
bs4
|
|
|
3 |
docopt
|
4 |
faiss-cpu
|
5 |
google-api-python-client
|
6 |
pdfplumber
|
7 |
python-dotenv
|
8 |
langchain
|
|
|
9 |
langchain_core
|
10 |
langchain_community
|
11 |
langchain_experimental
|
@@ -13,4 +15,4 @@ langchain_openai
|
|
13 |
langchain_groq
|
14 |
langsmith
|
15 |
rich
|
16 |
-
trafilatura
|
|
|
1 |
boto3
|
2 |
bs4
|
3 |
+
cohere
|
4 |
docopt
|
5 |
faiss-cpu
|
6 |
google-api-python-client
|
7 |
pdfplumber
|
8 |
python-dotenv
|
9 |
langchain
|
10 |
+
langchain-cohere
|
11 |
langchain_core
|
12 |
langchain_community
|
13 |
langchain_experimental
|
|
|
15 |
langchain_groq
|
16 |
langsmith
|
17 |
rich
|
18 |
+
trafilatura
|
search_agent.py
CHANGED
@@ -16,7 +16,7 @@ Options:
|
|
16 |
--version Show version.
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
-
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama) [default: openai]
|
20 |
-m model --model=model Use a specific model
|
21 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
22 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
@@ -29,16 +29,19 @@ import io
|
|
29 |
from concurrent.futures import ThreadPoolExecutor
|
30 |
from urllib.parse import quote
|
31 |
|
32 |
-
from bs4 import BeautifulSoup
|
33 |
from docopt import docopt
|
34 |
import dotenv
|
35 |
import pdfplumber
|
36 |
from trafilatura import extract
|
37 |
|
|
|
|
|
|
|
38 |
from langchain_core.documents.base import Document
|
39 |
from langchain_experimental.text_splitter import SemanticChunker
|
40 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
41 |
from langchain.callbacks import LangChainTracer
|
|
|
42 |
from langchain_groq import ChatGroq
|
43 |
from langchain_openai import ChatOpenAI
|
44 |
from langchain_openai import OpenAIEmbeddings
|
@@ -78,6 +81,10 @@ def get_chat_llm(provider, model=None, temperature=0.0):
|
|
78 |
if model is None:
|
79 |
model = 'llama2'
|
80 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
|
|
|
|
|
|
|
|
81 |
case _:
|
82 |
raise ValueError(f"Unknown LLM provider {provider}")
|
83 |
|
@@ -127,31 +134,39 @@ def get_sources(query, max_pages=10, domain=None):
|
|
127 |
console.log('Error fetching search results:', error)
|
128 |
raise
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
def fetch_with_timeout(url, timeout=8):
|
133 |
-
|
134 |
try:
|
135 |
response = requests.get(url, timeout=timeout)
|
136 |
response.raise_for_status()
|
137 |
return response
|
138 |
except requests.RequestException as error:
|
139 |
-
console.log(f"Skipping {url}! Error: {error}")
|
140 |
return None
|
141 |
|
142 |
-
def extract_main_content(html):
|
143 |
-
try:
|
144 |
-
soup = BeautifulSoup(html, 'html.parser')
|
145 |
-
for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]):
|
146 |
-
element.extract()
|
147 |
-
main_content = soup.get_text(separator='\n', strip=True)
|
148 |
-
return main_content
|
149 |
-
except Exception:
|
150 |
-
return None
|
151 |
|
152 |
def process_source(source):
|
153 |
-
|
154 |
-
console.log(f"Processing {
|
|
|
155 |
if response:
|
156 |
content_type = response.headers.get('Content-Type')
|
157 |
if content_type:
|
@@ -172,16 +187,24 @@ def process_source(source):
|
|
172 |
main_content = extract(html, output_format='txt', include_links=True)
|
173 |
return {**source, 'page_content': main_content}
|
174 |
else:
|
175 |
-
console.log(f"Skipping {
|
176 |
return {**source, 'page_content': source['snippet']}
|
177 |
else:
|
178 |
-
console.log(f"Skipping {
|
179 |
return {**source, 'page_content': source['snippet']}
|
180 |
-
return None
|
181 |
|
182 |
def get_links_contents(sources):
|
183 |
with ThreadPoolExecutor() as executor:
|
184 |
-
results = list(executor.map(process_source, sources))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
# Filter out None results
|
187 |
return [result for result in results if result is not None]
|
@@ -228,9 +251,7 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore):
|
|
228 |
|
229 |
|
230 |
def query_rag(chat_llm, question, search_query, vectorstore):
|
231 |
-
|
232 |
-
#unique_docs = retriver.get_relevant_documents(search_query, callbacks=callbacks, verbose=True)
|
233 |
-
unique_docs = vectorstore.similarity_search(search_query, k=5)
|
234 |
context = format_docs(unique_docs)
|
235 |
prompt = get_rag_prompt_template().format(query=question, context=context)
|
236 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
|
|
16 |
--version Show version.
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
+
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
|
20 |
-m model --model=model Use a specific model
|
21 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
22 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
|
|
29 |
from concurrent.futures import ThreadPoolExecutor
|
30 |
from urllib.parse import quote
|
31 |
|
|
|
32 |
from docopt import docopt
|
33 |
import dotenv
|
34 |
import pdfplumber
|
35 |
from trafilatura import extract
|
36 |
|
37 |
+
from selenium import webdriver
|
38 |
+
from selenium.webdriver.chrome.options import Options
|
39 |
+
|
40 |
from langchain_core.documents.base import Document
|
41 |
from langchain_experimental.text_splitter import SemanticChunker
|
42 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
43 |
from langchain.callbacks import LangChainTracer
|
44 |
+
from langchain_cohere.chat_models import ChatCohere
|
45 |
from langchain_groq import ChatGroq
|
46 |
from langchain_openai import ChatOpenAI
|
47 |
from langchain_openai import OpenAIEmbeddings
|
|
|
81 |
if model is None:
|
82 |
model = 'llama2'
|
83 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
84 |
+
case 'cohere':
|
85 |
+
if model is None:
|
86 |
+
model = 'command-r-plus'
|
87 |
+
chat_llm = ChatCohere(model=model, temperature=temperature)
|
88 |
case _:
|
89 |
raise ValueError(f"Unknown LLM provider {provider}")
|
90 |
|
|
|
134 |
console.log('Error fetching search results:', error)
|
135 |
raise
|
136 |
|
137 |
+
def fetch_with_selenium(url, timeout=8):
|
138 |
+
chrome_options = Options()
|
139 |
+
chrome_options.add_argument("headless")
|
140 |
+
chrome_options.add_argument("--disable-extensions")
|
141 |
+
chrome_options.add_argument("--disable-gpu")
|
142 |
+
chrome_options.add_argument("--no-sandbox")
|
143 |
+
chrome_options.add_argument("--disable-dev-shm-usage")
|
144 |
+
chrome_options.add_argument("--remote-debugging-port=9222")
|
145 |
+
chrome_options.add_argument('--blink-settings=imagesEnabled=false')
|
146 |
+
chrome_options.add_argument("--window-size=1920,1080")
|
147 |
+
|
148 |
+
driver = webdriver.Chrome(options=chrome_options)
|
149 |
+
|
150 |
+
driver.get(url)
|
151 |
+
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
152 |
+
html = driver.page_source
|
153 |
+
driver.quit()
|
154 |
+
|
155 |
+
return html
|
156 |
|
157 |
def fetch_with_timeout(url, timeout=8):
|
|
|
158 |
try:
|
159 |
response = requests.get(url, timeout=timeout)
|
160 |
response.raise_for_status()
|
161 |
return response
|
162 |
except requests.RequestException as error:
|
|
|
163 |
return None
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
def process_source(source):
|
167 |
+
url = source['link']
|
168 |
+
#console.log(f"Processing {url}")
|
169 |
+
response = fetch_with_timeout(url, 8)
|
170 |
if response:
|
171 |
content_type = response.headers.get('Content-Type')
|
172 |
if content_type:
|
|
|
187 |
main_content = extract(html, output_format='txt', include_links=True)
|
188 |
return {**source, 'page_content': main_content}
|
189 |
else:
|
190 |
+
console.log(f"Skipping {url}! Unsupported content type: {content_type}")
|
191 |
return {**source, 'page_content': source['snippet']}
|
192 |
else:
|
193 |
+
console.log(f"Skipping {url}! No content type")
|
194 |
return {**source, 'page_content': source['snippet']}
|
195 |
+
return {**source, 'page_content': None}
|
196 |
|
197 |
def get_links_contents(sources):
|
198 |
with ThreadPoolExecutor() as executor:
|
199 |
+
results = list(executor.map(process_source, sources))
|
200 |
+
for result in results:
|
201 |
+
if result['page_content'] is None:
|
202 |
+
url = result['link']
|
203 |
+
console.log(f"Fetching with selenium {url}")
|
204 |
+
html = fetch_with_selenium(url, 8)
|
205 |
+
main_content = extract(html, output_format='txt', include_links=True)
|
206 |
+
if main_content:
|
207 |
+
result['page_content'] = main_content
|
208 |
|
209 |
# Filter out None results
|
210 |
return [result for result in results if result is not None]
|
|
|
251 |
|
252 |
|
253 |
def query_rag(chat_llm, question, search_query, vectorstore):
|
254 |
+
unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
|
|
|
|
|
255 |
context = format_docs(unique_docs)
|
256 |
prompt = get_rag_prompt_template().format(query=question, context=context)
|
257 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|