|
import gradio as gr |
|
import time |
|
import yaml |
|
from langchain.prompts.chat import ChatPromptTemplate |
|
from huggingface_hub import hf_hub_download |
|
from spinoza_project.source.backend.llm_utils import ( |
|
get_llm, |
|
get_llm_api, |
|
get_vectorstore, |
|
get_vectorstore_api, |
|
) |
|
from spinoza_project.source.backend.document_store import pickle_to_document_store |
|
from spinoza_project.source.backend.get_prompts import get_qa_prompts |
|
from spinoza_project.source.frontend.utils import ( |
|
make_html_source, |
|
make_html_presse_source, |
|
make_html_afp_source, |
|
parse_output_llm_with_sources, |
|
init_env, |
|
) |
|
from spinoza_project.source.backend.prompt_utils import ( |
|
to_chat_instruction, |
|
SpecialTokens, |
|
) |
|
|
|
from assets.utils_javascript import ( |
|
accordion_trigger, |
|
accordion_trigger_end, |
|
accordion_trigger_spinoza, |
|
accordion_trigger_spinoza_end, |
|
update_footer, |
|
) |
|
|
|
init_env() |
|
|
|
with open("./spinoza_project/config.yaml") as f: |
|
config = yaml.full_load(f) |
|
|
|
prompts = {} |
|
for source in config["prompt_naming"]: |
|
with open(f"./spinoza_project/prompt_{source}.yaml") as f: |
|
prompts[source] = yaml.full_load(f) |
|
|
|
|
|
print("Building LLM") |
|
model = "gpt35turbo" |
|
llm = get_llm_api() |
|
|
|
|
|
print("Loading Databases") |
|
bdd_presse = get_vectorstore_api("presse") |
|
bdd_afp = get_vectorstore_api("afp") |
|
qdrants = { |
|
tab: pickle_to_document_store( |
|
hf_hub_download( |
|
repo_id="SpinozaProject/spinoza-database", |
|
filename=f"database_{tab}.pickle", |
|
repo_type="dataset", |
|
) |
|
) |
|
for tab in config["prompt_naming"] |
|
if tab != "Presse" and tab != "AFP" |
|
} |
|
|
|
|
|
print("Loading Prompts") |
|
chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {} |
|
for source, prompt in prompts.items(): |
|
chat_qa_prompt, chat_reformulation_prompt = get_qa_prompts(config, prompt) |
|
chat_qa_prompts[source] = chat_qa_prompt |
|
chat_reformulation_prompts[source] = chat_reformulation_prompt |
|
|
|
|
|
with open("./assets/style.css", "r") as f: |
|
css = f.read() |
|
|
|
|
|
special_tokens = SpecialTokens(config) |
|
|
|
synthesis_template = """You are a factual journalist that summarize the secialized awnsers from thechnical sources. |
|
|
|
Based on the folowing question: |
|
{question} |
|
|
|
And the following expert answer: |
|
{answers} |
|
|
|
- When using legal answers, keep tracking of the name of the articles. |
|
- When using ADEME answers, name the sources that are mainly used. |
|
- List the different elements mentionned, and highlight the agreement points between the sources, as well as the contradictions or differences. |
|
- Contradictions don't lie in whether or not a subject is dealt with, but more in the opinion given or the way the subject is dealt with. |
|
- Generate the answer as markdown, with an aerated layout, and headlines in bold |
|
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.", |
|
- Do not use the sentence 'Doc i says ...' to say where information came from.", |
|
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]", |
|
- Start by highlighting contradictions, then do a general summary and finally get into the details that might be interesting for article writing. Where relevant, quote them. |
|
- Awnser in French / Répond en Français |
|
""" |
|
|
|
synthesis_prompt = to_chat_instruction(synthesis_template, special_tokens) |
|
synthesis_prompt_template = ChatPromptTemplate.from_messages([synthesis_prompt]) |
|
|
|
|
|
def zip_longest_fill(*args, fillvalue=None): |
|
|
|
iterators = [iter(it) for it in args] |
|
num_active = len(iterators) |
|
if not num_active: |
|
return |
|
|
|
cond = True |
|
fillvalues = [None] * len(iterators) |
|
while cond: |
|
values = [] |
|
for i, it in enumerate(iterators): |
|
try: |
|
value = next(it) |
|
except StopIteration: |
|
value = fillvalues[i] |
|
values.append(value) |
|
|
|
new_cond = False |
|
for i, elt in enumerate(values): |
|
if elt != fillvalues[i]: |
|
new_cond = True |
|
cond = new_cond |
|
|
|
fillvalues = values.copy() |
|
yield tuple(values) |
|
|
|
|
|
def format_question(question): |
|
return f"{question}" |
|
|
|
|
|
def parse_question(question): |
|
x = question.replace("<p>", "").replace("</p>\n", "") |
|
if "### " in x: |
|
return x.split("### ")[1] |
|
return x |
|
|
|
|
|
def reformulate(question, tab, config=config): |
|
if tab in list(config["tabs"].keys()): |
|
return llm.stream( |
|
chat_reformulation_prompts[config["source_mapping"][tab]], |
|
{"question": parse_question(question)}, |
|
) |
|
else: |
|
return iter([None] * 5) |
|
|
|
|
|
def reformulate_single_question(question, tab, config=config): |
|
for elt in reformulate(question, tab, config=config): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def reformulate_questions(question, config=config): |
|
for elt in zip_longest_fill( |
|
*[reformulate(question, tab, config=config) for tab in config["tabs"]] |
|
): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def add_question(question): |
|
return question |
|
|
|
|
|
def answer(question, source, tab, config=config): |
|
if tab in list(config["tabs"].keys()): |
|
if len(source) < 10: |
|
return iter(["Aucune source trouvée, veuillez reformuler votre question"]) |
|
else: |
|
|
|
return llm.stream( |
|
chat_qa_prompts[config["source_mapping"][tab]], |
|
{ |
|
"question": parse_question(question), |
|
"sources": source.replace("<p>", "").replace("</p>\n", ""), |
|
}, |
|
) |
|
else: |
|
return iter([None] * 5) |
|
|
|
|
|
def answer_single_question(source, question, tab, config=config): |
|
for elt in answer(question, source, tab, config=config): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def answer_questions(*questions_sources, config=config): |
|
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] |
|
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] |
|
|
|
for elt in zip_longest_fill( |
|
*[ |
|
answer(question, source, tab, config=config) |
|
for question, source, tab in zip(questions, sources, config["tabs"]) |
|
] |
|
): |
|
time.sleep(0.02) |
|
yield [ |
|
[(question, parse_output_llm_with_sources(ans))] |
|
for question, ans in zip(questions, elt) |
|
] |
|
|
|
|
|
def get_sources( |
|
questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config |
|
): |
|
k = config["num_document_retrieved"] |
|
min_similarity = config["min_similarity"] |
|
text, formated = [], [] |
|
for i, (question, tab) in enumerate(zip(questions, list(config["tabs"].keys()))): |
|
if tab == "Presse": |
|
sources = bdd_presse.similarity_search_with_relevance_scores( |
|
question.replace("<p>", "").replace("</p>\n", ""), k=k |
|
) |
|
sources = [ |
|
(doc, score) for doc, score in sources if score >= min_similarity |
|
] |
|
formated.extend( |
|
[ |
|
make_html_presse_source(source[0], j, source[1]) |
|
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) |
|
] |
|
) |
|
|
|
elif tab == "AFP": |
|
sources = bdd_afp.similarity_search_with_relevance_scores( |
|
question.replace("<p>", "").replace("</p>\n", ""), k=k |
|
) |
|
sources = [ |
|
(doc, score) for doc, score in sources if score >= min_similarity |
|
] |
|
formated.extend( |
|
[ |
|
make_html_afp_source(source[0], j, source[1]) |
|
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) |
|
] |
|
) |
|
|
|
else: |
|
sources = qdrants[ |
|
config["source_mapping"][tab] |
|
].similarity_search_with_relevance_scores( |
|
config["query_preprompt"] |
|
+ question.replace("<p>", "").replace("</p>\n", ""), |
|
k=k, |
|
) |
|
sources = [ |
|
(doc, score) for doc, score in sources if score >= min_similarity |
|
] |
|
formated.extend( |
|
[ |
|
make_html_source(source[0], j, source[1], config) |
|
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) |
|
] |
|
) |
|
|
|
text.extend( |
|
[ |
|
"\n\n".join( |
|
[ |
|
f"Doc {str(j)} with source type {source[0].metadata.get('file_source_type')}:\n" |
|
+ source[0].page_content |
|
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) |
|
] |
|
) |
|
] |
|
) |
|
|
|
formated = "".join(formated) |
|
|
|
return formated, text |
|
|
|
|
|
def retrieve_sources( |
|
*questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config |
|
): |
|
formated_sources, text_sources = get_sources( |
|
questions, qdrants, bdd_presse, bdd_afp, config |
|
) |
|
|
|
return (formated_sources, *text_sources) |
|
|
|
|
|
def get_synthesis(question, *answers, config=config): |
|
answer = [] |
|
for i, tab in enumerate(config["tabs"]): |
|
if len(str(answers[i])) >= 100: |
|
answer.append( |
|
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") |
|
) |
|
|
|
if len(answer) == 0: |
|
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" |
|
else: |
|
for elt in llm.stream( |
|
synthesis_prompt_template, |
|
{ |
|
"question": question.replace("<p>", "").replace("</p>\n", ""), |
|
"answers": "\n\n".join(answer), |
|
}, |
|
): |
|
time.sleep(0.01) |
|
yield [(question, parse_output_llm_with_sources(elt))] |
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
with open("./assets/style.css", "r") as f: |
|
css = f.read() |
|
|
|
with open("./assets/source_information.md", "r") as f: |
|
source_information = f.read() |
|
|
|
|
|
def start_agents(): |
|
gr.Info(message="The agents and Spinoza are loading...", duration=3) |
|
|
|
return [ |
|
(None, "I am waiting until all the agents are done to generate an answer...") |
|
] |
|
|
|
|
|
def end_agents(): |
|
gr.Info( |
|
message="The agents and Spinoza have finished answering your question", |
|
duration=3, |
|
) |
|
|
|
|
|
def next_call(): |
|
return |
|
|
|
|
|
init_prompt = """ |
|
Hello, I am Spinoza, a conversational assistant designed to help you in your journalistic journey. I will answer your questions based **on the provided sources**. |
|
|
|
⚠️ Limitations |
|
*Please note that this chatbot is in an early stage, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
with gr.Blocks( |
|
title=f"🔍 Spinoza", |
|
css=css, |
|
js=update_footer(), |
|
theme=theme, |
|
) as demo: |
|
chatbots = {} |
|
question = gr.State("") |
|
docs_textbox = gr.State([""]) |
|
agent_questions = {elt: gr.State("") for elt in config["tabs"]} |
|
component_sources = {elt: gr.State("") for elt in config["tabs"]} |
|
text_sources = {elt: gr.State("") for elt in config["tabs"]} |
|
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} |
|
|
|
with gr.Tab("Q&A", elem_id="main-component"): |
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2, elem_id="center-panel"): |
|
with gr.Group(elem_id="chatbot-group"): |
|
with gr.Accordion( |
|
"Science agent", |
|
open=False, |
|
elem_id="accordion-science", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[0]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-science", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"Law agent", |
|
open=False, |
|
elem_id="accordion-legal", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[1]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-legal", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"Politics agent", |
|
open=False, |
|
elem_id="accordion-politique", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[2]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-politique", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"ADEME agent", |
|
open=False, |
|
elem_id="accordion-ademe", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[3]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-ademe", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"Press agent", |
|
open=False, |
|
elem_id="accordion-presse", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[4]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-presse", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"AFP agent", |
|
open=False, |
|
elem_id="accordion-afp", |
|
elem_classes="accordion", |
|
): |
|
chatbots[list(config["tabs"].keys())[5]] = gr.Chatbot( |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-afp", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
None, |
|
), |
|
) |
|
|
|
with gr.Accordion( |
|
"Spinoza", |
|
open=True, |
|
elem_id="accordion-spinoza", |
|
elem_classes="accordion", |
|
): |
|
chatbots["spinoza"] = gr.Chatbot( |
|
value=[(None, init_prompt)], |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id="chatbot-spinoza", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
"./assets/logos/spinoza.png", |
|
), |
|
) |
|
|
|
with gr.Row(elem_id="input-message"): |
|
ask = gr.Textbox( |
|
placeholder="Ask me anything here!", |
|
show_label=False, |
|
scale=7, |
|
lines=1, |
|
interactive=True, |
|
elem_id="input-textbox", |
|
) |
|
|
|
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): |
|
with gr.TabItem("Sources", elem_id="tab-sources", id=0): |
|
sources_textbox = gr.HTML( |
|
show_label=False, elem_id="sources-textbox" |
|
) |
|
|
|
with gr.Tab("Source information", elem_id="source-component"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(source_information) |
|
|
|
with gr.Tab("Contact", elem_id="contact-component"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("For any issue contact **[email protected]**.") |
|
|
|
ask.submit( |
|
start_agents, inputs=[], outputs=[chatbots["spinoza"]], js=accordion_trigger() |
|
).then( |
|
fn=reformulate_questions, |
|
inputs=[ask], |
|
outputs=[agent_questions[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=retrieve_sources, |
|
inputs=[agent_questions[tab] for tab in config["tabs"]], |
|
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=answer_questions, |
|
inputs=[agent_questions[tab] for tab in config["tabs"]] |
|
+ [text_sources[tab] for tab in config["tabs"]], |
|
outputs=[chatbots[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() |
|
).then( |
|
fn=get_synthesis, |
|
inputs=[agent_questions[list(config["tabs"].keys())[1]]] |
|
+ [chatbots[tab] for tab in config["tabs"]], |
|
outputs=[chatbots["spinoza"]], |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() |
|
).then( |
|
fn=end_agents, inputs=[], outputs=[] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(debug=True) |
|
|