Update new version
Browse files
app.py
CHANGED
@@ -34,12 +34,16 @@ with gr.Blocks(title=CFG_APP.BOT_NAME, css="assets/style.css", theme=theme) as d
|
|
34 |
)
|
35 |
state = gr.State([system_template])
|
36 |
|
|
|
37 |
with gr.Row():
|
38 |
ask = gr.Textbox(
|
39 |
show_label=False,
|
40 |
placeholder="Ask here your question and press enter",
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
examples_questions = gr.Examples(
|
45 |
[*CFG_APP.DEFAULT_QUESTIONS],
|
@@ -53,7 +57,7 @@ with gr.Blocks(title=CFG_APP.BOT_NAME, css="assets/style.css", theme=theme) as d
|
|
53 |
|
54 |
ask.submit(
|
55 |
fn=chat,
|
56 |
-
inputs=[ask, state],
|
57 |
outputs=[chatbot, state, sources_textbox],
|
58 |
)
|
59 |
ask.submit(lambda x: gr.update(value=""), [], [ask])
|
|
|
34 |
)
|
35 |
state = gr.State([system_template])
|
36 |
|
37 |
+
|
38 |
with gr.Row():
|
39 |
ask = gr.Textbox(
|
40 |
show_label=False,
|
41 |
placeholder="Ask here your question and press enter",
|
42 |
+
)
|
43 |
+
|
44 |
+
query_mode = gr.Radio(choices=["HYDE", "Reformulation"], elem_id="type-emb", default="HYDE", label="Query Embedding's Mode")
|
45 |
+
|
46 |
+
ask_examples_hidden = gr.Textbox(elem_id="hidden-message")
|
47 |
|
48 |
examples_questions = gr.Examples(
|
49 |
[*CFG_APP.DEFAULT_QUESTIONS],
|
|
|
57 |
|
58 |
ask.submit(
|
59 |
fn=chat,
|
60 |
+
inputs=[ask, state, query_mode],
|
61 |
outputs=[chatbot, state, sources_textbox],
|
62 |
)
|
63 |
ask.submit(lambda x: gr.update(value=""), [], [ask])
|
assets/style.css
CHANGED
@@ -140,7 +140,6 @@ a {
|
|
140 |
|
141 |
|
142 |
label>span {
|
143 |
-
background-color: white !important;
|
144 |
color: #577b9b !important;
|
145 |
}
|
146 |
|
@@ -152,7 +151,7 @@ label>span {
|
|
152 |
left: -10px;
|
153 |
width: 30px;
|
154 |
height: 30px;
|
155 |
-
background-image: url('
|
156 |
background-color: #fff;
|
157 |
background-size: cover;
|
158 |
background-position: center;
|
@@ -181,4 +180,14 @@ label>span {
|
|
181 |
padding: 17px 24px !important;
|
182 |
text-align: justify !important;
|
183 |
color: #fff !important;
|
184 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
|
142 |
label>span {
|
|
|
143 |
color: #577b9b !important;
|
144 |
}
|
145 |
|
|
|
151 |
left: -10px;
|
152 |
width: 30px;
|
153 |
height: 30px;
|
154 |
+
background-image: url('');
|
155 |
background-color: #fff;
|
156 |
background-size: cover;
|
157 |
background-position: center;
|
|
|
180 |
padding: 17px 24px !important;
|
181 |
text-align: justify !important;
|
182 |
color: #fff !important;
|
183 |
+
}
|
184 |
+
|
185 |
+
#chatbot{
|
186 |
+
height: auto !important;
|
187 |
+
max-height: 500px;
|
188 |
+
}
|
189 |
+
|
190 |
+
#type-emb label {
|
191 |
+
background: #ebeaea;
|
192 |
+
}
|
193 |
+
|
config.py
CHANGED
@@ -58,4 +58,14 @@ class CFG_APP:
|
|
58 |
standalone question: What does UL (Unexpected Loss) stand for?
|
59 |
language: English
|
60 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
DOC_METADATA_PATH = f"{DATA_FOLDER}/doc_metadata.json"
|
|
|
58 |
standalone question: What does UL (Unexpected Loss) stand for?
|
59 |
language: English
|
60 |
"""
|
61 |
+
HYDE_PROMPT = """
|
62 |
+
Important ! Give the output as a answer to the query followed by the detected language whatever the form of the query.
|
63 |
+
You must answer to the query in a short answer, 2 sentences maximum, using the right vocabulary of the context of the query. You must keep the question at the begining of the answer.
|
64 |
+
---
|
65 |
+
query : C'est quoi les régles que les banques américaines doivent suivre ?
|
66 |
+
output : C'est quoi les régles que les banques américaines doivent suivre ? Les banques américaines doivent suivre un ensemble de réglementations fédérales et d'État imposées par des organismes tels que la Réserve fédérale et le Bureau de protection financière du consommateur.
|
67 |
+
language : French
|
68 |
+
"""
|
69 |
+
|
70 |
+
|
71 |
DOC_METADATA_PATH = f"{DATA_FOLDER}/doc_metadata.json"
|
utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import openai
|
3 |
import re
|
4 |
from config import CFG_APP
|
@@ -37,6 +38,17 @@ def get_reformulation_prompt(query: str) -> list:
|
|
37 |
}
|
38 |
]
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def make_pairs(lst):
|
42 |
"""From a list of even lenght, make tupple pairs
|
@@ -68,10 +80,10 @@ def make_html_source(paragraph, meta_doc, i):
|
|
68 |
"""
|
69 |
|
70 |
|
71 |
-
def preprocess_message(text: str) -> str:
|
72 |
return re.sub(
|
73 |
r"\[doc (\d+)\]",
|
74 |
-
lambda match: f'<a href="
|
75 |
text,
|
76 |
)
|
77 |
|
@@ -96,6 +108,7 @@ def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
|
96 |
def chat(
|
97 |
query: str,
|
98 |
history: list,
|
|
|
99 |
threshold: float = CFG_APP.THRESHOLD,
|
100 |
k_total: int = CFG_APP.K_TOTAL,
|
101 |
) -> tuple:
|
@@ -108,13 +121,23 @@ def chat(
|
|
108 |
Yields:
|
109 |
tuple: chat gradio format, chat openai format, sources used.
|
110 |
"""
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
reformulated_query = reformulated_query["choices"][0]["message"]["content"]
|
120 |
if len(reformulated_query.split("\n")) == 2:
|
@@ -134,6 +157,11 @@ def chat(
|
|
134 |
|
135 |
messages = history + [{"role": "user", "content": query}]
|
136 |
|
|
|
|
|
|
|
|
|
|
|
137 |
if len(sources) > 0:
|
138 |
docs_string = []
|
139 |
docs_html = []
|
@@ -150,6 +178,9 @@ def chat(
|
|
150 |
docs_string.append(doc_content)
|
151 |
docs_html.append(make_html_source(data, meta_doc, i))
|
152 |
|
|
|
|
|
|
|
153 |
docs_string = "\n\n".join(
|
154 |
[f"Query used for retrieval:\n{reformulated_query}"] + docs_string
|
155 |
)
|
@@ -197,14 +228,12 @@ def chat(
|
|
197 |
)
|
198 |
complete_response = ""
|
199 |
messages.pop()
|
200 |
-
|
201 |
messages.append({"role": "assistant", "content": complete_response})
|
202 |
-
|
203 |
for chunk in response:
|
204 |
chunk_message = chunk["choices"][0]["delta"].get("content")
|
205 |
if chunk_message:
|
206 |
complete_response += chunk_message
|
207 |
-
complete_response = preprocess_message(complete_response)
|
208 |
messages[-1]["content"] = complete_response
|
209 |
gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
210 |
yield gradio_format, messages, docs_html
|
|
|
1 |
import json
|
2 |
+
from collections import defaultdict
|
3 |
import openai
|
4 |
import re
|
5 |
from config import CFG_APP
|
|
|
38 |
}
|
39 |
]
|
40 |
|
41 |
+
def get_hyde_prompt(query: str) -> list:
|
42 |
+
return [
|
43 |
+
{
|
44 |
+
"role": "user",
|
45 |
+
"content": f"""{CFG_APP.HYDE_PROMPT}
|
46 |
+
---
|
47 |
+
query: {query}
|
48 |
+
output: """,
|
49 |
+
}
|
50 |
+
]
|
51 |
+
|
52 |
|
53 |
def make_pairs(lst):
|
54 |
"""From a list of even lenght, make tupple pairs
|
|
|
80 |
"""
|
81 |
|
82 |
|
83 |
+
def preprocess_message(text: str, docs_url: dict) -> str:
|
84 |
return re.sub(
|
85 |
r"\[doc (\d+)\]",
|
86 |
+
lambda match: f'<a href="{docs_url[match.group(1)]}" target="_blank" class="pdf-link">{match.group(0)}</a>',
|
87 |
text,
|
88 |
)
|
89 |
|
|
|
108 |
def chat(
|
109 |
query: str,
|
110 |
history: list,
|
111 |
+
query_mode : str,
|
112 |
threshold: float = CFG_APP.THRESHOLD,
|
113 |
k_total: int = CFG_APP.K_TOTAL,
|
114 |
) -> tuple:
|
|
|
121 |
Yields:
|
122 |
tuple: chat gradio format, chat openai format, sources used.
|
123 |
"""
|
124 |
+
if query_mode == 'Reformulation':
|
125 |
|
126 |
+
reformulated_query = openai.ChatCompletion.create(
|
127 |
+
model=CFG_APP.MODEL_NAME,
|
128 |
+
messages=get_reformulation_prompt(parse_glossary(query)),
|
129 |
+
temperature=0,
|
130 |
+
max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION,
|
131 |
+
)
|
132 |
+
|
133 |
+
else :
|
134 |
+
|
135 |
+
reformulated_query = openai.ChatCompletion.create(
|
136 |
+
model=CFG_APP.MODEL_NAME,
|
137 |
+
messages=get_hyde_prompt(parse_glossary(query)),
|
138 |
+
temperature=0,
|
139 |
+
max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION,
|
140 |
+
)
|
141 |
|
142 |
reformulated_query = reformulated_query["choices"][0]["message"]["content"]
|
143 |
if len(reformulated_query.split("\n")) == 2:
|
|
|
157 |
|
158 |
messages = history + [{"role": "user", "content": query}]
|
159 |
|
160 |
+
if query_mode == 'HYDE' :
|
161 |
+
reformulated_query = reformulated_query.split("?")[0] + '?'
|
162 |
+
|
163 |
+
docs_url = defaultdict(str)
|
164 |
+
|
165 |
if len(sources) > 0:
|
166 |
docs_string = []
|
167 |
docs_html = []
|
|
|
178 |
docs_string.append(doc_content)
|
179 |
docs_html.append(make_html_source(data, meta_doc, i))
|
180 |
|
181 |
+
url_doc = f'<a href="{meta_doc["url"]}#page={data["meta"]["page_number"]}" target="_blank" class="pdf-link">'
|
182 |
+
docs_url[i] = url_doc
|
183 |
+
|
184 |
docs_string = "\n\n".join(
|
185 |
[f"Query used for retrieval:\n{reformulated_query}"] + docs_string
|
186 |
)
|
|
|
228 |
)
|
229 |
complete_response = ""
|
230 |
messages.pop()
|
|
|
231 |
messages.append({"role": "assistant", "content": complete_response})
|
|
|
232 |
for chunk in response:
|
233 |
chunk_message = chunk["choices"][0]["delta"].get("content")
|
234 |
if chunk_message:
|
235 |
complete_response += chunk_message
|
236 |
+
complete_response = preprocess_message(complete_response, docs_url)
|
237 |
messages[-1]["content"] = complete_response
|
238 |
gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
239 |
yield gradio_format, messages, docs_html
|