CyranoB commited on
Commit
3f21537
·
1 Parent(s): 2e147cb

Fixed web UI

Browse files
Files changed (4) hide show
  1. models.py +0 -2
  2. nlp_rag.py +30 -14
  3. search_agent_ui.py +22 -47
  4. web_rag.py +5 -50
models.py CHANGED
@@ -90,9 +90,7 @@ def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
90
  model = 'Qwen/Qwen2.5-72B-Instruct'
91
  llm = HuggingFaceEndpoint(
92
  repo_id=model,
93
- max_length=8192,
94
  temperature=temperature,
95
- huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"),
96
  )
97
  chat_llm = ChatHuggingFace(llm=llm)
98
  case 'ollama':
 
90
  model = 'Qwen/Qwen2.5-72B-Instruct'
91
  llm = HuggingFaceEndpoint(
92
  repo_id=model,
 
93
  temperature=temperature,
 
94
  )
95
  chat_llm = ChatHuggingFace(llm=llm)
96
  case 'ollama':
nlp_rag.py CHANGED
@@ -115,7 +115,7 @@ def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
115
 
116
 
117
  @traceable(run_type="llm", name="nlp_rag")
118
- def query_rag(chat_llm, query, relevant_results):
119
  """
120
  Generate a response using retrieval-augmented generation (RAG) based on relevant results.
121
 
@@ -127,20 +127,36 @@ def query_rag(chat_llm, query, relevant_results):
127
  Returns:
128
  str: The generated response.
129
  """
130
- import web_rag as wr
 
 
131
 
132
- formatted_chunks = ""
133
- for chunk, similarity in relevant_results:
134
- formatted_chunk = f"""
135
- <source>
136
- <url>{chunk['metadata']['source']}</url>
137
- <title>{chunk['metadata']['title']}</title>
138
- <text>{chunk['text']}</text>
139
- </source>
140
- """
141
- formatted_chunks += formatted_chunk
142
 
 
 
 
143
  prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- draft = chat_llm.invoke(prompt).content
146
- return draft
 
115
 
116
 
117
  @traceable(run_type="llm", name="nlp_rag")
118
+ def query_rag(chat_llm, query, relevant_results, callbacks = []):
119
  """
120
  Generate a response using retrieval-augmented generation (RAG) based on relevant results.
121
 
 
127
  Returns:
128
  str: The generated response.
129
  """
130
+ prompt = build_rag_prompt(query, relevant_results)
131
+ response = chat_llm.invoke(prompt).content
132
+ return response
133
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def build_rag_prompt(query, relevant_results):
136
+ import web_rag as wr
137
+ formatted_chunks = format_docs(relevant_results)
138
  prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
139
+ return prompt
140
+
141
+ def format_docs(relevant_results):
142
+ """
143
+ Convert relevant search results into a JSON-formatted string.
144
+
145
+ Args:
146
+ relevant_results (list): List of relevant chunks with metadata.
147
+
148
+ Returns:
149
+ str: JSON-formatted string of document chunks.
150
+ """
151
+ import json
152
+
153
+ formatted_chunks = []
154
+ for chunk, _ in relevant_results: # Unpack the tuple, ignore similarity score
155
+ formatted_chunk = {
156
+ "content": chunk['text'],
157
+ "link": chunk['metadata'].get('source', ''),
158
+ "title": chunk['metadata'].get('title', ''),
159
+ }
160
+ formatted_chunks.append(formatted_chunk)
161
 
162
+ return json.dumps(formatted_chunks, indent=2)
 
search_agent_ui.py CHANGED
@@ -3,16 +3,19 @@ import os
3
 
4
  import dotenv
5
  import streamlit as st
 
6
 
7
  from langchain_core.tracers.langchain import LangChainTracer
8
  from langchain.callbacks.base import BaseCallbackHandler
9
  from langsmith.client import Client
10
 
11
- import web_rag as wr
12
  import web_crawler as wc
13
- import copywriter as cw
14
  import models as md
 
 
 
15
  dotenv.load_dotenv()
 
16
 
17
  ls_tracer = LangChainTracer(
18
  project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
@@ -56,6 +59,12 @@ st.title("🔍 Simple Search Agent 💬")
56
 
57
  if "models" not in st.session_state:
58
  models = []
 
 
 
 
 
 
59
  if os.getenv("FIREWORKS_API_KEY"):
60
  models.append("fireworks")
61
  if os.getenv("TOGETHER_API_KEY"):
@@ -75,24 +84,12 @@ if "models" not in st.session_state:
75
  with st.sidebar.expander("Options", expanded=False):
76
  model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
77
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
78
- max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrive from the internet")
79
  top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 10, help="How many of the top extracts to consider")
80
- reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
81
 
82
  with st.sidebar.expander("Links", expanded=False):
83
  links_md = st.markdown("")
84
 
85
- if reviewer_mode:
86
- with st.sidebar.expander("Answer review", expanded=False):
87
- st.caption("Draft")
88
- draft_md = st.markdown("")
89
- st.divider()
90
- st.caption("Comments")
91
- comments_md = st.markdown("")
92
- st.divider()
93
- st.caption("Comparaison")
94
- comparaison_md = st.markdown("")
95
-
96
  if "messages" not in st.session_state:
97
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
98
 
@@ -106,47 +103,30 @@ for message in st.session_state.messages:
106
  mime="text/plain"
107
  )
108
 
109
- if prompt := st.chat_input("Enter you instructions..." ):
110
  st.chat_message("user").write(prompt)
111
  st.session_state.messages.append({"role": "user", "content": prompt})
112
 
113
  chat = md.get_model(model_provider, temperature)
114
- embedding_model = md.get_embedding_model(model_provider)
115
 
116
  with st.status("Thinking", expanded=True):
117
  st.write("I first need to do some research")
118
 
119
- optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
120
- st.write(f"I should search the web for: {optimize_search_query}")
121
 
122
- sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
123
  links_md.markdown(create_links_markdown(sources))
124
 
125
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
126
  contents = wc.get_links_contents(sources, use_selenium=False)
127
 
128
- st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
129
- vector_store = wc.vectorize(contents, embedding_model=embedding_model)
130
- st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
 
131
 
132
-
133
- if reviewer_mode:
134
- st.write("Creating a draft")
135
- draft_prompt = wr.build_rag_prompt(
136
- chat, prompt, optimize_search_query,
137
- vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
138
- draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
139
- draft_md.markdown(draft.content)
140
- st.write("Sending draft for review")
141
- comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
142
- comments_md.markdown(comments)
143
- st.write("Reviewing comments and generating final answer")
144
- rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
145
- else:
146
- rag_prompt = wr.build_rag_prompt(
147
- chat, prompt, optimize_search_query, vector_store,
148
- top_k=top_k_documents, callbacks=[ls_tracer]
149
- )
150
 
151
  with st.chat_message("assistant"):
152
  st_cb = StreamHandler(st.empty())
@@ -185,9 +165,4 @@ if prompt := st.chat_input("Enter you instructions..." ):
185
  data=st.session_state.messages[-1]["content"],
186
  file_name=f"{message_id}.txt",
187
  mime="text/plain"
188
- )
189
-
190
- if reviewer_mode:
191
- compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
192
- result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
193
- comparaison_md.markdown(result.content)
 
3
 
4
  import dotenv
5
  import streamlit as st
6
+ import spacy
7
 
8
  from langchain_core.tracers.langchain import LangChainTracer
9
  from langchain.callbacks.base import BaseCallbackHandler
10
  from langsmith.client import Client
11
 
 
12
  import web_crawler as wc
 
13
  import models as md
14
+ import nlp_rag as nr
15
+ import web_rag as wr
16
+
17
  dotenv.load_dotenv()
18
+ nlp = nr.get_nlp_model()
19
 
20
  ls_tracer = LangChainTracer(
21
  project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
 
59
 
60
  if "models" not in st.session_state:
61
  models = []
62
+ if os.getenv("MISTRAL_API_KEY"):
63
+ models.append("mistral")
64
+ if os.getenv("HF_TOKEN"):
65
+ models.append("huggingface")
66
+ if os.getenv("COHERE_API_KEY"):
67
+ models.append("cohere")
68
  if os.getenv("FIREWORKS_API_KEY"):
69
  models.append("fireworks")
70
  if os.getenv("TOGETHER_API_KEY"):
 
84
  with st.sidebar.expander("Options", expanded=False):
85
  model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
86
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
87
+ max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrieve from the internet")
88
  top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 10, help="How many of the top extracts to consider")
 
89
 
90
  with st.sidebar.expander("Links", expanded=False):
91
  links_md = st.markdown("")
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  if "messages" not in st.session_state:
94
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
95
 
 
103
  mime="text/plain"
104
  )
105
 
106
+ if prompt := st.chat_input("Enter your instructions..." ):
107
  st.chat_message("user").write(prompt)
108
  st.session_state.messages.append({"role": "user", "content": prompt})
109
 
110
  chat = md.get_model(model_provider, temperature)
 
111
 
112
  with st.status("Thinking", expanded=True):
113
  st.write("I first need to do some research")
114
 
115
+ optimized_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
116
+ st.write(f"I should search the web for: {optimized_search_query}")
117
 
118
+ sources = wc.get_sources(optimized_search_query, max_pages=max_pages)
119
  links_md.markdown(create_links_markdown(sources))
120
 
121
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
122
  contents = wc.get_links_contents(sources, use_selenium=False)
123
 
124
+ st.write(f"Reading through the {len(contents)} sources I managed to retrieve")
125
+ chunks = nr.recursive_split_documents(contents)
126
+ relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=top_k_documents)
127
+ st.write(f"I collected {len(relevant_results)} chunks of data and I can now answer")
128
 
129
+ rag_prompt = nr.build_rag_prompt(query=prompt, relevant_results=relevant_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  with st.chat_message("assistant"):
132
  st_cb = StreamHandler(st.empty())
 
165
  data=st.session_state.messages[-1]["content"],
166
  file_name=f"{message_id}.txt",
167
  mime="text/plain"
168
+ )
 
 
 
 
 
web_rag.py CHANGED
@@ -19,6 +19,7 @@ Perform RAG using a single query to retrieve relevant documents.
19
  """
20
  import os
21
  import json
 
22
  from docopt import re
23
  from langchain.schema import SystemMessage, HumanMessage
24
  from langchain.prompts.chat import (
@@ -115,55 +116,6 @@ def get_optimized_search_messages(query):
115
 
116
 
117
 
118
- def get_optimized_search_messages2(query):
119
- """
120
- Generate optimized search messages for a given query.
121
-
122
- Args:
123
- query (str): The user's query.
124
-
125
- Returns:
126
- list: A list containing the system message and human message for optimized search.
127
- """
128
- system_message = SystemMessage(
129
- content="""
130
- You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
131
-
132
- The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
133
-
134
- Here are some key principles for creating effective search queries:
135
- 1. Use specific and relevant keywords
136
- 2. Remove unnecessary words (articles, prepositions, etc.)
137
- 3. Utilize quotation marks for exact phrases
138
- 4. Employ Boolean operators (AND, OR, NOT) when appropriate
139
- 5. Include synonyms or related terms to broaden the search
140
-
141
- I will provide you with a chat prompt or question. Your task is to optimize this into an effective search string.
142
-
143
- Process the input as follows:
144
- 1. Analyze the Question to identify the main topic and key concepts.
145
- 2. Extract the most relevant keywords and phrases.
146
- 3. Consider any implicit information or context that might be useful for the search.
147
-
148
- Then, optimize the search string by:
149
- 1. Removing filler words and unnecessary language
150
- 2. Rearranging keywords in a logical order
151
- 3. Adding quotation marks around exact phrases if applicable
152
- 4. Including relevant synonyms or related terms (in parentheses) to broaden the search
153
- 5. Using Boolean operators if needed to refine the search
154
-
155
- You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
156
- """
157
- )
158
- human_message = HumanMessage(
159
- content=f"""
160
- Question: {query}
161
-
162
- """
163
- )
164
- return [system_message, human_message]
165
-
166
-
167
  @traceable(run_type="llm", name="optimize_search_query")
168
  def optimize_search_query(chat_llm, query, callbacks=[]):
169
  """
@@ -200,10 +152,11 @@ def get_rag_prompt_template():
200
  Returns:
201
  ChatPromptTemplate: The prompt template for RAG.
202
  """
 
203
  system_prompt = SystemMessagePromptTemplate(
204
  prompt=PromptTemplate(
205
  input_variables=[],
206
- template="""
207
  You are an expert research assistant.
208
  You are provided with a Context in JSON format and a Question.
209
  Each JSON entry contains: content, title, link
@@ -219,6 +172,8 @@ def get_rag_prompt_template():
219
  If the provided context is not relevant to the question, say it and answer with your internal knowledge.
220
  If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
221
  If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
 
 
222
  """
223
  )
224
  )
 
19
  """
20
  import os
21
  import json
22
+ from datetime import datetime
23
  from docopt import re
24
  from langchain.schema import SystemMessage, HumanMessage
25
  from langchain.prompts.chat import (
 
116
 
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  @traceable(run_type="llm", name="optimize_search_query")
120
  def optimize_search_query(chat_llm, query, callbacks=[]):
121
  """
 
152
  Returns:
153
  ChatPromptTemplate: The prompt template for RAG.
154
  """
155
+ today = datetime.now().strftime("%Y-%m-%d")
156
  system_prompt = SystemMessagePromptTemplate(
157
  prompt=PromptTemplate(
158
  input_variables=[],
159
+ template=f"""
160
  You are an expert research assistant.
161
  You are provided with a Context in JSON format and a Question.
162
  Each JSON entry contains: content, title, link
 
172
  If the provided context is not relevant to the question, say it and answer with your internal knowledge.
173
  If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
174
  If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
175
+
176
+ Today's date is {today}
177
  """
178
  )
179
  )