|
import time |
|
import yaml |
|
import logging |
|
import gradio as gr |
|
from langchain.prompts.chat import ChatPromptTemplate |
|
from huggingface_hub import hf_hub_download, whoami |
|
from spinoza_project.source.backend.llm_utils import get_llm, get_vectorstore |
|
from spinoza_project.source.backend.document_store import pickle_to_document_store |
|
from spinoza_project.source.backend.get_prompts import get_qa_prompts |
|
from spinoza_project.source.frontend.utils import ( |
|
make_html_source, |
|
make_html_presse_source, |
|
init_env, |
|
) |
|
from spinoza_project.source.backend.prompt_utils import ( |
|
to_chat_instruction, |
|
SpecialTokens, |
|
) |
|
|
|
init_env() |
|
|
|
with open("./spinoza_project/config.yaml") as f: |
|
config = yaml.full_load(f) |
|
|
|
prompts = {} |
|
for source in config["prompt_naming"]: |
|
with open(f"./spinoza_project/prompt_{source}.yaml") as f: |
|
prompts[source] = yaml.full_load(f) |
|
|
|
|
|
print("Building LLM") |
|
model = "gpt35turbo" |
|
llm = get_llm() |
|
|
|
|
|
print("Loading Databases") |
|
qdrants = { |
|
tab: pickle_to_document_store( |
|
hf_hub_download( |
|
repo_id="SpinozaProject/spinoza-database", |
|
filename=f"database_{tab}.pickle", |
|
repo_type="dataset", |
|
) |
|
) |
|
for tab in config["prompt_naming"] |
|
if tab != "Presse" |
|
} |
|
|
|
bdd_presse = get_vectorstore("presse") |
|
|
|
|
|
print("Loading Prompts") |
|
chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {} |
|
for source, prompt in prompts.items(): |
|
chat_qa_prompt, chat_reformulation_prompt = get_qa_prompts(config, prompt) |
|
chat_qa_prompts[source] = chat_qa_prompt |
|
chat_reformulation_prompts[source] = chat_reformulation_prompt |
|
|
|
|
|
|
|
with open("./assets/style.css", "r") as f: |
|
css = f.read() |
|
|
|
|
|
def update_tabs(outil, visible_tabs): |
|
visible_tabs = outil |
|
return visible_tabs |
|
|
|
|
|
special_tokens = SpecialTokens(config) |
|
|
|
synthesis_template = """You are a factual journalist that summarize the secialized awnsers from thechnical sources. |
|
|
|
Based on the folowing question: |
|
{question} |
|
|
|
And the following expert answer: |
|
{answers} |
|
|
|
Answer the question, in French. |
|
When using legal awnsers, keep tracking of the name of the articles. |
|
When using ADEME awnsers, name the sources that are mainly used. |
|
List the different element mentionned, and highlight the agreement points between the sources, as well as the contradictions or differences. |
|
Generate the answer as markdown, with an aerated layout, and headlines in bold |
|
Start by a general summary, agreement and contracdiction, and then go into detail without paraphasing the experts awnsers. |
|
""" |
|
|
|
synthesis_prompt = to_chat_instruction(synthesis_template, special_tokens) |
|
synthesis_prompt_template = ChatPromptTemplate.from_messages([synthesis_prompt]) |
|
|
|
|
|
def zip_longest_fill(*args, fillvalue=None): |
|
|
|
iterators = [iter(it) for it in args] |
|
num_active = len(iterators) |
|
if not num_active: |
|
return |
|
|
|
cond = True |
|
fillvalues = [None] * len(iterators) |
|
while cond: |
|
values = [] |
|
for i, it in enumerate(iterators): |
|
try: |
|
value = next(it) |
|
except StopIteration: |
|
value = fillvalues[i] |
|
values.append(value) |
|
|
|
new_cond = False |
|
for i, elt in enumerate(values): |
|
if elt != fillvalues[i]: |
|
new_cond = True |
|
cond = new_cond |
|
|
|
fillvalues = values.copy() |
|
yield tuple(values) |
|
|
|
|
|
def build_data_dict(config): |
|
data_dict = {} |
|
for tab in config["tabs"]: |
|
data_dict[tab] = { |
|
"tab": { |
|
"init_value": tab, |
|
"component": None, |
|
"elem_id": "tab", |
|
}, |
|
"description": { |
|
"init_value": config["tabs"][tab], |
|
"component": None, |
|
"elem_id": "desc", |
|
}, |
|
"question": { |
|
"init_value": None, |
|
"component": None, |
|
"elem_id": "question", |
|
}, |
|
"answer": { |
|
"init_value": None, |
|
"component": None, |
|
"elem_id": "answer", |
|
}, |
|
"sources": { |
|
"init_value": None, |
|
"component": None, |
|
"elem_id": "src", |
|
}, |
|
} |
|
return data_dict |
|
|
|
|
|
def init_gradio(data, config=config): |
|
for t in data: |
|
data[t]["tab"]["component"] = gr.Tab( |
|
data[t]["tab"]["init_value"], elem_id="tab" |
|
) |
|
with data[t]["tab"]["component"]: |
|
for fields in data[t]: |
|
if fields == "question": |
|
data[t][fields]["component"] = gr.Textbox( |
|
elem_id=data[t][fields]["elem_id"], |
|
show_label=False, |
|
interactive=True, |
|
placeholder="", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif fields != "tab": |
|
data[t][fields]["component"] = gr.Markdown( |
|
data[t][fields]["init_value"], |
|
elem_id=data[t][fields]["elem_id"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return data |
|
|
|
|
|
def add_warning(): |
|
return "*Les éléments cochés ont commencé à être généré dans les onglets spécifiques, la synthèse ne sera disponible qu'après la mise à disposition de ces derniers.*" |
|
|
|
|
|
def format_question(question): |
|
return f"{question}" |
|
|
|
|
|
def parse_question(question): |
|
x = question.replace("<p>", "").replace("</p>\n", "") |
|
if "### " in x: |
|
return x.split("### ")[1] |
|
return x |
|
|
|
|
|
def reformulate(outils, question, tab, config=config): |
|
if tab in outils: |
|
return llm.stream( |
|
chat_reformulation_prompts[config["source_mapping"][tab]], |
|
{"question": parse_question(question)}, |
|
) |
|
else: |
|
return iter([None] * 5) |
|
|
|
|
|
def reformulate_single_question(outils, question, tab, config=config): |
|
for elt in reformulate(outils, question, tab, config=config): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def reformulate_questions(outils, question, config=config): |
|
for elt in zip_longest_fill( |
|
*[reformulate(outils, question, tab, config=config) for tab in config["tabs"]] |
|
): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def add_question(question): |
|
return question |
|
|
|
|
|
def answer(question, source, outils, tab, config=config): |
|
if tab in outils: |
|
if len(source) < 10: |
|
return iter(["Aucune source trouvée, veuillez reformuler votre question"]) |
|
else: |
|
|
|
return llm.stream( |
|
chat_qa_prompts[config["source_mapping"][tab]], |
|
{ |
|
"question": parse_question(question), |
|
"sources": source.replace("<p>", "").replace("</p>\n", ""), |
|
}, |
|
) |
|
else: |
|
return iter([None] * 5) |
|
|
|
|
|
def answer_single_question(outils, source, question, tab, config=config): |
|
for elt in answer(question, source, outils, tab, config=config): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def answer_questions(outils, *questions_sources, config=config): |
|
|
|
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] |
|
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] |
|
|
|
for elt in zip_longest_fill( |
|
*[ |
|
answer(question, source, outils, tab, config=config) |
|
for question, source, tab in zip(questions, sources, config["tabs"]) |
|
] |
|
): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def get_source_link(metadata): |
|
return metadata["file_url"] + f"#page={metadata['content_page_number'] + 1}" |
|
|
|
|
|
def get_button(i, tag): |
|
return f"""<button id="btn_{tag}_{i}" type="button" style="margin: 0; display: inline; align="right">[{i}]</button>""" |
|
|
|
|
|
def get_html_sources(buttons, cards): |
|
return f""" |
|
<p style="margin: 0; display: inline;"><strong><br>Sources utilisées : </strong></p> |
|
{buttons} |
|
{cards} |
|
""" |
|
|
|
|
|
def get_sources( |
|
outils, question, tab, qdrants=qdrants, bdd_presse=bdd_presse, config=config |
|
): |
|
k = config["num_document_retrieved"] |
|
min_similarity = config["min_similarity"] |
|
if tab in outils: |
|
sources = ( |
|
( |
|
bdd_presse.similarity_search_with_relevance_scores( |
|
question.replace("<p>", "").replace("</p>\n", ""), |
|
k=k, |
|
) |
|
) |
|
if tab == "Presse" |
|
else qdrants[ |
|
config["source_mapping"][tab] |
|
].similarity_search_with_relevance_scores( |
|
config["query_preprompt"] |
|
+ question.replace("<p>", "").replace("</p>\n", ""), |
|
k=k, |
|
) |
|
) |
|
|
|
sources = [(doc, score) for doc, score in sources if score >= min_similarity] |
|
|
|
buttons_ids = list(range(len(sources))) |
|
buttons = " ".join( |
|
[get_button(i, tab) for i, source in zip(buttons_ids, sources)] |
|
) |
|
formated = ( |
|
"\n\n".join( |
|
[ |
|
make_html_presse_source(source[0], i, tab, source[1], config) |
|
for i, source in zip(buttons_ids, sources) |
|
] |
|
) |
|
if tab == "Presse" |
|
else "\n\n".join( |
|
[ |
|
make_html_source(source[0], i, tab, source[1], config) |
|
for i, source in zip(buttons_ids, sources) |
|
] |
|
) |
|
) |
|
formated = get_html_sources(buttons, formated) if sources else "" |
|
text = "\n\n".join( |
|
[ |
|
f"Doc {str(i)} with source type {elt[0].metadata.get('file_source_type')}:\n" |
|
+ elt[0].page_content |
|
for i, elt in enumerate(sources) |
|
] |
|
) |
|
return str(formated), str(text) |
|
else: |
|
return "", "" |
|
|
|
|
|
def retrieve_sources( |
|
outils, *questions, qdrants=qdrants, bdd_presse=bdd_presse, config=config |
|
): |
|
results = [ |
|
get_sources(outils, question, tab, qdrants, bdd_presse, config) |
|
for question, tab in zip(questions, config["tabs"]) |
|
] |
|
formated_sources = [source[0] for source in results] |
|
text_sources = [source[1] for source in results] |
|
return tuple(formated_sources + text_sources) |
|
|
|
|
|
def get_experts(outils, *answers, config=config): |
|
return "\n\n".join( |
|
[ |
|
f"{tab}\n{answers[i]}" |
|
for i, tab in enumerate(config["tabs"]) |
|
if (tab in outils) |
|
] |
|
) |
|
|
|
|
|
def get_synthesis(outils, question, *answers, config=config): |
|
answer = [] |
|
for i, tab in enumerate(config["tabs"]): |
|
if (tab in outils) & (len(str(answers[i])) >= 100): |
|
answer.append( |
|
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") |
|
) |
|
|
|
if len(answer) == 0: |
|
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" |
|
else: |
|
for elt in llm.stream( |
|
synthesis_prompt_template, |
|
{ |
|
"question": question.replace("<p>", "").replace("</p>\n", ""), |
|
"answers": "\n\n".join(answer), |
|
}, |
|
): |
|
time.sleep(0.01) |
|
yield elt |
|
|
|
|
|
def get_listener(): |
|
return """ |
|
function my_func_body() { |
|
const body = document.querySelector("body"); |
|
body.addEventListener("click", e => { |
|
console.log(e) |
|
const sourceId = "btn_" + e.target.id.split("_")[1] + "_" + e.target.id.split("_")[2] + "_source" |
|
console.log(sourceId) |
|
if (document.getElementById(sourceId).style.display === "none") { |
|
document.getElementById(sourceId).style.display = ""; |
|
} else { |
|
document.getElementById(sourceId).style.display = "none"; |
|
} |
|
} |
|
)} |
|
""" |
|
|
|
|
|
def get_source_template(buttons, divs_source): |
|
return """ |
|
<div class="source"> |
|
<p style="margin: 0; display: inline;"><strong><br>Sources utilisées :</strong></p> |
|
{buttons} |
|
{divs_source} |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
def activate_questions(outils, *textboxes, config=config): |
|
activated_textboxes = [] |
|
for i, tab in enumerate(config["tabs"]): |
|
if tab in outils: |
|
activated_textboxes.append( |
|
gr.Textbox( |
|
show_label=False, |
|
interactive=True, |
|
placeholder="Sélectionnez cet outil et posez une question sur l'onglet de synthèse", |
|
) |
|
) |
|
|
|
else: |
|
activated_textboxes.append( |
|
gr.Textbox( |
|
show_label=False, |
|
interactive=False, |
|
placeholder="Sélectionnez cet outil et posez une question sur l'onglet de synthèse", |
|
) |
|
) |
|
return activated_textboxes |
|
|
|
|
|
def empty(): |
|
return "" |
|
|
|
|
|
def empty_none(): |
|
return None |
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
init_prompt = """ |
|
Hello, I am Spinoza Q&A, a conversational assistant designed to help journalists by providing secialized answers from technical sources. I will answer your questions based **on the official definition of each ESRS as well as guidelines**. |
|
|
|
⚠️ Limitations |
|
*Please note that this chatbot is in an early stage phase, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
logo_rsf = config["logo_rsf"] |
|
logo_ap = config["logo_ap"] |
|
|
|
data = build_data_dict(config) |
|
|
|
|
|
def update_visible(oauth_token: gr.OAuthToken | None): |
|
if oauth_token is None: |
|
return { |
|
bloc_1: gr.update(visible=True), |
|
bloc_2: gr.update(visible=False), |
|
bloc_3: gr.update(visible=False), |
|
} |
|
|
|
org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]] |
|
|
|
if "SpinozaProject" in org_names: |
|
return { |
|
bloc_1: gr.update(visible=False), |
|
bloc_2: gr.update(visible=True), |
|
bloc_3: gr.update(visible=False), |
|
} |
|
|
|
else: |
|
return { |
|
bloc_1: gr.update(visible=False), |
|
bloc_2: gr.update(visible=False), |
|
bloc_3: gr.update(visible=True), |
|
} |
|
|
|
|
|
with gr.Blocks( |
|
title=f"🔍{config['demo_name']}", |
|
css=css, |
|
js=get_listener(), |
|
theme=theme, |
|
) as demo: |
|
with gr.Column(visible=True): |
|
gr.HTML( |
|
f"""<div class="row_logo"> |
|
<img src={logo_rsf} alt="logo RSF" style="float:left; width:120px; height:70px"> |
|
<img src={logo_ap} alt="logo AP" style="width:120px; height:70px"> |
|
</div>""" |
|
) |
|
|
|
text_sources = {elt: gr.State("") for elt in config["tabs"]} |
|
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
outils = gr.CheckboxGroup( |
|
choices=list(config["tabs"].keys()), |
|
value=list(config["tabs"].keys()), |
|
type="value", |
|
label="Choisir les bases de données à interroger", |
|
) |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button( |
|
"Relancer la Synthèse", variant="primary", elem_id="synthese_btn" |
|
) |
|
|
|
|
|
synthesis_tab = gr.Tab("Synthesis", elem_id="tab") |
|
with synthesis_tab: |
|
question = gr.Textbox( |
|
show_label=True, |
|
label="Posez une question à Spinoza", |
|
placeholder="Quelle est votre question ?", |
|
) |
|
md_question = gr.Markdown(None, visible=False) |
|
warning = gr.Markdown(None, elem_id="warn") |
|
synthesis = gr.Markdown(None, elem_id="synthesis") |
|
|
|
data = init_gradio(data) |
|
( |
|
question.submit(add_question, [question], [md_question]) |
|
.then(add_warning, [], [warning]) |
|
.then(empty, [], [synthesis]) |
|
.then( |
|
reformulate_questions, |
|
[outils, md_question], |
|
[data[tab]["question"]["component"] for tab in config["tabs"]], |
|
) |
|
.then( |
|
retrieve_sources, |
|
[outils] |
|
+ [data[tab]["question"]["component"] for tab in config["tabs"]], |
|
[data[tab]["sources"]["component"] for tab in config["tabs"]] |
|
+ [text_sources[tab] for tab in config["tabs"]], |
|
) |
|
.then( |
|
answer_questions, |
|
[outils] |
|
+ [data[tab]["question"]["component"] for tab in config["tabs"]] |
|
+ [text_sources[tab] for tab in config["tabs"]], |
|
[data[tab]["answer"]["component"] for tab in config["tabs"]], |
|
) |
|
.then( |
|
get_synthesis, |
|
[outils, md_question] |
|
+ [data[tab]["answer"]["component"] for tab in config["tabs"]], |
|
[synthesis], |
|
) |
|
) |
|
|
|
for tab in config["tabs"]: |
|
( |
|
data[tab]["question"]["component"] |
|
.submit(empty, [], [data[tab]["sources"]["component"]]) |
|
.then(empty, [], [text_sources[tab]]) |
|
.then(empty, [], [data[tab]["answer"]["component"]]) |
|
.then( |
|
get_sources, |
|
[outils, data[tab]["question"]["component"], tab_states[tab]], |
|
[data[tab]["sources"]["component"], text_sources[tab]], |
|
) |
|
.then( |
|
answer_single_question, |
|
[ |
|
outils, |
|
text_sources[tab], |
|
data[tab]["question"]["component"], |
|
tab_states[tab], |
|
], |
|
[data[tab]["answer"]["component"]], |
|
) |
|
) |
|
|
|
( |
|
submit_btn.click(empty, [], [synthesis]).then( |
|
get_synthesis, |
|
[outils, md_question] |
|
+ [data[tab]["answer"]["component"] for tab in config["tabs"]], |
|
[synthesis], |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(share=True, debug=True) |
|
|