# Imports import gradio as gr from helper_functions import * with gr.Blocks() as app: gr.Markdown('# Trading Q&A Bot') session_data = gr.State([ [],[] ]) def user(user_message, history): return "", history + [[user_message, None]] def bot(history, session_data_fn): messages_archived = session_data_fn[0] messages_current = session_data_fn[1] bot_message, messages_archived, messages_current = get_reply(history[-1][0], messages_archived, messages_current) history[-1][1] = bot_message session_data_fn[0] = messages_archived session_data_fn[1] = messages_current return history, session_data_fn def reset_memory(session_data_fn): messages_archived = session_data_fn[0] # print("Message Archived Len=", len(messages_archived)) if(len(messages_archived)>=21): messages_archived = messages_archived[0:1] + messages_archived[3:] session_data_fn[0] = messages_archived return session_data_fn def clear_data(session_data_fn): messages_archived = [ {"role": "system", "content": pre_text} ] messages_current = [] session_data_fn[0] = messages_archived session_data_fn[1] = messages_current return None, session_data_fn def get_context_gr(session_data_fn): messages_current = session_data_fn[1] return str(messages_current) with gr.Tab("Chat"): with gr.Row(): with gr.Column(): msg = gr.Textbox() with gr.Row(): submit = gr.Button("Submit") clear = gr.Button("Clear") with gr.Column(): chatbot = gr.Chatbot() with gr.Tab("Prompt"): context = gr.Textbox() submit_p = gr.Button("Check Prompt") # Tab Chat msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, session_data], [chatbot, session_data] ).then( fn = reset_memory, inputs = session_data, outputs = session_data ) submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, session_data], [chatbot, session_data] ).then( fn = reset_memory, inputs = session_data, outputs = session_data ) clear.click( fn = clear_data, inputs = session_data, outputs = [chatbot, session_data], queue = False ) # Tab Prompt submit_p.click(get_context_gr, session_data, context, queue=False) app.launch(auth=(os.getenv("id"), os.getenv("password")), show_api=False)