Afritz commited on
Commit
743be46
·
1 Parent(s): 02f7718

Update new version

Browse files
Files changed (4) hide show
  1. app.py +7 -3
  2. assets/style.css +12 -3
  3. config.py +10 -0
  4. utils.py +40 -11
app.py CHANGED
@@ -34,12 +34,16 @@ with gr.Blocks(title=CFG_APP.BOT_NAME, css="assets/style.css", theme=theme) as d
34
  )
35
  state = gr.State([system_template])
36
 
 
37
  with gr.Row():
38
  ask = gr.Textbox(
39
  show_label=False,
40
  placeholder="Ask here your question and press enter",
41
- ).style(container=False)
42
- ask_examples_hidden = gr.Textbox(elem_id="hidden-message")
 
 
 
43
 
44
  examples_questions = gr.Examples(
45
  [*CFG_APP.DEFAULT_QUESTIONS],
@@ -53,7 +57,7 @@ with gr.Blocks(title=CFG_APP.BOT_NAME, css="assets/style.css", theme=theme) as d
53
 
54
  ask.submit(
55
  fn=chat,
56
- inputs=[ask, state],
57
  outputs=[chatbot, state, sources_textbox],
58
  )
59
  ask.submit(lambda x: gr.update(value=""), [], [ask])
 
34
  )
35
  state = gr.State([system_template])
36
 
37
+
38
  with gr.Row():
39
  ask = gr.Textbox(
40
  show_label=False,
41
  placeholder="Ask here your question and press enter",
42
+ )
43
+
44
+ query_mode = gr.Radio(choices=["HYDE", "Reformulation"], elem_id="type-emb", default="HYDE", label="Query Embedding's Mode")
45
+
46
+ ask_examples_hidden = gr.Textbox(elem_id="hidden-message")
47
 
48
  examples_questions = gr.Examples(
49
  [*CFG_APP.DEFAULT_QUESTIONS],
 
57
 
58
  ask.submit(
59
  fn=chat,
60
+ inputs=[ask, state, query_mode],
61
  outputs=[chatbot, state, sources_textbox],
62
  )
63
  ask.submit(lambda x: gr.update(value=""), [], [ask])
assets/style.css CHANGED
@@ -140,7 +140,6 @@ a {
140
 
141
 
142
  label>span {
143
- background-color: white !important;
144
  color: #577b9b !important;
145
  }
146
 
@@ -152,7 +151,7 @@ label>span {
152
  left: -10px;
153
  width: 30px;
154
  height: 30px;
155
- background-image: url('https://www.nexialog.com/wp-content/uploads/2021/10/cropped-icone-onglet-logo.png');
156
  background-color: #fff;
157
  background-size: cover;
158
  background-position: center;
@@ -181,4 +180,14 @@ label>span {
181
  padding: 17px 24px !important;
182
  text-align: justify !important;
183
  color: #fff !important;
184
- }
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  label>span {
 
143
  color: #577b9b !important;
144
  }
145
 
 
151
  left: -10px;
152
  width: 30px;
153
  height: 30px;
154
+ background-image: url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAACnUlEQVR4AcXXA4wkQRiG4eHZtm3btm3btm3bDs+2bdvm2vPfm6Qu6XRuOz3LSp7xdH3ltCU8Za+lsA1JYLVER6HiOFiFXkgGa1QHiIvzCMQplI2uAKJMiY4A50ILwHs7bGG9eFqUQgx3A2gq74X+SAGrO5U7MQvfsAKF4XAzQD68QSDOoLbp3lAt/wxR3mMGssNmFEDTgAUQJQTTYDO7ticgEKLhwhMMRVpYDQIUwyeI8hhZzbbeipQYgNsIhmgE4xraIqk+AGJiNUQJwjCD1hsGSYfheIgQiIYXJuASRJmM8vgBUa4hdXi328yYgGdwQZSvuq4ehi0QxR9dYTVTUWIUQmEDtbESbzRBXBB4Yyb+QJTjSGx22U3DD/wMxQ+8xxXswRt8wjUInuKsboiamG19aXyBuCEQC9AIP/AZPhC4sBVxzVQeG2vgDR8YCYDgG1YhNZxoiWsIgi/2IA/iwojTwkMsFEN5VAhFRYzAc7hwFbXggBX5sB1+8MRNnNc5p3MAxcyuhOJ4ppvdX9ABuXET4qbtZocoLnZBFG+ch+AeNsED9/AFIRAY+YSSZjejBvCCKCdwGoJA+CII97EAA9Efg3SGYBRGoxkcZgIkwTGI8ge98RqCYHhClACcQRskMlqCZlvfCQEQZScqwQMCH6yFN0TDD0fRFAnCGiANrkKUH6iICvDRBKiOAZpe0fLBftRFXHf3/yG6k3ADYkIfoDzsKICV+ArR8cQGJDYbIBseQ5TP/2bt/wJo/hcD5bADHhCNrYhtNkA5PIILgiVwGgbQ7a6oh8PwxUeUdHcIcmABrqGAhWIygPY6CdEefY2XnfEpmQ52gwAVTKwmmyW8xTBAVBZ1yt2DK7oC2JAdc/EM5aPrztiJEkgXnuv8BdWTESwwR9FxAAAAAElFTkSuQmCC');
155
  background-color: #fff;
156
  background-size: cover;
157
  background-position: center;
 
180
  padding: 17px 24px !important;
181
  text-align: justify !important;
182
  color: #fff !important;
183
+ }
184
+
185
+ #chatbot{
186
+ height: auto !important;
187
+ max-height: 500px;
188
+ }
189
+
190
+ #type-emb label {
191
+ background: #ebeaea;
192
+ }
193
+
config.py CHANGED
@@ -58,4 +58,14 @@ class CFG_APP:
58
  standalone question: What does UL (Unexpected Loss) stand for?
59
  language: English
60
  """
 
 
 
 
 
 
 
 
 
 
61
  DOC_METADATA_PATH = f"{DATA_FOLDER}/doc_metadata.json"
 
58
  standalone question: What does UL (Unexpected Loss) stand for?
59
  language: English
60
  """
61
+ HYDE_PROMPT = """
62
+ Important ! Give the output as a answer to the query followed by the detected language whatever the form of the query.
63
+ You must answer to the query in a short answer, 2 sentences maximum, using the right vocabulary of the context of the query. You must keep the question at the begining of the answer.
64
+ ---
65
+ query : C'est quoi les régles que les banques américaines doivent suivre ?
66
+ output : C'est quoi les régles que les banques américaines doivent suivre ? Les banques américaines doivent suivre un ensemble de réglementations fédérales et d'État imposées par des organismes tels que la Réserve fédérale et le Bureau de protection financière du consommateur.
67
+ language : French
68
+ """
69
+
70
+
71
  DOC_METADATA_PATH = f"{DATA_FOLDER}/doc_metadata.json"
utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import openai
3
  import re
4
  from config import CFG_APP
@@ -37,6 +38,17 @@ def get_reformulation_prompt(query: str) -> list:
37
  }
38
  ]
39
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def make_pairs(lst):
42
  """From a list of even lenght, make tupple pairs
@@ -68,10 +80,10 @@ def make_html_source(paragraph, meta_doc, i):
68
  """
69
 
70
 
71
- def preprocess_message(text: str) -> str:
72
  return re.sub(
73
  r"\[doc (\d+)\]",
74
- lambda match: f'<a href="#do-{match.group(1)}">{match.group(0)}</a>',
75
  text,
76
  )
77
 
@@ -96,6 +108,7 @@ def num_tokens_from_string(string: str, encoding_name: str) -> int:
96
  def chat(
97
  query: str,
98
  history: list,
 
99
  threshold: float = CFG_APP.THRESHOLD,
100
  k_total: int = CFG_APP.K_TOTAL,
101
  ) -> tuple:
@@ -108,13 +121,23 @@ def chat(
108
  Yields:
109
  tuple: chat gradio format, chat openai format, sources used.
110
  """
 
111
 
112
- reformulated_query = openai.ChatCompletion.create(
113
- model=CFG_APP.MODEL_NAME,
114
- messages=get_reformulation_prompt(parse_glossary(query)),
115
- temperature=0,
116
- max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION,
117
- )
 
 
 
 
 
 
 
 
 
118
 
119
  reformulated_query = reformulated_query["choices"][0]["message"]["content"]
120
  if len(reformulated_query.split("\n")) == 2:
@@ -134,6 +157,11 @@ def chat(
134
 
135
  messages = history + [{"role": "user", "content": query}]
136
 
 
 
 
 
 
137
  if len(sources) > 0:
138
  docs_string = []
139
  docs_html = []
@@ -150,6 +178,9 @@ def chat(
150
  docs_string.append(doc_content)
151
  docs_html.append(make_html_source(data, meta_doc, i))
152
 
 
 
 
153
  docs_string = "\n\n".join(
154
  [f"Query used for retrieval:\n{reformulated_query}"] + docs_string
155
  )
@@ -197,14 +228,12 @@ def chat(
197
  )
198
  complete_response = ""
199
  messages.pop()
200
-
201
  messages.append({"role": "assistant", "content": complete_response})
202
-
203
  for chunk in response:
204
  chunk_message = chunk["choices"][0]["delta"].get("content")
205
  if chunk_message:
206
  complete_response += chunk_message
207
- complete_response = preprocess_message(complete_response)
208
  messages[-1]["content"] = complete_response
209
  gradio_format = make_pairs([a["content"] for a in messages[1:]])
210
  yield gradio_format, messages, docs_html
 
1
  import json
2
+ from collections import defaultdict
3
  import openai
4
  import re
5
  from config import CFG_APP
 
38
  }
39
  ]
40
 
41
+ def get_hyde_prompt(query: str) -> list:
42
+ return [
43
+ {
44
+ "role": "user",
45
+ "content": f"""{CFG_APP.HYDE_PROMPT}
46
+ ---
47
+ query: {query}
48
+ output: """,
49
+ }
50
+ ]
51
+
52
 
53
  def make_pairs(lst):
54
  """From a list of even lenght, make tupple pairs
 
80
  """
81
 
82
 
83
+ def preprocess_message(text: str, docs_url: dict) -> str:
84
  return re.sub(
85
  r"\[doc (\d+)\]",
86
+ lambda match: f'<a href="{docs_url[match.group(1)]}" target="_blank" class="pdf-link">{match.group(0)}</a>',
87
  text,
88
  )
89
 
 
108
  def chat(
109
  query: str,
110
  history: list,
111
+ query_mode : str,
112
  threshold: float = CFG_APP.THRESHOLD,
113
  k_total: int = CFG_APP.K_TOTAL,
114
  ) -> tuple:
 
121
  Yields:
122
  tuple: chat gradio format, chat openai format, sources used.
123
  """
124
+ if query_mode == 'Reformulation':
125
 
126
+ reformulated_query = openai.ChatCompletion.create(
127
+ model=CFG_APP.MODEL_NAME,
128
+ messages=get_reformulation_prompt(parse_glossary(query)),
129
+ temperature=0,
130
+ max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION,
131
+ )
132
+
133
+ else :
134
+
135
+ reformulated_query = openai.ChatCompletion.create(
136
+ model=CFG_APP.MODEL_NAME,
137
+ messages=get_hyde_prompt(parse_glossary(query)),
138
+ temperature=0,
139
+ max_tokens=CFG_APP.MAX_TOKENS_REF_QUESTION,
140
+ )
141
 
142
  reformulated_query = reformulated_query["choices"][0]["message"]["content"]
143
  if len(reformulated_query.split("\n")) == 2:
 
157
 
158
  messages = history + [{"role": "user", "content": query}]
159
 
160
+ if query_mode == 'HYDE' :
161
+ reformulated_query = reformulated_query.split("?")[0] + '?'
162
+
163
+ docs_url = defaultdict(str)
164
+
165
  if len(sources) > 0:
166
  docs_string = []
167
  docs_html = []
 
178
  docs_string.append(doc_content)
179
  docs_html.append(make_html_source(data, meta_doc, i))
180
 
181
+ url_doc = f'<a href="{meta_doc["url"]}#page={data["meta"]["page_number"]}" target="_blank" class="pdf-link">'
182
+ docs_url[i] = url_doc
183
+
184
  docs_string = "\n\n".join(
185
  [f"Query used for retrieval:\n{reformulated_query}"] + docs_string
186
  )
 
228
  )
229
  complete_response = ""
230
  messages.pop()
 
231
  messages.append({"role": "assistant", "content": complete_response})
 
232
  for chunk in response:
233
  chunk_message = chunk["choices"][0]["delta"].get("content")
234
  if chunk_message:
235
  complete_response += chunk_message
236
+ complete_response = preprocess_message(complete_response, docs_url)
237
  messages[-1]["content"] = complete_response
238
  gradio_format = make_pairs([a["content"] for a in messages[1:]])
239
  yield gradio_format, messages, docs_html