YingxuHe's picture
add new examples
40973e9
import numpy as np
import streamlit as st
from src.content.common import (
MODEL_NAMES,
AUDIO_SAMPLES_W_INSTRUCT,
PLAYGROUND_DIALOGUE_STATES,
reset_states,
update_voice_instruction_state,
init_state_section,
header_section,
sidebar_fragment,
successful_example_section,
audio_attach_dialogue,
retrive_response_with_ui
)
QUICK_ACTIONS = [
{
"name": "**Summary**",
"instruction": "Please summarise this speech.",
"width": 10,
},
{
"name": "**Transcript**",
"instruction": "Please transcribe the speech",
"width": 9.5,
}
]
PG_CONVERSATION_STATES = dict(
pg_messages=[],
)
@st.fragment
def select_model_variants_fradment():
display_mapper = {
value["vllm_name"]: value["ui_name"]
for key, value in MODEL_NAMES.items()
if "audiollm" in key
}
st.selectbox(
label=":fire: Explore more MERaLiON-AudioLLM variants!",
options=list(display_mapper.keys()),
index=0,
format_func=lambda o: display_mapper[o],
key="pg_model_name",
placeholder=":fire: Explore more MERaLiON-AudioLLM variants!",
disabled=st.session_state.disprompt,
)
def bottom_input_section():
select_model_variants_fradment()
bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03])
with bottom_cols[0]:
st.button(
':material/delete:',
disabled=st.session_state.disprompt,
on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES)
)
with bottom_cols[1]:
if st.button(":material/add:", disabled=st.session_state.disprompt):
audio_attach_dialogue(
audio_array_state="pg_audio_array",
audio_base64_state="pg_audio_base64",
restore_state=PG_CONVERSATION_STATES
)
with bottom_cols[2]:
if chat_input := st.chat_input(
placeholder="Instruction...",
disabled=st.session_state.disprompt,
on_submit=lambda: st.session_state.update(
disprompt=True,
**PG_CONVERSATION_STATES
)
):
st.session_state.new_prompt = chat_input
with bottom_cols[3]:
uploaded_voice = st.audio_input(
label="voice_instruction",
label_visibility="collapsed",
disabled=st.session_state.disprompt,
on_change=lambda: st.session_state.update(
disprompt=True,
on_record_voice_instruction=True,
**PG_CONVERSATION_STATES
),
key='voice_instruction'
)
if uploaded_voice and st.session_state.on_record_voice_instruction:
voice_bytes = uploaded_voice.read()
update_voice_instruction_state(voice_bytes)
st.session_state.on_record_voice_instruction = False
@st.fragment
def quick_actions_fragment():
action_cols_spec = [_["width"] for _ in QUICK_ACTIONS]
action_cols = st.columns(action_cols_spec)
for idx, action in enumerate(QUICK_ACTIONS):
action_cols[idx].button(
action["name"],
args=(action["instruction"],),
disabled=st.session_state.disprompt,
on_click=lambda p: st.session_state.update(
disprompt=True,
pg_messages=[],
new_prompt=p,
on_select_quick_action=True
)
)
if st.session_state.on_select_quick_action:
st.session_state.on_select_quick_action = False
st.rerun(scope="app")
def conversation_section():
if st.session_state.pg_audio_array.size:
with st.chat_message("user"):
st.audio(st.session_state.pg_audio_array, format="audio/wav", sample_rate=16000)
quick_actions_fragment()
for message in st.session_state.pg_messages:
with st.chat_message(message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if message.get("content"):
st.write(message["content"])
with st._bottom:
bottom_input_section()
if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64):
return
one_time_prompt = st.session_state.new_prompt
one_time_vi_array = st.session_state.new_vi_array
one_time_vi_base64 = st.session_state.new_vi_base64
st.session_state.update(
new_prompt="",
new_vi_array=np.array([]),
new_vi_base64="",
pg_messages=[]
)
with st.chat_message("user"):
if one_time_vi_base64:
with st.spinner("Transcribing..."):
error_msg, warnings, one_time_prompt = retrive_response_with_ui(
model_name=MODEL_NAMES["audiollm"]["vllm_name"],
text_input="Write out the dialogue as text.",
array_audio_input=one_time_vi_array,
base64_audio_input=one_time_vi_base64,
stream=False,
normalise_response=True
)
else:
error_msg, warnings = "", []
st.write(one_time_prompt)
st.session_state.pg_messages.append({
"role": "user",
"error": error_msg,
"warnings": warnings,
"content": one_time_prompt
})
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
error_msg, warnings, response = retrive_response_with_ui(
model_name=st.session_state.pg_model_name,
text_input=one_time_prompt,
array_audio_input=st.session_state.pg_audio_array,
base64_audio_input=st.session_state.pg_audio_base64,
stream=True
)
st.session_state.pg_messages.append({
"role": "assistant",
"error": error_msg,
"warnings": warnings,
"content": response
})
st.session_state.disprompt=False
st.rerun(scope="app")
def playground_page():
init_state_section()
header_section(
component_name="Playground",
description=""" It is tailored for Singapore’s multilingual and multicultural landscape.
MERaLiON-AudioLLM supports
<strong>Automatic Speech Recognition</strong>,
<strong>Speech Translation</strong>,
<strong>Spoken Question Answering</strong>,
<strong>Spoken Dialogue Summarization</strong>,
<strong>Speech Instruction</strong>, and
<strong>Paralinguistics</strong> tasks.""",
concise_description=""
)
with st.sidebar:
sidebar_fragment()
audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
successful_example_section(
audio_sample_names,
audio_array_state="pg_audio_array",
audio_base64_state="pg_audio_base64",
restore_state=PG_CONVERSATION_STATES
)
conversation_section()