Spaces:
Running
Running
Fixed web UI
Browse files- models.py +0 -2
- nlp_rag.py +30 -14
- search_agent_ui.py +22 -47
- web_rag.py +5 -50
models.py
CHANGED
@@ -90,9 +90,7 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
|
90 |
model = 'Qwen/Qwen2.5-72B-Instruct'
|
91 |
llm = HuggingFaceEndpoint(
|
92 |
repo_id=model,
|
93 |
-
max_length=8192,
|
94 |
temperature=temperature,
|
95 |
-
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"),
|
96 |
)
|
97 |
chat_llm = ChatHuggingFace(llm=llm)
|
98 |
case 'ollama':
|
|
|
90 |
model = 'Qwen/Qwen2.5-72B-Instruct'
|
91 |
llm = HuggingFaceEndpoint(
|
92 |
repo_id=model,
|
|
|
93 |
temperature=temperature,
|
|
|
94 |
)
|
95 |
chat_llm = ChatHuggingFace(llm=llm)
|
96 |
case 'ollama':
|
nlp_rag.py
CHANGED
@@ -115,7 +115,7 @@ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
|
115 |
|
116 |
|
117 |
@traceable(run_type="llm", name="nlp_rag")
|
118 |
-
def query_rag(chat_llm, query, relevant_results):
|
119 |
"""
|
120 |
Generate a response using retrieval-augmented generation (RAG) based on relevant results.
|
121 |
|
@@ -127,20 +127,36 @@ def query_rag(chat_llm, query, relevant_results):
|
|
127 |
Returns:
|
128 |
str: The generated response.
|
129 |
"""
|
130 |
-
|
|
|
|
|
131 |
|
132 |
-
formatted_chunks = ""
|
133 |
-
for chunk, similarity in relevant_results:
|
134 |
-
formatted_chunk = f"""
|
135 |
-
<source>
|
136 |
-
<url>{chunk['metadata']['source']}</url>
|
137 |
-
<title>{chunk['metadata']['title']}</title>
|
138 |
-
<text>{chunk['text']}</text>
|
139 |
-
</source>
|
140 |
-
"""
|
141 |
-
formatted_chunks += formatted_chunk
|
142 |
|
|
|
|
|
|
|
143 |
prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
|
146 |
-
return draft
|
|
|
115 |
|
116 |
|
117 |
@traceable(run_type="llm", name="nlp_rag")
|
118 |
+
def query_rag(chat_llm, query, relevant_results, callbacks = []):
|
119 |
"""
|
120 |
Generate a response using retrieval-augmented generation (RAG) based on relevant results.
|
121 |
|
|
|
127 |
Returns:
|
128 |
str: The generated response.
|
129 |
"""
|
130 |
+
prompt = build_rag_prompt(query, relevant_results)
|
131 |
+
response = chat_llm.invoke(prompt).content
|
132 |
+
return response
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
def build_rag_prompt(query, relevant_results):
|
136 |
+
import web_rag as wr
|
137 |
+
formatted_chunks = format_docs(relevant_results)
|
138 |
prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
|
139 |
+
return prompt
|
140 |
+
|
141 |
+
def format_docs(relevant_results):
|
142 |
+
"""
|
143 |
+
Convert relevant search results into a JSON-formatted string.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
relevant_results (list): List of relevant chunks with metadata.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
str: JSON-formatted string of document chunks.
|
150 |
+
"""
|
151 |
+
import json
|
152 |
+
|
153 |
+
formatted_chunks = []
|
154 |
+
for chunk, _ in relevant_results: # Unpack the tuple, ignore similarity score
|
155 |
+
formatted_chunk = {
|
156 |
+
"content": chunk['text'],
|
157 |
+
"link": chunk['metadata'].get('source', ''),
|
158 |
+
"title": chunk['metadata'].get('title', ''),
|
159 |
+
}
|
160 |
+
formatted_chunks.append(formatted_chunk)
|
161 |
|
162 |
+
return json.dumps(formatted_chunks, indent=2)
|
|
search_agent_ui.py
CHANGED
@@ -3,16 +3,19 @@ import os
|
|
3 |
|
4 |
import dotenv
|
5 |
import streamlit as st
|
|
|
6 |
|
7 |
from langchain_core.tracers.langchain import LangChainTracer
|
8 |
from langchain.callbacks.base import BaseCallbackHandler
|
9 |
from langsmith.client import Client
|
10 |
|
11 |
-
import web_rag as wr
|
12 |
import web_crawler as wc
|
13 |
-
import copywriter as cw
|
14 |
import models as md
|
|
|
|
|
|
|
15 |
dotenv.load_dotenv()
|
|
|
16 |
|
17 |
ls_tracer = LangChainTracer(
|
18 |
project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
|
@@ -56,6 +59,12 @@ st.title("🔍 Simple Search Agent 💬")
|
|
56 |
|
57 |
if "models" not in st.session_state:
|
58 |
models = []
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if os.getenv("FIREWORKS_API_KEY"):
|
60 |
models.append("fireworks")
|
61 |
if os.getenv("TOGETHER_API_KEY"):
|
@@ -75,24 +84,12 @@ if "models" not in st.session_state:
|
|
75 |
with st.sidebar.expander("Options", expanded=False):
|
76 |
model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
|
77 |
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
78 |
-
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to
|
79 |
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 10, help="How many of the top extracts to consider")
|
80 |
-
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
|
81 |
|
82 |
with st.sidebar.expander("Links", expanded=False):
|
83 |
links_md = st.markdown("")
|
84 |
|
85 |
-
if reviewer_mode:
|
86 |
-
with st.sidebar.expander("Answer review", expanded=False):
|
87 |
-
st.caption("Draft")
|
88 |
-
draft_md = st.markdown("")
|
89 |
-
st.divider()
|
90 |
-
st.caption("Comments")
|
91 |
-
comments_md = st.markdown("")
|
92 |
-
st.divider()
|
93 |
-
st.caption("Comparaison")
|
94 |
-
comparaison_md = st.markdown("")
|
95 |
-
|
96 |
if "messages" not in st.session_state:
|
97 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
98 |
|
@@ -106,47 +103,30 @@ for message in st.session_state.messages:
|
|
106 |
mime="text/plain"
|
107 |
)
|
108 |
|
109 |
-
if prompt := st.chat_input("Enter
|
110 |
st.chat_message("user").write(prompt)
|
111 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
112 |
|
113 |
chat = md.get_model(model_provider, temperature)
|
114 |
-
embedding_model = md.get_embedding_model(model_provider)
|
115 |
|
116 |
with st.status("Thinking", expanded=True):
|
117 |
st.write("I first need to do some research")
|
118 |
|
119 |
-
|
120 |
-
st.write(f"I should search the web for: {
|
121 |
|
122 |
-
sources = wc.get_sources(
|
123 |
links_md.markdown(create_links_markdown(sources))
|
124 |
|
125 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
126 |
contents = wc.get_links_contents(sources, use_selenium=False)
|
127 |
|
128 |
-
st.write(
|
129 |
-
|
130 |
-
|
|
|
131 |
|
132 |
-
|
133 |
-
if reviewer_mode:
|
134 |
-
st.write("Creating a draft")
|
135 |
-
draft_prompt = wr.build_rag_prompt(
|
136 |
-
chat, prompt, optimize_search_query,
|
137 |
-
vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
|
138 |
-
draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
139 |
-
draft_md.markdown(draft.content)
|
140 |
-
st.write("Sending draft for review")
|
141 |
-
comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
|
142 |
-
comments_md.markdown(comments)
|
143 |
-
st.write("Reviewing comments and generating final answer")
|
144 |
-
rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
|
145 |
-
else:
|
146 |
-
rag_prompt = wr.build_rag_prompt(
|
147 |
-
chat, prompt, optimize_search_query, vector_store,
|
148 |
-
top_k=top_k_documents, callbacks=[ls_tracer]
|
149 |
-
)
|
150 |
|
151 |
with st.chat_message("assistant"):
|
152 |
st_cb = StreamHandler(st.empty())
|
@@ -185,9 +165,4 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
185 |
data=st.session_state.messages[-1]["content"],
|
186 |
file_name=f"{message_id}.txt",
|
187 |
mime="text/plain"
|
188 |
-
)
|
189 |
-
|
190 |
-
if reviewer_mode:
|
191 |
-
compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
|
192 |
-
result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
193 |
-
comparaison_md.markdown(result.content)
|
|
|
3 |
|
4 |
import dotenv
|
5 |
import streamlit as st
|
6 |
+
import spacy
|
7 |
|
8 |
from langchain_core.tracers.langchain import LangChainTracer
|
9 |
from langchain.callbacks.base import BaseCallbackHandler
|
10 |
from langsmith.client import Client
|
11 |
|
|
|
12 |
import web_crawler as wc
|
|
|
13 |
import models as md
|
14 |
+
import nlp_rag as nr
|
15 |
+
import web_rag as wr
|
16 |
+
|
17 |
dotenv.load_dotenv()
|
18 |
+
nlp = nr.get_nlp_model()
|
19 |
|
20 |
ls_tracer = LangChainTracer(
|
21 |
project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
|
|
|
59 |
|
60 |
if "models" not in st.session_state:
|
61 |
models = []
|
62 |
+
if os.getenv("MISTRAL_API_KEY"):
|
63 |
+
models.append("mistral")
|
64 |
+
if os.getenv("HF_TOKEN"):
|
65 |
+
models.append("huggingface")
|
66 |
+
if os.getenv("COHERE_API_KEY"):
|
67 |
+
models.append("cohere")
|
68 |
if os.getenv("FIREWORKS_API_KEY"):
|
69 |
models.append("fireworks")
|
70 |
if os.getenv("TOGETHER_API_KEY"):
|
|
|
84 |
with st.sidebar.expander("Options", expanded=False):
|
85 |
model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
|
86 |
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
87 |
+
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrieve from the internet")
|
88 |
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 10, help="How many of the top extracts to consider")
|
|
|
89 |
|
90 |
with st.sidebar.expander("Links", expanded=False):
|
91 |
links_md = st.markdown("")
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if "messages" not in st.session_state:
|
94 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
95 |
|
|
|
103 |
mime="text/plain"
|
104 |
)
|
105 |
|
106 |
+
if prompt := st.chat_input("Enter your instructions..." ):
|
107 |
st.chat_message("user").write(prompt)
|
108 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
109 |
|
110 |
chat = md.get_model(model_provider, temperature)
|
|
|
111 |
|
112 |
with st.status("Thinking", expanded=True):
|
113 |
st.write("I first need to do some research")
|
114 |
|
115 |
+
optimized_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
|
116 |
+
st.write(f"I should search the web for: {optimized_search_query}")
|
117 |
|
118 |
+
sources = wc.get_sources(optimized_search_query, max_pages=max_pages)
|
119 |
links_md.markdown(create_links_markdown(sources))
|
120 |
|
121 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
122 |
contents = wc.get_links_contents(sources, use_selenium=False)
|
123 |
|
124 |
+
st.write(f"Reading through the {len(contents)} sources I managed to retrieve")
|
125 |
+
chunks = nr.recursive_split_documents(contents)
|
126 |
+
relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=top_k_documents)
|
127 |
+
st.write(f"I collected {len(relevant_results)} chunks of data and I can now answer")
|
128 |
|
129 |
+
rag_prompt = nr.build_rag_prompt(query=prompt, relevant_results=relevant_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
with st.chat_message("assistant"):
|
132 |
st_cb = StreamHandler(st.empty())
|
|
|
165 |
data=st.session_state.messages[-1]["content"],
|
166 |
file_name=f"{message_id}.txt",
|
167 |
mime="text/plain"
|
168 |
+
)
|
|
|
|
|
|
|
|
|
|
web_rag.py
CHANGED
@@ -19,6 +19,7 @@ Perform RAG using a single query to retrieve relevant documents.
|
|
19 |
"""
|
20 |
import os
|
21 |
import json
|
|
|
22 |
from docopt import re
|
23 |
from langchain.schema import SystemMessage, HumanMessage
|
24 |
from langchain.prompts.chat import (
|
@@ -115,55 +116,6 @@ def get_optimized_search_messages(query):
|
|
115 |
|
116 |
|
117 |
|
118 |
-
def get_optimized_search_messages2(query):
|
119 |
-
"""
|
120 |
-
Generate optimized search messages for a given query.
|
121 |
-
|
122 |
-
Args:
|
123 |
-
query (str): The user's query.
|
124 |
-
|
125 |
-
Returns:
|
126 |
-
list: A list containing the system message and human message for optimized search.
|
127 |
-
"""
|
128 |
-
system_message = SystemMessage(
|
129 |
-
content="""
|
130 |
-
You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
|
131 |
-
|
132 |
-
The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
|
133 |
-
|
134 |
-
Here are some key principles for creating effective search queries:
|
135 |
-
1. Use specific and relevant keywords
|
136 |
-
2. Remove unnecessary words (articles, prepositions, etc.)
|
137 |
-
3. Utilize quotation marks for exact phrases
|
138 |
-
4. Employ Boolean operators (AND, OR, NOT) when appropriate
|
139 |
-
5. Include synonyms or related terms to broaden the search
|
140 |
-
|
141 |
-
I will provide you with a chat prompt or question. Your task is to optimize this into an effective search string.
|
142 |
-
|
143 |
-
Process the input as follows:
|
144 |
-
1. Analyze the Question to identify the main topic and key concepts.
|
145 |
-
2. Extract the most relevant keywords and phrases.
|
146 |
-
3. Consider any implicit information or context that might be useful for the search.
|
147 |
-
|
148 |
-
Then, optimize the search string by:
|
149 |
-
1. Removing filler words and unnecessary language
|
150 |
-
2. Rearranging keywords in a logical order
|
151 |
-
3. Adding quotation marks around exact phrases if applicable
|
152 |
-
4. Including relevant synonyms or related terms (in parentheses) to broaden the search
|
153 |
-
5. Using Boolean operators if needed to refine the search
|
154 |
-
|
155 |
-
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
|
156 |
-
"""
|
157 |
-
)
|
158 |
-
human_message = HumanMessage(
|
159 |
-
content=f"""
|
160 |
-
Question: {query}
|
161 |
-
|
162 |
-
"""
|
163 |
-
)
|
164 |
-
return [system_message, human_message]
|
165 |
-
|
166 |
-
|
167 |
@traceable(run_type="llm", name="optimize_search_query")
|
168 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
169 |
"""
|
@@ -200,10 +152,11 @@ def get_rag_prompt_template():
|
|
200 |
Returns:
|
201 |
ChatPromptTemplate: The prompt template for RAG.
|
202 |
"""
|
|
|
203 |
system_prompt = SystemMessagePromptTemplate(
|
204 |
prompt=PromptTemplate(
|
205 |
input_variables=[],
|
206 |
-
template="""
|
207 |
You are an expert research assistant.
|
208 |
You are provided with a Context in JSON format and a Question.
|
209 |
Each JSON entry contains: content, title, link
|
@@ -219,6 +172,8 @@ def get_rag_prompt_template():
|
|
219 |
If the provided context is not relevant to the question, say it and answer with your internal knowledge.
|
220 |
If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
|
221 |
If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
|
|
|
|
|
222 |
"""
|
223 |
)
|
224 |
)
|
|
|
19 |
"""
|
20 |
import os
|
21 |
import json
|
22 |
+
from datetime import datetime
|
23 |
from docopt import re
|
24 |
from langchain.schema import SystemMessage, HumanMessage
|
25 |
from langchain.prompts.chat import (
|
|
|
116 |
|
117 |
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
@traceable(run_type="llm", name="optimize_search_query")
|
120 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
121 |
"""
|
|
|
152 |
Returns:
|
153 |
ChatPromptTemplate: The prompt template for RAG.
|
154 |
"""
|
155 |
+
today = datetime.now().strftime("%Y-%m-%d")
|
156 |
system_prompt = SystemMessagePromptTemplate(
|
157 |
prompt=PromptTemplate(
|
158 |
input_variables=[],
|
159 |
+
template=f"""
|
160 |
You are an expert research assistant.
|
161 |
You are provided with a Context in JSON format and a Question.
|
162 |
Each JSON entry contains: content, title, link
|
|
|
172 |
If the provided context is not relevant to the question, say it and answer with your internal knowledge.
|
173 |
If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
|
174 |
If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
|
175 |
+
|
176 |
+
Today's date is {today}
|
177 |
"""
|
178 |
)
|
179 |
)
|