import gradio as gr from loguru import logger from pydantic import BaseModel from ast import literal_eval from allofresh_chatbot import AllofreshChatbot from utils import cut_dialogue_history from prompts.mod_prompt import FALLBACK_MESSAGE allo_chatbot = AllofreshChatbot(debug=True) class Message(BaseModel): role: str content: str def fetch_messages(history): """ Fetch the messages from the chat history. """ return [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] def preproc_history(history): """ Clean the chat history to remove the None values. """ clean_history = [Message(**msg) for msg in history if msg["content"] is not None] return cut_dialogue_history(str(clean_history)) def user_input(input, history): """ Add the user input to the chat history. """ history.append({'role': 'user', 'content': input}) history.append({'role': 'assistant', 'content': None}) return fetch_messages(history), history def predict_answer(input, history): """ Answering component """ answer = allo_chatbot.answer_optim_v2(input, preproc_history(history)) history.append({'role': 'user', 'content': None}) history.append({'role': 'assistant', 'content': answer}) return fetch_messages(history), history def predict_reco(history): """ Reco component """ if history[-1]["content"] != FALLBACK_MESSAGE: reco = allo_chatbot.reco_optim_v1(preproc_history(history)) history.append({'role': 'user', 'content': None}) history.append({'role': 'assistant', 'content': reco}) return fetch_messages(history), history """ Gradio Blocks low-level API that allows to create custom web applications (here our chat app) """ with gr.Blocks() as app: logger.info("Starting app...") chatbot = gr.Chatbot(label="Allofresh Assistant") state = gr.State([]) with gr.Row(): txt = gr.Textbox(show_label=False, placeholder="Enter text, then press enter").style(container=False) txt.submit( user_input, [txt, state], [chatbot, state] ).success( predict_answer, [txt, state], [chatbot, state] ).success( predict_reco, [state], [chatbot, state] ) app.queue(concurrency_count=4) app.launch()