Spaces:
Running
Running
import numpy as np | |
import streamlit as st | |
from src.content.common import ( | |
MODEL_NAMES, | |
AUDIO_SAMPLES_W_INSTRUCT, | |
PLAYGROUND_DIALOGUE_STATES, | |
reset_states, | |
update_voice_instruction_state, | |
init_state_section, | |
header_section, | |
sidebar_fragment, | |
successful_example_section, | |
audio_attach_dialogue, | |
retrive_response_with_ui | |
) | |
QUICK_ACTIONS = [ | |
{ | |
"name": "**Summary**", | |
"instruction": "Please summarise this speech.", | |
"width": 10, | |
}, | |
{ | |
"name": "**Transcript**", | |
"instruction": "Please transcribe the speech", | |
"width": 9.5, | |
} | |
] | |
PG_CONVERSATION_STATES = dict( | |
pg_messages=[], | |
) | |
def select_model_variants_fradment(): | |
display_mapper = { | |
value["vllm_name"]: value["ui_name"] | |
for key, value in MODEL_NAMES.items() | |
if "audiollm" in key | |
} | |
st.selectbox( | |
label=":fire: Explore more MERaLiON-AudioLLM variants!", | |
options=list(display_mapper.keys()), | |
index=0, | |
format_func=lambda o: display_mapper[o], | |
key="pg_model_name", | |
placeholder=":fire: Explore more MERaLiON-AudioLLM variants!", | |
disabled=st.session_state.disprompt, | |
) | |
def bottom_input_section(): | |
select_model_variants_fradment() | |
bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03]) | |
with bottom_cols[0]: | |
st.button( | |
':material/delete:', | |
disabled=st.session_state.disprompt, | |
on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES) | |
) | |
with bottom_cols[1]: | |
if st.button(":material/add:", disabled=st.session_state.disprompt): | |
audio_attach_dialogue( | |
audio_array_state="pg_audio_array", | |
audio_base64_state="pg_audio_base64", | |
restore_state=PG_CONVERSATION_STATES | |
) | |
with bottom_cols[2]: | |
if chat_input := st.chat_input( | |
placeholder="Instruction...", | |
disabled=st.session_state.disprompt, | |
on_submit=lambda: st.session_state.update( | |
disprompt=True, | |
**PG_CONVERSATION_STATES | |
) | |
): | |
st.session_state.new_prompt = chat_input | |
with bottom_cols[3]: | |
uploaded_voice = st.audio_input( | |
label="voice_instruction", | |
label_visibility="collapsed", | |
disabled=st.session_state.disprompt, | |
on_change=lambda: st.session_state.update( | |
disprompt=True, | |
on_record_voice_instruction=True, | |
**PG_CONVERSATION_STATES | |
), | |
key='voice_instruction' | |
) | |
if uploaded_voice and st.session_state.on_record_voice_instruction: | |
voice_bytes = uploaded_voice.read() | |
update_voice_instruction_state(voice_bytes) | |
st.session_state.on_record_voice_instruction = False | |
def quick_actions_fragment(): | |
action_cols_spec = [_["width"] for _ in QUICK_ACTIONS] | |
action_cols = st.columns(action_cols_spec) | |
for idx, action in enumerate(QUICK_ACTIONS): | |
action_cols[idx].button( | |
action["name"], | |
args=(action["instruction"],), | |
disabled=st.session_state.disprompt, | |
on_click=lambda p: st.session_state.update( | |
disprompt=True, | |
pg_messages=[], | |
new_prompt=p, | |
on_select_quick_action=True | |
) | |
) | |
if st.session_state.on_select_quick_action: | |
st.session_state.on_select_quick_action = False | |
st.rerun(scope="app") | |
def conversation_section(): | |
if st.session_state.pg_audio_array.size: | |
with st.chat_message("user"): | |
st.audio(st.session_state.pg_audio_array, format="audio/wav", sample_rate=16000) | |
quick_actions_fragment() | |
for message in st.session_state.pg_messages: | |
with st.chat_message(message["role"]): | |
if message.get("error"): | |
st.error(message["error"]) | |
for warning_msg in message.get("warnings", []): | |
st.warning(warning_msg) | |
if message.get("content"): | |
st.write(message["content"]) | |
with st._bottom: | |
bottom_input_section() | |
if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64): | |
return | |
one_time_prompt = st.session_state.new_prompt | |
one_time_vi_array = st.session_state.new_vi_array | |
one_time_vi_base64 = st.session_state.new_vi_base64 | |
st.session_state.update( | |
new_prompt="", | |
new_vi_array=np.array([]), | |
new_vi_base64="", | |
pg_messages=[] | |
) | |
with st.chat_message("user"): | |
if one_time_vi_base64: | |
with st.spinner("Transcribing..."): | |
error_msg, warnings, one_time_prompt = retrive_response_with_ui( | |
model_name=MODEL_NAMES["audiollm"]["vllm_name"], | |
text_input="Write out the dialogue as text.", | |
array_audio_input=one_time_vi_array, | |
base64_audio_input=one_time_vi_base64, | |
stream=False, | |
normalise_response=True | |
) | |
else: | |
error_msg, warnings = "", [] | |
st.write(one_time_prompt) | |
st.session_state.pg_messages.append({ | |
"role": "user", | |
"error": error_msg, | |
"warnings": warnings, | |
"content": one_time_prompt | |
}) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
error_msg, warnings, response = retrive_response_with_ui( | |
model_name=st.session_state.pg_model_name, | |
text_input=one_time_prompt, | |
array_audio_input=st.session_state.pg_audio_array, | |
base64_audio_input=st.session_state.pg_audio_base64, | |
stream=True | |
) | |
st.session_state.pg_messages.append({ | |
"role": "assistant", | |
"error": error_msg, | |
"warnings": warnings, | |
"content": response | |
}) | |
st.session_state.disprompt=False | |
st.rerun(scope="app") | |
def playground_page(): | |
init_state_section() | |
header_section( | |
component_name="Playground", | |
description=""" It is tailored for Singapore’s multilingual and multicultural landscape. | |
MERaLiON-AudioLLM supports | |
<strong>Automatic Speech Recognition</strong>, | |
<strong>Speech Translation</strong>, | |
<strong>Spoken Question Answering</strong>, | |
<strong>Spoken Dialogue Summarization</strong>, | |
<strong>Speech Instruction</strong>, and | |
<strong>Paralinguistics</strong> tasks.""", | |
concise_description="" | |
) | |
with st.sidebar: | |
sidebar_fragment() | |
audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys()] | |
successful_example_section( | |
audio_sample_names, | |
audio_array_state="pg_audio_array", | |
audio_base64_state="pg_audio_base64", | |
restore_state=PG_CONVERSATION_STATES | |
) | |
conversation_section() |