import gradio as gr import time import yaml from langchain.prompts.chat import ChatPromptTemplate from huggingface_hub import hf_hub_download from spinoza_project.source.backend.llm_utils import get_llm, get_vectorstore from spinoza_project.source.backend.document_store import pickle_to_document_store from spinoza_project.source.backend.get_prompts import get_qa_prompts from spinoza_project.source.frontend.utils import ( make_html_source, make_html_presse_source, parse_output_llm_with_sources, init_env, ) from spinoza_project.source.backend.prompt_utils import ( to_chat_instruction, SpecialTokens, ) from assets.utils_javascript import ( accordion_trigger, accordion_trigger_end, accordion_trigger_spinoza, accordion_trigger_spinoza_end, ) init_env() with open("./spinoza_project/config.yaml") as f: config = yaml.full_load(f) prompts = {} for source in config["prompt_naming"]: with open(f"./spinoza_project/prompt_{source}.yaml") as f: prompts[source] = yaml.full_load(f) ## Building LLM print("Building LLM") model = "gpt35turbo" llm = get_llm() ## Loading_tools print("Loading Databases") qdrants = { tab: pickle_to_document_store( hf_hub_download( repo_id="SpinozaProject/spinoza-database", filename=f"database_{tab}.pickle", repo_type="dataset", ) ) for tab in config["prompt_naming"] if tab != "Presse" } bdd_presse = get_vectorstore("presse") ## Load Prompts print("Loading Prompts") chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {} for source, prompt in prompts.items(): chat_qa_prompt, chat_reformulation_prompt = get_qa_prompts(config, prompt) chat_qa_prompts[source] = chat_qa_prompt chat_reformulation_prompts[source] = chat_reformulation_prompt with open("./assets/style.css", "r") as f: css = f.read() special_tokens = SpecialTokens(config) synthesis_template = """You are a factual journalist that summarize the secialized awnsers from thechnical sources. Based on the folowing question: {question} And the following expert answer: {answers} Answer in French. When using legal answers, keep tracking of the name of the articles. When using ADEME answers, name the sources that are mainly used. List the different elements mentionned, and highlight the agreement points between the sources, as well as the contradictions or differences. Generate the answer as markdown, with an aerated layout, and headlines in bold Start by highlighting contradictions followed by a general summary and then go into detail that could be interesting for writing an article about. """ synthesis_prompt = to_chat_instruction(synthesis_template, special_tokens) synthesis_prompt_template = ChatPromptTemplate.from_messages([synthesis_prompt]) def zip_longest_fill(*args, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D- iterators = [iter(it) for it in args] num_active = len(iterators) if not num_active: return cond = True fillvalues = [None] * len(iterators) while cond: values = [] for i, it in enumerate(iterators): try: value = next(it) except StopIteration: value = fillvalues[i] values.append(value) new_cond = False for i, elt in enumerate(values): if elt != fillvalues[i]: new_cond = True cond = new_cond fillvalues = values.copy() yield tuple(values) def format_question(question): return f"{question}" # ### def parse_question(question): x = question.replace("

", "").replace("

\n", "") if "### " in x: return x.split("### ")[1] return x def reformulate(question, tab, config=config): if tab in list(config["tabs"].keys()): return llm.stream( chat_reformulation_prompts[config["source_mapping"][tab]], {"question": parse_question(question)}, ) else: return iter([None] * 5) def reformulate_single_question(question, tab, config=config): for elt in reformulate(question, tab, config=config): time.sleep(0.02) yield elt def reformulate_questions(question, config=config): for elt in zip_longest_fill( *[reformulate(question, tab, config=config) for tab in config["tabs"]] ): time.sleep(0.02) yield elt def add_question(question): return question def answer(question, source, tab, config=config): if tab in list(config["tabs"].keys()): if len(source) < 10: return iter(["Aucune source trouvée, veuillez reformuler votre question"]) else: return llm.stream( chat_qa_prompts[config["source_mapping"][tab]], { "question": parse_question(question), "sources": source.replace("

", "").replace("

\n", ""), }, ) else: return iter([None] * 5) def answer_single_question(source, question, tab, config=config): for elt in answer(question, source, tab, config=config): time.sleep(0.02) yield elt def answer_questions(*questions_sources, config=config): questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] for elt in zip_longest_fill( *[ answer(question, source, tab, config=config) for question, source, tab in zip(questions, sources, config["tabs"]) ] ): time.sleep(0.02) yield [ [(question, parse_output_llm_with_sources(ans))] for question, ans in zip(questions, elt) ] def get_sources(questions, qdrants=qdrants, bdd_presse=bdd_presse, config=config): k = config["num_document_retrieved"] min_similarity = config["min_similarity"] formated = [] text = [] for i, (question, tab) in enumerate(zip(questions, list(config["tabs"].keys()))): sources = ( ( bdd_presse.similarity_search_with_relevance_scores( question.replace("

", "").replace("

\n", ""), k=k, ) ) if tab == "Presse" else qdrants[ config["source_mapping"][tab] ].similarity_search_with_relevance_scores( config["query_preprompt"] + question.replace("

", "").replace("

\n", ""), k=k, ) ) sources = [ (doc, score) for doc, score in sources ] # if score >= min_similarity] formated.extend( [ make_html_presse_source(source[0], j, source[1]) for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) ] if tab == "Presse" else [ make_html_source(source[0], j, source[1], config) for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) ] ) text.extend( [ "\n\n".join( [ f"Doc {str(j)} with source type {source[0].metadata.get('file_source_type')}:\n" + source[0].page_content for j, source in zip(range(k * i + 1, k * (i + 1) + 1), sources) ] ) ] ) formated = "".join(formated) return formated, text def retrieve_sources(*questions, qdrants=qdrants, bdd_presse=bdd_presse, config=config): formated_sources, text_sources = get_sources(questions, qdrants, bdd_presse, config) return (formated_sources, *text_sources) def get_synthesis(question, *answers, config=config): answer = [] for i, tab in enumerate(config["tabs"]): if len(str(answers[i])) >= 100: answer.append( f"{tab}\n{answers[i]}".replace("

", "").replace("

\n", "") ) if len(answer) == 0: return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" else: for elt in llm.stream( synthesis_prompt_template, { "question": question.replace("

", "").replace("

\n", ""), "answers": "\n\n".join(answer), }, ): time.sleep(0.01) yield [(question, elt)] theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) with open("./assets/style.css", "r") as f: css = f.read() with open("./assets/source_information.md", "r") as f: source_information = f.read() def start_agents(): gr.Info(message="The agents and Spinoza are loading...", duration=3) def end_agents(): gr.Info( message="The agents and Spinoza have finished answering your question", duration=3, ) def next_call(): print("Next call") init_prompt = """ Hello, I am Spinoza, a conversational assistant designed to help you in your journalistic journey. I will answer your questions based **on the provided sources**. ⚠️ Limitations *Please note that this chatbot is in an early stage phase, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* What do you want to learn ? """ with gr.Blocks( title=f"🔍 Spinoza", css=css, theme=theme, ) as demo: chatbots = {} question = gr.State("") docs_textbox = gr.State([""]) agent_questions = {elt: gr.State("") for elt in config["tabs"]} component_sources = {elt: gr.State("") for elt in config["tabs"]} text_sources = {elt: gr.State("") for elt in config["tabs"]} tab_states = {elt: gr.State(elt) for elt in config["tabs"]} chatbot_states = [ gr.State(name) for name in ["science", "presse", "politique", "legal", "spinoza"] ] with gr.Tab("Q&A", elem_id="main-component"): with gr.Row(elem_id="chatbot-row"): with gr.Column(scale=2, elem_id="center-panel"): with gr.Group(elem_id="chatbot-group"): with gr.Accordion( "Science agent", open=False, elem_id="accordion-science", elem_classes="accordion", ): chatbots[list(config["tabs"].keys())[0]] = gr.Chatbot( show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-science", layout="panel", avatar_images=( "./assets/logos/help.png", None, ), ) with gr.Accordion( "Law agent", open=False, elem_id="accordion-legal", elem_classes="accordion", ): chatbots[list(config["tabs"].keys())[1]] = gr.Chatbot( show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-legal", layout="panel", avatar_images=( "./assets/logos/help.png", None, ), ) with gr.Accordion( "Politics agent", open=False, elem_id="accordion-politique", elem_classes="accordion", ): chatbots[list(config["tabs"].keys())[2]] = gr.Chatbot( show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-politique", layout="panel", avatar_images=( "./assets/logos/help.png", None, # "https://i.ibb.co/cN0czLp/celsius-logo.png", ), ) with gr.Accordion( "ADEME agent", open=False, elem_id="accordion-ademe", elem_classes="accordion", ): chatbots[list(config["tabs"].keys())[3]] = gr.Chatbot( show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-ademe", layout="panel", avatar_images=( "./assets/logos/help.png", None, # "https://i.ibb.co/cN0czLp/celsius-logo.png", ), ) with gr.Accordion( "Press agent", open=False, elem_id="accordion-presse", elem_classes="accordion", ): chatbots[list(config["tabs"].keys())[4]] = gr.Chatbot( show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-presse", layout="panel", avatar_images=( "./assets/logos/help.png", None, # "https://i.ibb.co/cN0czLp/celsius-logo.png", ), ) with gr.Accordion( "Spinoza", open=True, elem_id="accordion-spinoza", elem_classes="accordion", ): chatbots["spinoza"] = gr.Chatbot( value=[(None, init_prompt)], show_copy_button=True, show_share_button=False, show_label=False, elem_id="chatbot-spinoza", layout="panel", avatar_images=( "./assets/logos/help.png", "./assets/logos/spinoza.png", ), ) with gr.Row(elem_id="input-message"): ask = gr.Textbox( placeholder="Ask me anything here!", show_label=False, scale=7, lines=1, interactive=True, elem_id="input-textbox", ) with gr.Column(scale=1, variant="panel", elem_id="right-panel"): with gr.TabItem("Sources", elem_id="tab-sources", id=0): sources_textbox = gr.HTML( show_label=False, elem_id="sources-textbox" ) with gr.Tab("Source information", elem_id="source-component"): with gr.Row(): with gr.Column(scale=1): gr.Markdown(source_information) with gr.Tab("Contact", elem_id="contact-component"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("For any issue contact **spinoza.support@ekimetrics.com**.") ask.submit(start_agents, inputs=[], outputs=[], js=accordion_trigger()).then( fn=reformulate_questions, inputs=[ask], outputs=[agent_questions[tab] for tab in config["tabs"]], ).then( fn=retrieve_sources, inputs=[agent_questions[tab] for tab in config["tabs"]], outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], ).then( fn=answer_questions, inputs=[agent_questions[tab] for tab in config["tabs"]] + [text_sources[tab] for tab in config["tabs"]], outputs=[chatbots[tab] for tab in config["tabs"]], ).then( fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() ).then( fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() ).then( fn=get_synthesis, inputs=[agent_questions[list(config["tabs"].keys())[1]]] + [chatbots[tab] for tab in config["tabs"]], outputs=[chatbots["spinoza"]], ).then( fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() ).then( fn=end_agents, inputs=[], outputs=[] ) if __name__ == "__main__": demo.queue().launch(share=True, debug=True)