YingxuHe's picture
add mic button
e9402b5
import numpy as np
import streamlit as st
from src.generation import (
prepare_multimodal_content,
change_multimodal_content
)
from src.content.common import (
MODEL_NAMES,
VOICE_CHAT_DIALOGUE_STATES,
reset_states,
process_audio_bytes,
init_state_section,
header_section,
sidebar_fragment,
retrive_response_with_ui
)
# TODO: change this.
DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply to the user in a friendly and helpful way."
MAX_VC_ROUNDS = 5
def bottom_input_section():
bottom_cols = st.columns([0.03, 0.97])
with bottom_cols[0]:
st.button(
':material/delete:',
disabled=st.session_state.disprompt,
on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES)
)
with bottom_cols[1]:
uploaded_file = st.audio_input(
label="record audio",
label_visibility="collapsed",
disabled=st.session_state.disprompt,
on_change=lambda: st.session_state.update(
on_record=True,
disprompt=True
),
key='record'
)
if uploaded_file and st.session_state.on_record:
audio_bytes = uploaded_file.read()
st.session_state.vc_audio_array, st.session_state.vc_audio_base64 = \
process_audio_bytes(audio_bytes)
st.session_state.update(
on_record=False,
)
@st.fragment
def system_prompt_fragment():
with st.expander("System Prompt"):
st.text_area(
label="Insert system instructions or background knowledge here.",
label_visibility="collapsed",
disabled=st.session_state.disprompt,
max_chars=5000,
key="system_prompt",
value=DEFAULT_PROMPT,
)
def conversation_section():
chat_message_container = st.container(height=480)
for message in st.session_state.vc_messages:
with chat_message_container.chat_message(message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if message.get("audio", np.array([])).shape[0]:
st.audio(message["audio"], format="audio/wav", sample_rate=16000)
if message.get("content"):
st.write(message["content"])
with st._bottom:
bottom_input_section()
if not st.session_state.vc_audio_base64:
return
if len(st.session_state.vc_messages) >= MAX_VC_ROUNDS * 2:
st.toast(f":warning: max conversation rounds ({MAX_VC_ROUNDS}) reached!")
return
one_time_prompt = DEFAULT_PROMPT
one_time_array = st.session_state.vc_audio_array
one_time_base64 = st.session_state.vc_audio_base64
st.session_state.update(
vc_audio_array=np.array([]),
vc_audio_base64="",
)
with chat_message_container.chat_message("user"):
st.audio(one_time_array, format="audio/wav", sample_rate=16000)
st.session_state.vc_messages.append({"role": "user", "audio": one_time_array})
if not st.session_state.vc_model_messages:
one_time_prompt = st.session_state.system_prompt
else:
st.session_state.vc_model_messages[0]["content"] = change_multimodal_content(
st.session_state.vc_model_messages[0]["content"],
text_input=st.session_state.system_prompt
)
with chat_message_container.chat_message("assistant"):
with st.spinner("Thinking..."):
error_msg, warnings, response = retrive_response_with_ui(
model_name=MODEL_NAMES["audiollm-it"]["vllm_name"],
text_input=one_time_prompt,
array_audio_input=one_time_array,
base64_audio_input=one_time_base64,
stream=True,
history=st.session_state.vc_model_messages
)
st.session_state.vc_messages.append({
"role": "assistant",
"error": error_msg,
"warnings": warnings,
"content": response
})
mm_content = prepare_multimodal_content(one_time_prompt, one_time_base64)
st.session_state.vc_model_messages.extend([
{"role": "user", "content": mm_content},
{"role": "assistant", "content": response}
])
st.session_state.disprompt=False
st.rerun(scope="app")
def voice_chat_page():
init_state_section()
header_section(
component_name="Voice Chat",
description=""" Currently support up to <strong>5 rounds</strong> of conversations.
Feel free to talk about anything.""",
concise_description=" Currently support up to <strong>5 rounds</strong> of conversations.",
icon="🗣️"
)
with st.sidebar:
sidebar_fragment()
system_prompt_fragment()
conversation_section()