Spaces:
Running
Running
import numpy as np | |
import streamlit as st | |
from src.generation import ( | |
prepare_multimodal_content, | |
change_multimodal_content | |
) | |
from src.content.common import ( | |
MODEL_NAMES, | |
VOICE_CHAT_DIALOGUE_STATES, | |
reset_states, | |
process_audio_bytes, | |
init_state_section, | |
header_section, | |
sidebar_fragment, | |
retrive_response_with_ui | |
) | |
# TODO: change this. | |
DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply to the user in a friendly and helpful way." | |
MAX_VC_ROUNDS = 5 | |
def bottom_input_section(): | |
bottom_cols = st.columns([0.03, 0.97]) | |
with bottom_cols[0]: | |
st.button( | |
':material/delete:', | |
disabled=st.session_state.disprompt, | |
on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES) | |
) | |
with bottom_cols[1]: | |
uploaded_file = st.audio_input( | |
label="record audio", | |
label_visibility="collapsed", | |
disabled=st.session_state.disprompt, | |
on_change=lambda: st.session_state.update( | |
on_record=True, | |
disprompt=True | |
), | |
key='record' | |
) | |
if uploaded_file and st.session_state.on_record: | |
audio_bytes = uploaded_file.read() | |
st.session_state.vc_audio_array, st.session_state.vc_audio_base64 = \ | |
process_audio_bytes(audio_bytes) | |
st.session_state.update( | |
on_record=False, | |
) | |
def system_prompt_fragment(): | |
with st.expander("System Prompt"): | |
st.text_area( | |
label="Insert system instructions or background knowledge here.", | |
label_visibility="collapsed", | |
disabled=st.session_state.disprompt, | |
max_chars=5000, | |
key="system_prompt", | |
value=DEFAULT_PROMPT, | |
) | |
def conversation_section(): | |
chat_message_container = st.container(height=480) | |
for message in st.session_state.vc_messages: | |
with chat_message_container.chat_message(message["role"]): | |
if message.get("error"): | |
st.error(message["error"]) | |
for warning_msg in message.get("warnings", []): | |
st.warning(warning_msg) | |
if message.get("audio", np.array([])).shape[0]: | |
st.audio(message["audio"], format="audio/wav", sample_rate=16000) | |
if message.get("content"): | |
st.write(message["content"]) | |
with st._bottom: | |
bottom_input_section() | |
if not st.session_state.vc_audio_base64: | |
return | |
if len(st.session_state.vc_messages) >= MAX_VC_ROUNDS * 2: | |
st.toast(f":warning: max conversation rounds ({MAX_VC_ROUNDS}) reached!") | |
return | |
one_time_prompt = DEFAULT_PROMPT | |
one_time_array = st.session_state.vc_audio_array | |
one_time_base64 = st.session_state.vc_audio_base64 | |
st.session_state.update( | |
vc_audio_array=np.array([]), | |
vc_audio_base64="", | |
) | |
with chat_message_container.chat_message("user"): | |
st.audio(one_time_array, format="audio/wav", sample_rate=16000) | |
st.session_state.vc_messages.append({"role": "user", "audio": one_time_array}) | |
if not st.session_state.vc_model_messages: | |
one_time_prompt = st.session_state.system_prompt | |
else: | |
st.session_state.vc_model_messages[0]["content"] = change_multimodal_content( | |
st.session_state.vc_model_messages[0]["content"], | |
text_input=st.session_state.system_prompt | |
) | |
with chat_message_container.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
error_msg, warnings, response = retrive_response_with_ui( | |
model_name=MODEL_NAMES["audiollm-it"]["vllm_name"], | |
text_input=one_time_prompt, | |
array_audio_input=one_time_array, | |
base64_audio_input=one_time_base64, | |
stream=True, | |
history=st.session_state.vc_model_messages | |
) | |
st.session_state.vc_messages.append({ | |
"role": "assistant", | |
"error": error_msg, | |
"warnings": warnings, | |
"content": response | |
}) | |
mm_content = prepare_multimodal_content(one_time_prompt, one_time_base64) | |
st.session_state.vc_model_messages.extend([ | |
{"role": "user", "content": mm_content}, | |
{"role": "assistant", "content": response} | |
]) | |
st.session_state.disprompt=False | |
st.rerun(scope="app") | |
def voice_chat_page(): | |
init_state_section() | |
header_section( | |
component_name="Voice Chat", | |
description=""" Currently support up to <strong>5 rounds</strong> of conversations. | |
Feel free to talk about anything.""", | |
concise_description=" Currently support up to <strong>5 rounds</strong> of conversations.", | |
icon="🗣️" | |
) | |
with st.sidebar: | |
sidebar_fragment() | |
system_prompt_fragment() | |
conversation_section() |