import os import requests import numpy as np import streamlit as st from src.retrieval import STANDARD_QUERIES from src.content.common import ( MODEL_NAMES, AUDIO_SAMPLES_W_INSTRUCT, AGENT_DIALOGUE_STATES, reset_states, update_voice_instruction_state, init_state_section, header_section, sidebar_fragment, successful_example_section, audio_attach_dialogue, retrive_response_with_ui ) API_BASE_URL = os.getenv('API_BASE_URL') LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}""" LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip. ## User Question {user_question} {audio_information_prompt}Please reply to user's question with a friendly, accurate, and helpful answer.""" AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip. ## Audio Information {audio_information} However, the audio analysis may or may not contain relevant information to the user question, please only reply the user with the relevant information. """ AUDIO_ANALYSIS_STATUS = "MERaLiON-AudioLLM Analysis" AG_CONVERSATION_STATES = dict( ag_messages=[], ag_model_messages=[], ag_visited_query_indices=[], ) def bottom_input_section(): bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03]) with bottom_cols[0]: st.button( ':material/delete:', disabled=st.session_state.disprompt, on_click=lambda: reset_states(AGENT_DIALOGUE_STATES) ) with bottom_cols[1]: if st.button(":material/add:", disabled=st.session_state.disprompt): audio_attach_dialogue( audio_array_state="ag_audio_array", audio_base64_state="ag_audio_base64", restore_state=AG_CONVERSATION_STATES ) with bottom_cols[2]: if chat_input := st.chat_input( placeholder="Instruction...", disabled=st.session_state.disprompt, on_submit=lambda: st.session_state.update(disprompt=True) ): st.session_state.new_prompt = chat_input with bottom_cols[3]: uploaded_voice = st.audio_input( label="voice_instruction", label_visibility="collapsed", disabled=st.session_state.disprompt, on_change=lambda: st.session_state.update( disprompt=True, on_record_voice_instruction=True ), key='voice_instruction' ) if uploaded_voice and st.session_state.on_record_voice_instruction: voice_bytes = uploaded_voice.read() update_voice_instruction_state(voice_bytes) st.session_state.on_record_voice_instruction = False def _prepare_final_prompt_with_ui(one_time_prompt): if st.session_state.ag_audio_array.shape[0] == 0: return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt) with st.spinner("Searching appropriate querys..."): response = requests.get( f"{API_BASE_URL}retrieve_relevant_docs", params={"user_question": one_time_prompt} ) relevant_query_indices = response.json() if len(st.session_state.ag_messages) <= 2: relevant_query_indices.append(0) relevant_query_indices = list( set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices) ) st.session_state.ag_visited_query_indices.extend(relevant_query_indices) if not relevant_query_indices: return LLM_PROMPT_TEMPLATE.format( user_question=one_time_prompt, audio_information_prompt="" ) audio_info = [] with st.status(AUDIO_ANALYSIS_STATUS, expanded=False) as status: for i, standard_idx in enumerate(relevant_query_indices): new_label = ( f"{AUDIO_ANALYSIS_STATUS}: " f"{STANDARD_QUERIES[standard_idx]['ui_text']} " f"({i+1}/{len(relevant_query_indices)})" ) status.update(label=new_label, state="running") error_msg, warnings, response = retrive_response_with_ui( model_name=MODEL_NAMES["audiollm"]["vllm_name"], text_input=STANDARD_QUERIES[standard_idx]["query_text"], array_audio_input=st.session_state.ag_audio_array, base64_audio_input=st.session_state.ag_audio_base64, prefix=f"**{STANDARD_QUERIES[standard_idx]['ui_text']}**: ", stream=True, show_warning=i==0 ) audio_info.append(STANDARD_QUERIES[standard_idx]["response_prefix_text"] + response) st.session_state.ag_messages[-1]["process"].append({ "error": error_msg, "warnings": warnings, "content": response }) status.update(label=AUDIO_ANALYSIS_STATUS, state="complete") audio_information_prompt = AUDIO_INFO_TEMPLATE.format( audio_information="\n".join(audio_info) ) return LLM_PROMPT_TEMPLATE.format( user_question=one_time_prompt, audio_information_prompt=audio_information_prompt ) def conversation_section(): chat_message_container = st.container(height=480) if st.session_state.ag_audio_array.size: with chat_message_container.chat_message("user"): st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000) for message in st.session_state.ag_messages: with chat_message_container.chat_message(name=message["role"]): if message.get("error"): st.error(message["error"]) for warning_msg in message.get("warnings", []): st.warning(warning_msg) if process := message.get("process", []): with st.status(AUDIO_ANALYSIS_STATUS, expanded=False, state="complete"): for proc in process: if proc.get("error"): st.error(proc["error"]) for proc_warning_msg in proc.get("warnings", []): st.warning(proc_warning_msg) if proc.get("content"): st.write(proc["content"]) if message.get("content"): st.write(message["content"]) with st._bottom: bottom_input_section() if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64): return one_time_prompt = st.session_state.new_prompt one_time_vi_array = st.session_state.new_vi_array one_time_vi_base64 = st.session_state.new_vi_base64 st.session_state.update( new_prompt="", new_vi_array=np.array([]), new_vi_base64="", ) with chat_message_container.chat_message("user"): if one_time_vi_base64: with st.spinner("Transcribing..."): error_msg, warnings, one_time_prompt = retrive_response_with_ui( model_name=MODEL_NAMES["audiollm"]["vllm_name"], text_input="Write out the dialogue as text.", array_audio_input=one_time_vi_array, base64_audio_input=one_time_vi_base64, stream=False, normalise_response=True ) else: error_msg, warnings = "", [] st.write(one_time_prompt) st.session_state.ag_messages.append({ "role": "user", "error": error_msg, "warnings": warnings, "content": one_time_prompt }) with chat_message_container.chat_message("assistant"): assistant_message = {"role": "assistant", "process": []} st.session_state.ag_messages.append(assistant_message) final_prompt = _prepare_final_prompt_with_ui(one_time_prompt) llm_response_prefix = f"**{MODEL_NAMES['llm']['ui_name']}**: " error_msg, warnings, response = retrive_response_with_ui( model_name=MODEL_NAMES["llm"]["vllm_name"], text_input=final_prompt, array_audio_input=st.session_state.ag_audio_array, base64_audio_input="", prefix=llm_response_prefix, stream=True, history=st.session_state.ag_model_messages, show_warning=False ) assistant_message.update({ "error": error_msg, "warnings": warnings, "content": response }) pure_response = response.replace(llm_response_prefix, "") st.session_state.ag_model_messages.extend([ {"role": "user", "content": final_prompt}, {"role": "assistant", "content": pure_response} ]) st.session_state.disprompt=False st.rerun(scope="app") def agent_page(): init_state_section() header_section( component_name="Chatbot", description=""" It is implemented by connecting multiple AI models, offers more flexibility, and supports multi-round conversation.""", concise_description=""" It is implemented by connecting multiple AI models and support multi-round conversation.""", icon="👥" ) with st.sidebar: sidebar_fragment() audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paral" in name] successful_example_section( audio_sample_names, audio_array_state="ag_audio_array", audio_base64_state="ag_audio_base64", restore_state=AG_CONVERSATION_STATES ) conversation_section()