Forbu14 commited on
Commit
d6bb506
1 Parent(s): 7017f65

adding code structure

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +359 -0
  3. requirements.txt +6 -0
  4. utils.py +67 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **.key
app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoiLibreQA is an open source AI assistant for legal assistance.
3
+ Le code est inspiré de ClimateQA
4
+ """
5
+
6
+ import gradio as gr
7
+ from haystack.document_stores import FAISSDocumentStore
8
+ from haystack.nodes import EmbeddingRetriever
9
+ import openai
10
+ import pandas as pd
11
+ import os
12
+ from utils import (
13
+ make_pairs,
14
+ set_openai_api_key,
15
+ create_user_id,
16
+ to_completion,
17
+ )
18
+ import numpy as np
19
+ from datetime import datetime
20
+
21
+ try:
22
+ from dotenv import load_dotenv
23
+
24
+ load_dotenv()
25
+ except:
26
+ pass
27
+
28
+ list_codes = []
29
+
30
+ theme = gr.themes.Soft(
31
+ primary_hue="sky",
32
+ font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
33
+ )
34
+
35
+ init_prompt = (
36
+ "Vous êtes LoiLibreQA, un assistant AI open source pour l'assistance juridique.",
37
+ "Vous recevez une question et des extraits d'article de loi",
38
+ "Fournissez une réponse claire et structurée en vous basant sur le contexte fourni.",
39
+ "Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.",
40
+ )
41
+ sources_prompt = (
42
+ "Lorsque cela est pertinent, utilisez les documents suivants dans votre réponse.",
43
+ "Chaque fois que vous utilisez des informations provenant d'un document, référencez-le à la fin de la phrase (ex : [doc 2]).",
44
+ "Vous n'êtes pas obligé d'utiliser tous les documents, seulement s'ils ont du sens dans la conversation.",
45
+ "Si aucune information pertinente pour répondre à la question n'est présente dans les documents, indiquez simplement que vous n'avez pas suffisamment d'informations pour répondre.",
46
+ )
47
+
48
+
49
+ def get_reformulation_prompt(query: str) -> str:
50
+ return f"""Reformulez le message utilisateur suivant en une question courte et autonome en français, dans le contexte d'une discussion autour de questions juridiques.
51
+ ---
52
+ requête: La justice doit-elle être la même pour tous ?
53
+ question autonome : Pensez-vous que la justice devrait être appliquée de manière égale à tous, indépendamment de leur statut social ou de leur origine ?
54
+ langage: French
55
+ ---
56
+ requête: Comment protéger ses droits d'auteur ?
57
+ question autonome : Quelles sont les mesures à prendre pour protéger ses droits d'auteur en tant qu'auteur ?
58
+ langage: French
59
+ ---
60
+ requête: Peut-on utiliser une photo trouvée sur Internet pour un projet commercial ?
61
+ question autonome : Est-il légalement permis d'utiliser une photographie trouvée sur Internet pour un projet commercial sans obtenir l'autorisation du titulaire des droits d'auteur ?
62
+ langage: French
63
+ ---
64
+ requête : {query}
65
+ question autonome : """
66
+
67
+
68
+ system_template = {
69
+ "role": "system",
70
+ "content": init_prompt,
71
+ }
72
+
73
+ # read key.key file and set openai api key
74
+ with open("key.key", "r") as f:
75
+ key = f.read()
76
+
77
+ # set api_key environment variable
78
+ os.environ["api_key"] = key
79
+
80
+ set_openai_api_key(key)
81
+
82
+ openai.api_key = os.environ["api_key"]
83
+
84
+ retriever = EmbeddingRetriever(
85
+ document_store=FAISSDocumentStore.load(
86
+ index_path="faiss_index.index",
87
+ config_path="faiss_config.json",
88
+ ),
89
+ embedding_model="text-embedding-ada-002",
90
+ model_format="openai",
91
+ progress_bar=False,
92
+ api_key=os.environ["api_key"],
93
+ )
94
+
95
+
96
+ file_share_name = "loilibregpt"
97
+
98
+ user_id = create_user_id(10)
99
+
100
+
101
+ def filter_sources(df, k_summary=3, k_total=10, source="code civil"):
102
+ # assert source in ["ipcc", "ipbes", "all"]
103
+
104
+ # # Filter by source
105
+ # if source == "Code civil":
106
+ # df = df.loc[df["source"] == "codecivil"]
107
+ # elif source == "ipbes":
108
+ # df = df.loc[df["source"] == "IPBES"]
109
+ # else:
110
+ # pass
111
+
112
+ # Separate summaries and full reports
113
+ df_summaries = df # .loc[df["report_type"].isin(["SPM", "TS"])]
114
+ df_full = df # .loc[~df["report_type"].isin(["SPM", "TS"])]
115
+
116
+ # Find passages from summaries dataset
117
+ passages_summaries = df_summaries.head(k_summary)
118
+
119
+ # Find passages from full reports dataset
120
+ passages_fullreports = df_full.head(k_total - len(passages_summaries))
121
+
122
+ # Concatenate passages
123
+ passages = pd.concat(
124
+ [passages_summaries, passages_fullreports], axis=0, ignore_index=True
125
+ )
126
+ return passages
127
+
128
+
129
+ def retrieve_with_summaries(
130
+ query,
131
+ retriever,
132
+ k_summary=3,
133
+ k_total=10,
134
+ source="ipcc",
135
+ max_k=100,
136
+ threshold=0.49,
137
+ as_dict=True,
138
+ ):
139
+ """
140
+ compare to retrieve_with_summaries, this function returns a dataframe with the content of the passages
141
+ """
142
+ assert max_k > k_total
143
+ docs = retriever.retrieve(query, top_k=max_k)
144
+ docs = [
145
+ {**x.meta, "score": x.score, "content": x.content}
146
+ for x in docs
147
+ if x.score > threshold
148
+ ]
149
+ if len(docs) == 0:
150
+ return []
151
+ res = pd.DataFrame(docs)
152
+ passages_df = filter_sources(res, k_summary, k_total, source)
153
+ if as_dict:
154
+ contents = passages_df["content"].tolist()
155
+ meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
156
+ passages = []
157
+ for i in range(len(contents)):
158
+ passages.append({"content": contents[i], "meta": meta[i]})
159
+ return passages
160
+ else:
161
+ return passages_df
162
+
163
+
164
+ def make_html_source(source, i):
165
+ """ """
166
+ meta = source["meta"]
167
+ return f"""
168
+ <div class="card">
169
+ <div class="card-content">
170
+ <h2>Doc {i} - </h2>
171
+ <p>{source['content']}</p>
172
+ </div>
173
+ <div class="card-footer">
174
+ <span>link to code</span>
175
+ </div>
176
+ </div>
177
+ """
178
+
179
+
180
+ def chat(
181
+ user_id: str,
182
+ query: str,
183
+ history: list = [system_template],
184
+ threshold: float = 0.49,
185
+ ) -> tuple:
186
+ """retrieve relevant documents in the document store then query gpt-turbo
187
+ Args:
188
+ query (str): user message.
189
+ history (list, optional): history of the conversation. Defaults to [system_template].
190
+ report_type (str, optional): should be "All available" or "IPCC only". Defaults to "All available".
191
+ threshold (float, optional): similarity threshold, don't increase more than 0.568. Defaults to 0.56.
192
+ Yields:
193
+ tuple: chat gradio format, chat openai format, sources used.
194
+ """
195
+ reformulated_query = openai.Completion.create(
196
+ model="text-davinci-002",
197
+ prompt=get_reformulation_prompt(query),
198
+ temperature=0,
199
+ max_tokens=128,
200
+ stop=["\n---\n", "<|im_end|>"],
201
+ )
202
+
203
+ reformulated_query = reformulated_query["choices"][0]["text"]
204
+ language = "francais"
205
+
206
+ sources = retrieve_with_summaries(
207
+ reformulated_query,
208
+ retriever,
209
+ k_total=10,
210
+ k_summary=3,
211
+ as_dict=True,
212
+ threshold=threshold,
213
+ )
214
+
215
+ # docs = [d for d in retriever.retrieve(query=reformulated_query, top_k=10) if d.score > threshold]
216
+ messages = history + [{"role": "user", "content": query}]
217
+
218
+ if len(sources) > 0:
219
+ docs_string = []
220
+ docs_html = []
221
+ for i, d in enumerate(sources, 1):
222
+ docs_string.append(f"📃 Doc {i}: \n{d['content']}")
223
+ docs_html.append(make_html_source(d, i))
224
+ docs_string = "\n\n".join(
225
+ [f"Query used for retrieval:\n{reformulated_query}"] + docs_string
226
+ )
227
+ docs_html = "\n\n".join(
228
+ [f"Query used for retrieval:\n{reformulated_query}"] + docs_html
229
+ )
230
+ messages.append(
231
+ {
232
+ "role": "system",
233
+ "content": f"{sources_prompt}\n\n{docs_string}\n\nAnswer in {language}:",
234
+ }
235
+ )
236
+
237
+ response = openai.Completion.create(
238
+ model="text-davinci-002",
239
+ prompt=to_completion(messages),
240
+ temperature=0, # deterministic
241
+ stream=True,
242
+ max_tokens=1024,
243
+ )
244
+
245
+ complete_response = ""
246
+ messages.pop()
247
+
248
+ messages.append({"role": "assistant", "content": complete_response})
249
+ timestamp = str(datetime.now().timestamp())
250
+ file = user_id[0] + timestamp + ".json"
251
+
252
+ for chunk in response:
253
+ if (
254
+ chunk_message := chunk["choices"][0].get("text")
255
+ ) and chunk_message != "<|im_end|>":
256
+ complete_response += chunk_message
257
+ messages[-1]["content"] = complete_response
258
+ gradio_format = make_pairs([a["content"] for a in messages[1:]])
259
+ yield gradio_format, messages, docs_html
260
+
261
+ else:
262
+ docs_string = "Pas d'élements juridique trouvé dans les codes de loi"
263
+ complete_response = (
264
+ "**Pas d'élément trouvé dans les textes de loi. Préciser votre réponse**"
265
+ )
266
+ messages.append({"role": "assistant", "content": complete_response})
267
+ gradio_format = make_pairs([a["content"] for a in messages[1:]])
268
+ yield gradio_format, messages, docs_string
269
+
270
+
271
+ def save_feedback(feed: str, user_id):
272
+ if len(feed) > 1:
273
+ timestamp = str(datetime.now().timestamp())
274
+ file = user_id[0] + timestamp + ".json"
275
+ logs = {
276
+ "user_id": user_id[0],
277
+ "feedback": feed,
278
+ "time": timestamp,
279
+ }
280
+ return "Feedback submitted, thank you!"
281
+
282
+
283
+ def reset_textbox():
284
+ return gr.update(value="")
285
+
286
+
287
+ with gr.Blocks(title="LoiLibre Q&A", css="style.css", theme=theme) as demo:
288
+ user_id_state = gr.State([user_id])
289
+
290
+ # Gradio
291
+ gr.Markdown("<h1><center>LoiLibre Q&A</center></h1>")
292
+ gr.Markdown("<h4><center>Pose tes questions aux textes de loi ici</center></h4>")
293
+
294
+ with gr.Row():
295
+ with gr.Column(scale=2):
296
+ chatbot = gr.Chatbot(
297
+ elem_id="chatbot", label="LoiLibreQ&A chatbot", show_label=False
298
+ )
299
+ state = gr.State([system_template])
300
+
301
+ with gr.Row():
302
+ ask = gr.Textbox(
303
+ show_label=False,
304
+ placeholder="Pose ta question ici",
305
+ ).style(container=False)
306
+ ask_examples_hidden = gr.Textbox(elem_id="hidden-message")
307
+
308
+ examples_questions = gr.Examples(
309
+ [
310
+ "Quelles sont les options légales pour une personne qui souhaite divorcer, notamment en matière de garde d'enfants et de pension alimentaire ?",
311
+ "Quelles sont les démarches à suivre pour créer une entreprise et quels sont les risques et les responsabilités juridiques associés ?",
312
+ "Comment pouvez-vous m'aider à protéger mes droits d'auteur et à faire respecter mes droits de propriété intellectuelle ?",
313
+ "Quels sont mes droits si j'ai été victime de harcèlement au travail ou de discrimination en raison de mon âge, de ma race ou de mon genre ?",
314
+ "Quelles sont les conséquences légales pour une entreprise qui a été poursuivie pour négligence ou faute professionnelle ?",
315
+ "Comment pouvez-vous m'aider à négocier un contrat de location commercial ou résidentiel, et quels sont mes droits et obligations en tant que locataire ou propriétaire ?",
316
+ "Quels sont les défenses possibles pour une personne accusée de crimes sexuels ou de violence domestique ?",
317
+ "Quelles sont les options légales pour une personne qui souhaite contester un testament ou un héritage ?",
318
+ "Comment pouvez-vous m'aider à obtenir une compensation en cas d'accident de voiture ou de blessure personnelle causée par la négligence d'une autre personne ?",
319
+ "Comment pouvez-vous m'aider à obtenir un visa ou un statut de résident permanent aux États-Unis, et quels sont les risques et les avantages associés ?",
320
+ ],
321
+ [ask_examples_hidden],
322
+ )
323
+
324
+ with gr.Column(scale=1, variant="panel"):
325
+ gr.Markdown("### Sources")
326
+ sources_textbox = gr.Markdown(show_label=False)
327
+
328
+
329
+
330
+ ask.submit(
331
+ fn=chat,
332
+ inputs=[user_id_state, ask, state],
333
+ outputs=[chatbot, state, sources_textbox],
334
+ )
335
+ ask.submit(reset_textbox, [], [ask])
336
+
337
+ ask_examples_hidden.change(
338
+ fn=chat,
339
+ inputs=[user_id_state, ask_examples_hidden, state],
340
+ outputs=[chatbot, state, sources_textbox],
341
+ )
342
+
343
+ with gr.Row():
344
+ with gr.Column(scale=1):
345
+ gr.Markdown(
346
+ """
347
+ <div class="warning-box">
348
+ Version 0.1-beta - This tool is under active development
349
+
350
+ </div>
351
+ """)
352
+ gr.Markdown(
353
+ """
354
+
355
+ """)
356
+
357
+ demo.queue(concurrency_count=16)
358
+
359
+ demo.launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ faiss-cpu==1.7.2
2
+ farm-haystack==1.14.0
3
+ gradio==3.22.1
4
+ openai==0.27.0
5
+ python-dotenv==1.0.0
6
+ pdfminer.six
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import openai
3
+ import os
4
+ import random
5
+ import string
6
+
7
+
8
+ def is_climate_change_related(sentence: str, classifier) -> bool:
9
+ """_summary_
10
+ Args:
11
+ sentence (str): your sentence to classify
12
+ classifier (_type_): zero shot hugging face pipeline classifier
13
+ Returns:
14
+ bool: is_climate_change_related or not
15
+ """
16
+ results = classifier(
17
+ sequences=sentence,
18
+ candidate_labels=["climate change related", "non climate change related"],
19
+ )
20
+ print(f" ## Result from is climate change related {results}")
21
+ return results["labels"][np.argmax(results["scores"])] == "climate change related"
22
+
23
+
24
+ def make_pairs(lst):
25
+ """From a list of even lenght, make tupple pairs
26
+ Args:
27
+ lst (list): a list of even lenght
28
+ Returns:
29
+ list: the list as tupple pairs
30
+ """
31
+ assert not (l := len(lst) % 2), f"your list is of lenght {l} which is not even"
32
+ return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
33
+
34
+
35
+ def set_openai_api_key(text):
36
+ """Set the api key and return chain.If no api_key, then None is returned.
37
+ To do : add raise error & Warning message
38
+ Args:
39
+ text (str): openai api key
40
+ Returns:
41
+ str: Result of connection
42
+ """
43
+ openai.api_key = os.environ["api_key"]
44
+
45
+ if text.startswith("sk-") and len(text) > 10:
46
+ openai.api_key = text
47
+ return f"You're all set: this is your api key: {openai.api_key}"
48
+
49
+
50
+ def create_user_id(length):
51
+ """Create user_id
52
+ Args:
53
+ length (int): length of user id
54
+ Returns:
55
+ str: String to id user
56
+ """
57
+ letters = string.ascii_lowercase
58
+ user_id = "".join(random.choice(letters) for i in range(length))
59
+ return user_id
60
+
61
+
62
+ def to_completion(messages):
63
+ s = []
64
+ for message in messages:
65
+ s.append(f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>")
66
+ s.append("<|im_start|>assistant\n")
67
+ return "\n".join(s)