spinoza / app.py
momenaca's picture
update AFP agent with access to its vectorstore
e2e3255
raw
history blame
19.6 kB
import gradio as gr
import time
import yaml
from langchain.prompts.chat import ChatPromptTemplate
from huggingface_hub import hf_hub_download
from spinoza_project.source.backend.llm_utils import (
get_llm,
get_llm_api,
get_vectorstore,
get_vectorstore_api,
)
from spinoza_project.source.backend.document_store import pickle_to_document_store
from spinoza_project.source.backend.get_prompts import get_qa_prompts
from spinoza_project.source.frontend.utils import (
make_html_source,
make_html_presse_source,
make_html_afp_source,
parse_output_llm_with_sources,
init_env,
)
from spinoza_project.source.backend.prompt_utils import (
to_chat_instruction,
SpecialTokens,
)
from assets.utils_javascript import (
accordion_trigger,
accordion_trigger_end,
accordion_trigger_spinoza,
accordion_trigger_spinoza_end,
update_footer,
)
init_env()
with open("./spinoza_project/config.yaml") as f:
config = yaml.full_load(f)
prompts = {}
for source in config["prompt_naming"]:
with open(f"./spinoza_project/prompt_{source}.yaml") as f:
prompts[source] = yaml.full_load(f)
## Building LLM
print("Building LLM")
model = "gpt35turbo"
llm = get_llm_api()
## Loading_tools
print("Loading Databases")
bdd_presse = get_vectorstore_api("presse")
bdd_afp = get_vectorstore_api("afp")
qdrants = {
tab: pickle_to_document_store(
hf_hub_download(
repo_id="SpinozaProject/spinoza-database",
filename=f"database_{tab}.pickle",
repo_type="dataset",
)
)
for tab in config["prompt_naming"]
if tab != "Presse" and tab != "AFP"
}
## Load Prompts
print("Loading Prompts")
chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {}
for source, prompt in prompts.items():
chat_qa_prompt, chat_reformulation_prompt = get_qa_prompts(config, prompt)
chat_qa_prompts[source] = chat_qa_prompt
chat_reformulation_prompts[source] = chat_reformulation_prompt
with open("./assets/style.css", "r") as f:
css = f.read()
special_tokens = SpecialTokens(config)
synthesis_template = """You are a factual journalist that summarize the secialized awnsers from thechnical sources.
Based on the folowing question:
{question}
And the following expert answer:
{answers}
- When using legal answers, keep tracking of the name of the articles.
- When using ADEME answers, name the sources that are mainly used.
- List the different elements mentionned, and highlight the agreement points between the sources, as well as the contradictions or differences.
- Contradictions don't lie in whether or not a subject is dealt with, but more in the opinion given or the way the subject is dealt with.
- Generate the answer as markdown, with an aerated layout, and headlines in bold
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.",
- Do not use the sentence 'Doc i says ...' to say where information came from.",
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]",
- Start by highlighting contradictions, then do a general summary and finally get into the details that might be interesting for article writing. Where relevant, quote them.
- Awnser in French / Répond en Français
"""
synthesis_prompt = to_chat_instruction(synthesis_template, special_tokens)
synthesis_prompt_template = ChatPromptTemplate.from_messages([synthesis_prompt])
def zip_longest_fill(*args, fillvalue=None):
# zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
iterators = [iter(it) for it in args]
num_active = len(iterators)
if not num_active:
return
cond = True
fillvalues = [None] * len(iterators)
while cond:
values = []
for i, it in enumerate(iterators):
try:
value = next(it)
except StopIteration:
value = fillvalues[i]
values.append(value)
new_cond = False
for i, elt in enumerate(values):
if elt != fillvalues[i]:
new_cond = True
cond = new_cond
fillvalues = values.copy()
yield tuple(values)
def format_question(question):
return f"{question}" # ###
def parse_question(question):
x = question.replace("<p>", "").replace("</p>\n", "")
if "### " in x:
return x.split("### ")[1]
return x
def reformulate(question, tab, config=config):
if tab in list(config["tabs"].keys()):
return llm.stream(
chat_reformulation_prompts[config["source_mapping"][tab]],
{"question": parse_question(question)},
)
else:
return iter([None] * 5)
def reformulate_single_question(question, tab, config=config):
for elt in reformulate(question, tab, config=config):
time.sleep(0.02)
yield elt
def reformulate_questions(question, config=config):
for elt in zip_longest_fill(
*[reformulate(question, tab, config=config) for tab in config["tabs"]]
):
time.sleep(0.02)
yield elt
def add_question(question):
return question
def answer(question, source, tab, config=config):
if tab in list(config["tabs"].keys()):
if len(source) < 10:
return iter(["Aucune source trouvée, veuillez reformuler votre question"])
else:
return llm.stream(
chat_qa_prompts[config["source_mapping"][tab]],
{
"question": parse_question(question),
"sources": source.replace("<p>", "").replace("</p>\n", ""),
},
)
else:
return iter([None] * 5)
def answer_single_question(source, question, tab, config=config):
for elt in answer(question, source, tab, config=config):
time.sleep(0.02)
yield elt
def answer_questions(*questions_sources, config=config):
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]]
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]]
for elt in zip_longest_fill(
*[
answer(question, source, tab, config=config)
for question, source, tab in zip(questions, sources, config["tabs"])
]
):
time.sleep(0.02)
yield [
[(question, parse_output_llm_with_sources(ans))]
for question, ans in zip(questions, elt)
]
def get_sources(
questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config
):
k = config["num_document_retrieved"]
min_similarity = config["min_similarity"]
text, formated = [], []
for i, (question, tab) in enumerate(zip(questions, list(config["tabs"].keys()))):
if tab == "Presse":
sources = bdd_presse.similarity_search_with_relevance_scores(
question.replace("<p>", "").replace("</p>\n", ""), k=k
)
sources = [
(doc, score) for doc, score in sources if score >= min_similarity
]
formated.extend(
[
make_html_presse_source(source[0], j, source[1])
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources)
]
)
elif tab == "AFP":
sources = bdd_afp.similarity_search_with_relevance_scores(
question.replace("<p>", "").replace("</p>\n", ""), k=k
)
sources = [
(doc, score) for doc, score in sources if score >= min_similarity
]
formated.extend(
[
make_html_afp_source(source[0], j, source[1])
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources)
]
)
else:
sources = qdrants[
config["source_mapping"][tab]
].similarity_search_with_relevance_scores(
config["query_preprompt"]
+ question.replace("<p>", "").replace("</p>\n", ""),
k=k,
)
sources = [
(doc, score) for doc, score in sources if score >= min_similarity
]
formated.extend(
[
make_html_source(source[0], j, source[1], config)
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources)
]
)
text.extend(
[
"\n\n".join(
[
f"Doc {str(j)} with source type {source[0].metadata.get('file_source_type')}:\n"
+ source[0].page_content
for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources)
]
)
]
)
formated = "".join(formated)
return formated, text
def retrieve_sources(
*questions, qdrants=qdrants, bdd_presse=bdd_presse, bdd_afp=bdd_afp, config=config
):
formated_sources, text_sources = get_sources(
questions, qdrants, bdd_presse, bdd_afp, config
)
return (formated_sources, *text_sources)
def get_synthesis(question, *answers, config=config):
answer = []
for i, tab in enumerate(config["tabs"]):
if len(str(answers[i])) >= 100:
answer.append(
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "")
)
if len(answer) == 0:
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question"
else:
for elt in llm.stream(
synthesis_prompt_template,
{
"question": question.replace("<p>", "").replace("</p>\n", ""),
"answers": "\n\n".join(answer),
},
):
time.sleep(0.01)
yield [(question, parse_output_llm_with_sources(elt))]
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with open("./assets/style.css", "r") as f:
css = f.read()
with open("./assets/source_information.md", "r") as f:
source_information = f.read()
def start_agents():
gr.Info(message="The agents and Spinoza are loading...", duration=3)
return [
(None, "I am waiting until all the agents are done to generate an answer...")
]
def end_agents():
gr.Info(
message="The agents and Spinoza have finished answering your question",
duration=3,
)
def next_call():
return
init_prompt = """
Hello, I am Spinoza, a conversational assistant designed to help you in your journalistic journey. I will answer your questions based **on the provided sources**.
⚠️ Limitations
*Please note that this chatbot is in an early stage, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
What do you want to learn ?
"""
with gr.Blocks(
title=f"🔍 Spinoza",
css=css,
js=update_footer(),
theme=theme,
) as demo:
chatbots = {}
question = gr.State("")
docs_textbox = gr.State([""])
agent_questions = {elt: gr.State("") for elt in config["tabs"]}
component_sources = {elt: gr.State("") for elt in config["tabs"]}
text_sources = {elt: gr.State("") for elt in config["tabs"]}
tab_states = {elt: gr.State(elt) for elt in config["tabs"]}
with gr.Tab("Q&A", elem_id="main-component"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2, elem_id="center-panel"):
with gr.Group(elem_id="chatbot-group"):
with gr.Accordion(
"Science agent",
open=False,
elem_id="accordion-science",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[0]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-science",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"Law agent",
open=False,
elem_id="accordion-legal",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[1]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-legal",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"Politics agent",
open=False,
elem_id="accordion-politique",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[2]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-politique",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"ADEME agent",
open=False,
elem_id="accordion-ademe",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[3]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-ademe",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"Press agent",
open=False,
elem_id="accordion-presse",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[4]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-presse",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"AFP agent",
open=False,
elem_id="accordion-afp",
elem_classes="accordion",
):
chatbots[list(config["tabs"].keys())[5]] = gr.Chatbot(
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-afp",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
None,
),
)
with gr.Accordion(
"Spinoza",
open=True,
elem_id="accordion-spinoza",
elem_classes="accordion",
):
chatbots["spinoza"] = gr.Chatbot(
value=[(None, init_prompt)],
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id="chatbot-spinoza",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
"./assets/logos/spinoza.png",
),
)
with gr.Row(elem_id="input-message"):
ask = gr.Textbox(
placeholder="Ask me anything here!",
show_label=False,
scale=7,
lines=1,
interactive=True,
elem_id="input-textbox",
)
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.TabItem("Sources", elem_id="tab-sources", id=0):
sources_textbox = gr.HTML(
show_label=False, elem_id="sources-textbox"
)
with gr.Tab("Source information", elem_id="source-component"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(source_information)
with gr.Tab("Contact", elem_id="contact-component"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("For any issue contact **[email protected]**.")
ask.submit(
start_agents, inputs=[], outputs=[chatbots["spinoza"]], js=accordion_trigger()
).then(
fn=reformulate_questions,
inputs=[ask],
outputs=[agent_questions[tab] for tab in config["tabs"]],
).then(
fn=retrieve_sources,
inputs=[agent_questions[tab] for tab in config["tabs"]],
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]],
).then(
fn=answer_questions,
inputs=[agent_questions[tab] for tab in config["tabs"]]
+ [text_sources[tab] for tab in config["tabs"]],
outputs=[chatbots[tab] for tab in config["tabs"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end()
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza()
).then(
fn=get_synthesis,
inputs=[agent_questions[list(config["tabs"].keys())[1]]]
+ [chatbots[tab] for tab in config["tabs"]],
outputs=[chatbots["spinoza"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end()
).then(
fn=end_agents, inputs=[], outputs=[]
)
if __name__ == "__main__":
demo.queue().launch(debug=True)