Spaces:
Sleeping
Sleeping
""" | |
LoiLibreQA is an open source AI assistant for legal assistance. | |
Le code est inspiré de ClimateQA | |
""" | |
import gradio as gr | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.nodes import EmbeddingRetriever | |
import openai | |
import pandas as pd | |
import os | |
from utils import ( | |
make_pairs, | |
set_openai_api_key, | |
create_user_id, | |
to_completion, | |
) | |
import numpy as np | |
from datetime import datetime | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except: | |
pass | |
list_codes = [] | |
theme = gr.themes.Soft( | |
primary_hue="sky", | |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
init_prompt = ( | |
"Vous êtes LoiLibreQA, un assistant AI open source pour l'assistance juridique.", | |
"Vous recevez une question et des extraits d'article de loi", | |
"Fournissez une réponse claire et structurée en vous basant sur le contexte fourni.", | |
"Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.", | |
) | |
sources_prompt = ( | |
"Lorsque cela est pertinent, utilisez les documents suivants dans votre réponse.", | |
"Chaque fois que vous utilisez des informations provenant d'un document, référencez-le à la fin de la phrase (ex : [doc 2]).", | |
"Vous n'êtes pas obligé d'utiliser tous les documents, seulement s'ils ont du sens dans la conversation.", | |
"Si aucune information pertinente pour répondre à la question n'est présente dans les documents, indiquez simplement que vous n'avez pas suffisamment d'informations pour répondre.", | |
) | |
def get_reformulation_prompt(query: str) -> str: | |
return f"""Reformulez le message utilisateur suivant en une question courte et autonome en français, dans le contexte d'une discussion autour de questions juridiques. | |
--- | |
requête: La justice doit-elle être la même pour tous ? | |
question autonome : Pensez-vous que la justice devrait être appliquée de manière égale à tous, indépendamment de leur statut social ou de leur origine ? | |
langage: French | |
--- | |
requête: Comment protéger ses droits d'auteur ? | |
question autonome : Quelles sont les mesures à prendre pour protéger ses droits d'auteur en tant qu'auteur ? | |
langage: French | |
--- | |
requête: Peut-on utiliser une photo trouvée sur Internet pour un projet commercial ? | |
question autonome : Est-il légalement permis d'utiliser une photographie trouvée sur Internet pour un projet commercial sans obtenir l'autorisation du titulaire des droits d'auteur ? | |
langage: French | |
--- | |
requête : {query} | |
question autonome : """ | |
system_template = { | |
"role": "system", | |
"content": init_prompt, | |
} | |
# if file key.key exist read the key if note read the env variable OPENAI_TOKEN | |
if os.path.isfile("key.key"): | |
# read key.key file and set openai api key | |
with open("key.key", "r") as f: | |
key = f.read() | |
else: | |
key = os.environ["OPENAI_TOKEN"] | |
# set api_key environment variable | |
os.environ["api_key"] = key | |
set_openai_api_key(key) | |
openai.api_key = os.environ["api_key"] | |
retriever = EmbeddingRetriever( | |
document_store=FAISSDocumentStore.load( | |
index_path="faiss_index.index", | |
config_path="faiss_config.json", | |
), | |
embedding_model="text-embedding-ada-002", | |
model_format="openai", | |
progress_bar=False, | |
api_key=os.environ["api_key"], | |
) | |
file_share_name = "loilibregpt" | |
user_id = create_user_id(10) | |
def filter_sources(df, k_summary=3, k_total=10, source="code civil"): | |
# assert source in ["ipcc", "ipbes", "all"] | |
# # Filter by source | |
# if source == "Code civil": | |
# df = df.loc[df["source"] == "codecivil"] | |
# elif source == "ipbes": | |
# df = df.loc[df["source"] == "IPBES"] | |
# else: | |
# pass | |
# Separate summaries and full reports | |
df_summaries = df # .loc[df["report_type"].isin(["SPM", "TS"])] | |
df_full = df # .loc[~df["report_type"].isin(["SPM", "TS"])] | |
# Find passages from summaries dataset | |
passages_summaries = df_summaries.head(k_summary) | |
# Find passages from full reports dataset | |
passages_fullreports = df_full.head(k_total - len(passages_summaries)) | |
# Concatenate passages | |
passages = pd.concat( | |
[passages_summaries, passages_fullreports], axis=0, ignore_index=True | |
) | |
return passages | |
def retrieve_with_summaries( | |
query, | |
retriever, | |
k_summary=3, | |
k_total=10, | |
source="ipcc", | |
max_k=100, | |
threshold=0.49, | |
as_dict=True, | |
): | |
""" | |
compare to retrieve_with_summaries, this function returns a dataframe with the content of the passages | |
""" | |
assert max_k > k_total | |
docs = retriever.retrieve(query, top_k=max_k) | |
docs = [ | |
{**x.meta, "score": x.score, "content": x.content} | |
for x in docs | |
if x.score > threshold | |
] | |
if len(docs) == 0: | |
return [] | |
res = pd.DataFrame(docs) | |
passages_df = filter_sources(res, k_summary, k_total, source) | |
if as_dict: | |
contents = passages_df["content"].tolist() | |
meta = passages_df.drop(columns=["content"]).to_dict(orient="records") | |
passages = [] | |
for i in range(len(contents)): | |
passages.append({"content": contents[i], "meta": meta[i]}) | |
return passages | |
else: | |
return passages_df | |
def make_html_source(source, i): | |
""" """ | |
meta = source["meta"] | |
return f""" | |
<div class="card"> | |
<div class="card-content"> | |
<h2>Doc {i} - </h2> | |
<p>{source['content']}</p> | |
</div> | |
<div class="card-footer"> | |
<span>link to code</span> | |
</div> | |
</div> | |
""" | |
def chat( | |
user_id: str, | |
query: str, | |
history: list = [system_template], | |
threshold: float = 0.49, | |
) -> tuple: | |
"""retrieve relevant documents in the document store then query gpt-turbo | |
Args: | |
query (str): user message. | |
history (list, optional): history of the conversation. Defaults to [system_template]. | |
report_type (str, optional): should be "All available" or "IPCC only". Defaults to "All available". | |
threshold (float, optional): similarity threshold, don't increase more than 0.568. Defaults to 0.56. | |
Yields: | |
tuple: chat gradio format, chat openai format, sources used. | |
""" | |
reformulated_query = openai.Completion.create( | |
model="text-davinci-002", | |
prompt=get_reformulation_prompt(query), | |
temperature=0, | |
max_tokens=128, | |
stop=["\n---\n", "<|im_end|>"], | |
) | |
reformulated_query = reformulated_query["choices"][0]["text"] | |
language = "francais" | |
sources = retrieve_with_summaries( | |
reformulated_query, | |
retriever, | |
k_total=10, | |
k_summary=3, | |
as_dict=True, | |
threshold=threshold, | |
) | |
# docs = [d for d in retriever.retrieve(query=reformulated_query, top_k=10) if d.score > threshold] | |
messages = history + [{"role": "user", "content": query}] | |
if len(sources) > 0: | |
docs_string = [] | |
docs_html = [] | |
for i, d in enumerate(sources, 1): | |
docs_string.append(f"📃 Doc {i}: \n{d['content']}") | |
docs_html.append(make_html_source(d, i)) | |
docs_string = "\n\n".join( | |
[f"Query used for retrieval:\n{reformulated_query}"] + docs_string | |
) | |
docs_html = "\n\n".join( | |
[f"Query used for retrieval:\n{reformulated_query}"] + docs_html | |
) | |
messages.append( | |
{ | |
"role": "system", | |
"content": f"{sources_prompt}\n\n{docs_string}\n\nAnswer in {language}:", | |
} | |
) | |
response = openai.Completion.create( | |
model="text-davinci-002", | |
prompt=to_completion(messages), | |
temperature=0, # deterministic | |
stream=True, | |
max_tokens=1024, | |
) | |
complete_response = "" | |
messages.pop() | |
messages.append({"role": "assistant", "content": complete_response}) | |
timestamp = str(datetime.now().timestamp()) | |
file = user_id[0] + timestamp + ".json" | |
for chunk in response: | |
if ( | |
chunk_message := chunk["choices"][0].get("text") | |
) and chunk_message != "<|im_end|>": | |
complete_response += chunk_message | |
messages[-1]["content"] = complete_response | |
gradio_format = make_pairs([a["content"] for a in messages[1:]]) | |
yield gradio_format, messages, docs_html | |
else: | |
docs_string = "Pas d'élements juridique trouvé dans les codes de loi" | |
complete_response = ( | |
"**Pas d'élément trouvé dans les textes de loi. Préciser votre réponse**" | |
) | |
messages.append({"role": "assistant", "content": complete_response}) | |
gradio_format = make_pairs([a["content"] for a in messages[1:]]) | |
yield gradio_format, messages, docs_string | |
def save_feedback(feed: str, user_id): | |
if len(feed) > 1: | |
timestamp = str(datetime.now().timestamp()) | |
file = user_id[0] + timestamp + ".json" | |
logs = { | |
"user_id": user_id[0], | |
"feedback": feed, | |
"time": timestamp, | |
} | |
return "Feedback submitted, thank you!" | |
def reset_textbox(): | |
return gr.update(value="") | |
with gr.Blocks(title="LoiLibre Q&A", css="style.css", theme=theme) as demo: | |
user_id_state = gr.State([user_id]) | |
# Gradio | |
gr.Markdown("<h1><center>LoiLibre Q&A</center></h1>") | |
gr.Markdown("<h4><center>Pose tes questions aux textes de loi ici</center></h4>") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", label="LoiLibreQ&A chatbot", show_label=False | |
) | |
state = gr.State([system_template]) | |
with gr.Row(): | |
ask = gr.Textbox( | |
show_label=False, | |
placeholder="Pose ta question ici", | |
).style(container=False) | |
ask_examples_hidden = gr.Textbox(elem_id="hidden-message") | |
examples_questions = gr.Examples( | |
[ | |
"Quelles sont les options légales pour une personne qui souhaite divorcer, notamment en matière de garde d'enfants et de pension alimentaire ?", | |
"Quelles sont les démarches à suivre pour créer une entreprise et quels sont les risques et les responsabilités juridiques associés ?", | |
"Comment pouvez-vous m'aider à protéger mes droits d'auteur et à faire respecter mes droits de propriété intellectuelle ?", | |
"Quels sont mes droits si j'ai été victime de harcèlement au travail ou de discrimination en raison de mon âge, de ma race ou de mon genre ?", | |
"Quelles sont les conséquences légales pour une entreprise qui a été poursuivie pour négligence ou faute professionnelle ?", | |
"Comment pouvez-vous m'aider à négocier un contrat de location commercial ou résidentiel, et quels sont mes droits et obligations en tant que locataire ou propriétaire ?", | |
"Quels sont les défenses possibles pour une personne accusée de crimes sexuels ou de violence domestique ?", | |
"Quelles sont les options légales pour une personne qui souhaite contester un testament ou un héritage ?", | |
"Comment pouvez-vous m'aider à obtenir une compensation en cas d'accident de voiture ou de blessure personnelle causée par la négligence d'une autre personne ?", | |
"Comment pouvez-vous m'aider à obtenir un visa ou un statut de résident permanent aux États-Unis, et quels sont les risques et les avantages associés ?", | |
], | |
[ask_examples_hidden], | |
) | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("### Sources") | |
sources_textbox = gr.Markdown(show_label=False) | |
ask.submit( | |
fn=chat, | |
inputs=[user_id_state, ask, state], | |
outputs=[chatbot, state, sources_textbox], | |
) | |
ask.submit(reset_textbox, [], [ask]) | |
ask_examples_hidden.change( | |
fn=chat, | |
inputs=[user_id_state, ask_examples_hidden, state], | |
outputs=[chatbot, state, sources_textbox], | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
<div class="warning-box"> | |
Version 0.1-beta - This tool is under active development | |
</div> | |
""") | |
gr.Markdown( | |
""" | |
""") | |
demo.queue(concurrency_count=16) | |
demo.launch(server_name="0.0.0.0") | |