import os
import requests
import numpy as np
import streamlit as st
from src.retrieval import STANDARD_QUERIES
from src.content.common import (
MODEL_NAMES,
AUDIO_SAMPLES_W_INSTRUCT,
AGENT_DIALOGUE_STATES,
reset_states,
update_voice_instruction_state,
init_state_section,
header_section,
sidebar_fragment,
successful_example_section,
audio_attach_dialogue,
retrive_response_with_ui
)
API_BASE_URL = os.getenv('API_BASE_URL')
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
## User Question
{user_question}
{audio_information_prompt}Please reply to user's question with a friendly, accurate, and helpful answer."""
AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
## Audio Information
{audio_information}
However, the audio analysis may or may not contain relevant information to the user question, please only reply the user with the relevant information.
"""
AUDIO_ANALYSIS_STATUS = "MERaLiON-AudioLLM Analysis"
AG_CONVERSATION_STATES = dict(
ag_messages=[],
ag_model_messages=[],
ag_visited_query_indices=[],
)
def bottom_input_section():
bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03])
with bottom_cols[0]:
st.button(
':material/delete:',
disabled=st.session_state.disprompt,
on_click=lambda: reset_states(AGENT_DIALOGUE_STATES)
)
with bottom_cols[1]:
if st.button(":material/add:", disabled=st.session_state.disprompt):
audio_attach_dialogue(
audio_array_state="ag_audio_array",
audio_base64_state="ag_audio_base64",
restore_state=AG_CONVERSATION_STATES
)
with bottom_cols[2]:
if chat_input := st.chat_input(
placeholder="Instruction...",
disabled=st.session_state.disprompt,
on_submit=lambda: st.session_state.update(disprompt=True)
):
st.session_state.new_prompt = chat_input
with bottom_cols[3]:
uploaded_voice = st.audio_input(
label="voice_instruction",
label_visibility="collapsed",
disabled=st.session_state.disprompt,
on_change=lambda: st.session_state.update(
disprompt=True,
on_record_voice_instruction=True
),
key='voice_instruction'
)
if uploaded_voice and st.session_state.on_record_voice_instruction:
voice_bytes = uploaded_voice.read()
update_voice_instruction_state(voice_bytes)
st.session_state.on_record_voice_instruction = False
def _prepare_final_prompt_with_ui(one_time_prompt):
if st.session_state.ag_audio_array.shape[0] == 0:
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
with st.spinner("Searching appropriate querys..."):
response = requests.get(
f"{API_BASE_URL}retrieve_relevant_docs",
params={"user_question": one_time_prompt}
)
relevant_query_indices = response.json()
if len(st.session_state.ag_messages) <= 2:
relevant_query_indices.append(0)
relevant_query_indices = list(
set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices)
)
st.session_state.ag_visited_query_indices.extend(relevant_query_indices)
if not relevant_query_indices:
return LLM_PROMPT_TEMPLATE.format(
user_question=one_time_prompt,
audio_information_prompt=""
)
audio_info = []
with st.status(AUDIO_ANALYSIS_STATUS, expanded=False) as status:
for i, standard_idx in enumerate(relevant_query_indices):
new_label = (
f"{AUDIO_ANALYSIS_STATUS}: "
f"{STANDARD_QUERIES[standard_idx]['ui_text']} "
f"({i+1}/{len(relevant_query_indices)})"
)
status.update(label=new_label, state="running")
error_msg, warnings, response = retrive_response_with_ui(
model_name=MODEL_NAMES["audiollm"]["vllm_name"],
text_input=STANDARD_QUERIES[standard_idx]["query_text"],
array_audio_input=st.session_state.ag_audio_array,
base64_audio_input=st.session_state.ag_audio_base64,
prefix=f"**{STANDARD_QUERIES[standard_idx]['ui_text']}**: ",
stream=True,
show_warning=i==0
)
audio_info.append(STANDARD_QUERIES[standard_idx]["response_prefix_text"] + response)
st.session_state.ag_messages[-1]["process"].append({
"error": error_msg,
"warnings": warnings,
"content": response
})
status.update(label=AUDIO_ANALYSIS_STATUS, state="complete")
audio_information_prompt = AUDIO_INFO_TEMPLATE.format(
audio_information="\n".join(audio_info)
)
return LLM_PROMPT_TEMPLATE.format(
user_question=one_time_prompt,
audio_information_prompt=audio_information_prompt
)
def conversation_section():
chat_message_container = st.container(height=480)
if st.session_state.ag_audio_array.size:
with chat_message_container.chat_message("user"):
st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000)
for message in st.session_state.ag_messages:
with chat_message_container.chat_message(name=message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if process := message.get("process", []):
with st.status(AUDIO_ANALYSIS_STATUS, expanded=False, state="complete"):
for proc in process:
if proc.get("error"):
st.error(proc["error"])
for proc_warning_msg in proc.get("warnings", []):
st.warning(proc_warning_msg)
if proc.get("content"):
st.write(proc["content"])
if message.get("content"):
st.write(message["content"])
with st._bottom:
bottom_input_section()
if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64):
return
one_time_prompt = st.session_state.new_prompt
one_time_vi_array = st.session_state.new_vi_array
one_time_vi_base64 = st.session_state.new_vi_base64
st.session_state.update(
new_prompt="",
new_vi_array=np.array([]),
new_vi_base64="",
)
with chat_message_container.chat_message("user"):
if one_time_vi_base64:
with st.spinner("Transcribing..."):
error_msg, warnings, one_time_prompt = retrive_response_with_ui(
model_name=MODEL_NAMES["audiollm"]["vllm_name"],
text_input="Write out the dialogue as text.",
array_audio_input=one_time_vi_array,
base64_audio_input=one_time_vi_base64,
stream=False,
normalise_response=True
)
else:
error_msg, warnings = "", []
st.write(one_time_prompt)
st.session_state.ag_messages.append({
"role": "user",
"error": error_msg,
"warnings": warnings,
"content": one_time_prompt
})
with chat_message_container.chat_message("assistant"):
assistant_message = {"role": "assistant", "process": []}
st.session_state.ag_messages.append(assistant_message)
final_prompt = _prepare_final_prompt_with_ui(one_time_prompt)
llm_response_prefix = f"**{MODEL_NAMES['llm']['ui_name']}**: "
error_msg, warnings, response = retrive_response_with_ui(
model_name=MODEL_NAMES["llm"]["vllm_name"],
text_input=final_prompt,
array_audio_input=st.session_state.ag_audio_array,
base64_audio_input="",
prefix=llm_response_prefix,
stream=True,
history=st.session_state.ag_model_messages,
show_warning=False
)
assistant_message.update({
"error": error_msg,
"warnings": warnings,
"content": response
})
pure_response = response.replace(llm_response_prefix, "")
st.session_state.ag_model_messages.extend([
{"role": "user", "content": final_prompt},
{"role": "assistant", "content": pure_response}
])
st.session_state.disprompt=False
st.rerun(scope="app")
def agent_page():
init_state_section()
header_section(
component_name="Chatbot",
description=""" It is implemented by connecting multiple AI models,
offers more flexibility, and supports multi-round conversation.""",
concise_description=""" It is implemented by connecting multiple AI models and
support multi-round conversation.""",
icon="👥"
)
with st.sidebar:
sidebar_fragment()
audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paral" in name]
successful_example_section(
audio_sample_names,
audio_array_state="ag_audio_array",
audio_base64_state="ag_audio_base64",
restore_state=AG_CONVERSATION_STATES
)
conversation_section()