Spaces:
Running
Running
import os | |
import requests | |
import numpy as np | |
import streamlit as st | |
from src.retrieval import STANDARD_QUERIES | |
from src.content.common import ( | |
MODEL_NAMES, | |
AUDIO_SAMPLES_W_INSTRUCT, | |
AGENT_DIALOGUE_STATES, | |
reset_states, | |
update_voice_instruction_state, | |
init_state_section, | |
header_section, | |
sidebar_fragment, | |
successful_example_section, | |
audio_attach_dialogue, | |
retrive_response_with_ui | |
) | |
API_BASE_URL = os.getenv('API_BASE_URL') | |
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}""" | |
LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip. | |
## User Question | |
{user_question} | |
{audio_information_prompt}Please reply to user's question with a friendly, accurate, and helpful answer.""" | |
AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip. | |
## Audio Information | |
{audio_information} | |
However, the audio analysis may or may not contain relevant information to the user question, please only reply the user with the relevant information. | |
""" | |
AUDIO_ANALYSIS_STATUS = "MERaLiON-AudioLLM Analysis" | |
AG_CONVERSATION_STATES = dict( | |
ag_messages=[], | |
ag_model_messages=[], | |
ag_visited_query_indices=[], | |
) | |
def bottom_input_section(): | |
bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03]) | |
with bottom_cols[0]: | |
st.button( | |
':material/delete:', | |
disabled=st.session_state.disprompt, | |
on_click=lambda: reset_states(AGENT_DIALOGUE_STATES) | |
) | |
with bottom_cols[1]: | |
if st.button(":material/add:", disabled=st.session_state.disprompt): | |
audio_attach_dialogue( | |
audio_array_state="ag_audio_array", | |
audio_base64_state="ag_audio_base64", | |
restore_state=AG_CONVERSATION_STATES | |
) | |
with bottom_cols[2]: | |
if chat_input := st.chat_input( | |
placeholder="Instruction...", | |
disabled=st.session_state.disprompt, | |
on_submit=lambda: st.session_state.update(disprompt=True) | |
): | |
st.session_state.new_prompt = chat_input | |
with bottom_cols[3]: | |
uploaded_voice = st.audio_input( | |
label="voice_instruction", | |
label_visibility="collapsed", | |
disabled=st.session_state.disprompt, | |
on_change=lambda: st.session_state.update( | |
disprompt=True, | |
on_record_voice_instruction=True | |
), | |
key='voice_instruction' | |
) | |
if uploaded_voice and st.session_state.on_record_voice_instruction: | |
voice_bytes = uploaded_voice.read() | |
update_voice_instruction_state(voice_bytes) | |
st.session_state.on_record_voice_instruction = False | |
def _prepare_final_prompt_with_ui(one_time_prompt): | |
if st.session_state.ag_audio_array.shape[0] == 0: | |
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt) | |
with st.spinner("Searching appropriate querys..."): | |
response = requests.get( | |
f"{API_BASE_URL}retrieve_relevant_docs", | |
params={"user_question": one_time_prompt} | |
) | |
relevant_query_indices = response.json() | |
if len(st.session_state.ag_messages) <= 2: | |
relevant_query_indices.append(0) | |
relevant_query_indices = list( | |
set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices) | |
) | |
st.session_state.ag_visited_query_indices.extend(relevant_query_indices) | |
if not relevant_query_indices: | |
return LLM_PROMPT_TEMPLATE.format( | |
user_question=one_time_prompt, | |
audio_information_prompt="" | |
) | |
audio_info = [] | |
with st.status(AUDIO_ANALYSIS_STATUS, expanded=False) as status: | |
for i, standard_idx in enumerate(relevant_query_indices): | |
new_label = ( | |
f"{AUDIO_ANALYSIS_STATUS}: " | |
f"{STANDARD_QUERIES[standard_idx]['ui_text']} " | |
f"({i+1}/{len(relevant_query_indices)})" | |
) | |
status.update(label=new_label, state="running") | |
error_msg, warnings, response = retrive_response_with_ui( | |
model_name=MODEL_NAMES["audiollm"]["vllm_name"], | |
text_input=STANDARD_QUERIES[standard_idx]["query_text"], | |
array_audio_input=st.session_state.ag_audio_array, | |
base64_audio_input=st.session_state.ag_audio_base64, | |
prefix=f"**{STANDARD_QUERIES[standard_idx]['ui_text']}**: ", | |
stream=True, | |
show_warning=i==0 | |
) | |
audio_info.append(STANDARD_QUERIES[standard_idx]["response_prefix_text"] + response) | |
st.session_state.ag_messages[-1]["process"].append({ | |
"error": error_msg, | |
"warnings": warnings, | |
"content": response | |
}) | |
status.update(label=AUDIO_ANALYSIS_STATUS, state="complete") | |
audio_information_prompt = AUDIO_INFO_TEMPLATE.format( | |
audio_information="\n".join(audio_info) | |
) | |
return LLM_PROMPT_TEMPLATE.format( | |
user_question=one_time_prompt, | |
audio_information_prompt=audio_information_prompt | |
) | |
def conversation_section(): | |
chat_message_container = st.container(height=480) | |
if st.session_state.ag_audio_array.size: | |
with chat_message_container.chat_message("user"): | |
st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000) | |
for message in st.session_state.ag_messages: | |
with chat_message_container.chat_message(name=message["role"]): | |
if message.get("error"): | |
st.error(message["error"]) | |
for warning_msg in message.get("warnings", []): | |
st.warning(warning_msg) | |
if process := message.get("process", []): | |
with st.status(AUDIO_ANALYSIS_STATUS, expanded=False, state="complete"): | |
for proc in process: | |
if proc.get("error"): | |
st.error(proc["error"]) | |
for proc_warning_msg in proc.get("warnings", []): | |
st.warning(proc_warning_msg) | |
if proc.get("content"): | |
st.write(proc["content"]) | |
if message.get("content"): | |
st.write(message["content"]) | |
with st._bottom: | |
bottom_input_section() | |
if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64): | |
return | |
one_time_prompt = st.session_state.new_prompt | |
one_time_vi_array = st.session_state.new_vi_array | |
one_time_vi_base64 = st.session_state.new_vi_base64 | |
st.session_state.update( | |
new_prompt="", | |
new_vi_array=np.array([]), | |
new_vi_base64="", | |
) | |
with chat_message_container.chat_message("user"): | |
if one_time_vi_base64: | |
with st.spinner("Transcribing..."): | |
error_msg, warnings, one_time_prompt = retrive_response_with_ui( | |
model_name=MODEL_NAMES["audiollm"]["vllm_name"], | |
text_input="Write out the dialogue as text.", | |
array_audio_input=one_time_vi_array, | |
base64_audio_input=one_time_vi_base64, | |
stream=False, | |
normalise_response=True | |
) | |
else: | |
error_msg, warnings = "", [] | |
st.write(one_time_prompt) | |
st.session_state.ag_messages.append({ | |
"role": "user", | |
"error": error_msg, | |
"warnings": warnings, | |
"content": one_time_prompt | |
}) | |
with chat_message_container.chat_message("assistant"): | |
assistant_message = {"role": "assistant", "process": []} | |
st.session_state.ag_messages.append(assistant_message) | |
final_prompt = _prepare_final_prompt_with_ui(one_time_prompt) | |
llm_response_prefix = f"**{MODEL_NAMES['llm']['ui_name']}**: " | |
error_msg, warnings, response = retrive_response_with_ui( | |
model_name=MODEL_NAMES["llm"]["vllm_name"], | |
text_input=final_prompt, | |
array_audio_input=st.session_state.ag_audio_array, | |
base64_audio_input="", | |
prefix=llm_response_prefix, | |
stream=True, | |
history=st.session_state.ag_model_messages, | |
show_warning=False | |
) | |
assistant_message.update({ | |
"error": error_msg, | |
"warnings": warnings, | |
"content": response | |
}) | |
pure_response = response.replace(llm_response_prefix, "") | |
st.session_state.ag_model_messages.extend([ | |
{"role": "user", "content": final_prompt}, | |
{"role": "assistant", "content": pure_response} | |
]) | |
st.session_state.disprompt=False | |
st.rerun(scope="app") | |
def agent_page(): | |
init_state_section() | |
header_section( | |
component_name="Chatbot", | |
description=""" It is implemented by <strong>connecting multiple AI models</strong>, | |
offers more flexibility, and supports <strong>multi-round</strong> conversation.""", | |
concise_description=""" It is implemented by connecting multiple AI models and | |
support <strong>multi-round</strong> conversation.""", | |
icon="👥" | |
) | |
with st.sidebar: | |
sidebar_fragment() | |
audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paral" in name] | |
successful_example_section( | |
audio_sample_names, | |
audio_array_state="ag_audio_array", | |
audio_base64_state="ag_audio_base64", | |
restore_state=AG_CONVERSATION_STATES | |
) | |
conversation_section() |