Spaces:
Running
Running
import os | |
import re | |
import copy | |
import base64 | |
import requests | |
import itertools | |
from collections import OrderedDict | |
from typing import List, Optional | |
import numpy as np | |
import streamlit as st | |
from src.logger import load_logger | |
from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription | |
from src.generation import FIXED_GENERATION_CONFIG, MAX_AUDIO_LENGTH | |
API_BASE_URL = os.getenv('API_BASE_URL') | |
PLAYGROUND_DIALOGUE_STATES = dict( | |
pg_audio_base64='', | |
pg_audio_array=np.array([]), | |
pg_messages=[] | |
) | |
VOICE_CHAT_DIALOGUE_STATES = dict( | |
vc_audio_base64='', | |
vc_audio_array=np.array([]), | |
vc_messages=[], | |
vc_model_messages=[] | |
) | |
AGENT_DIALOGUE_STATES = dict( | |
ag_audio_base64='', | |
ag_audio_array=np.array([]), | |
ag_visited_query_indices=[], | |
ag_messages=[], | |
ag_model_messages=[] | |
) | |
COMMON_DIALOGUE_STATES = dict( | |
disprompt=False, | |
new_prompt="", | |
new_vi_array=np.array([]), | |
new_vi_base64="", | |
on_select=False, | |
on_upload=False, | |
on_record=False, | |
on_select_quick_action=False, | |
on_record_voice_instruction=False | |
) | |
DEFAULT_DIALOGUE_STATE_DICTS = [ | |
PLAYGROUND_DIALOGUE_STATES, | |
VOICE_CHAT_DIALOGUE_STATES, | |
AGENT_DIALOGUE_STATES, | |
COMMON_DIALOGUE_STATES | |
] | |
MODEL_NAMES = OrderedDict({ | |
"llm": { | |
"vllm_name": "MERaLiON-Gemma", | |
"model_name": "MERaLiON-Gemma", | |
"ui_name": "MERaLiON-Gemma" | |
}, | |
"audiollm": { | |
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION", | |
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION", | |
"ui_name": "MERaLiON-AudioLLM" | |
}, | |
"audiollm-it": { | |
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION-it", | |
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION-it", | |
"ui_name": "MERaLiON-AudioLLM-Instruction-Tuning" | |
} | |
}) | |
AUDIO_SAMPLES_W_INSTRUCT = { | |
"song_1": { | |
"apperance": "Instruction Following Demo: Music Question Answering", | |
"instructions": [ | |
"Please provide a detailed description of the song in both English and Chinese." | |
] | |
}, | |
"7_ASR_IMDA_PART3_30_ASR_v2_2269": { | |
"apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Need this talk written down, please." | |
] | |
}, | |
"11_ASR_IMDA_PART4_30_ASR_v2_3771": { | |
"apperance": "11. Automatic Speech Recognition task: conversation with Singlish code-switch", | |
"instructions": [ | |
"Write out the dialogue as text." | |
] | |
}, | |
"12_ASR_IMDA_PART4_30_ASR_v2_103": { | |
"apperance": "12. Automatic Speech Recognition task: conversation with Singlish code-switch", | |
"instructions": [ | |
"Write out the dialogue as text." | |
] | |
}, | |
"17_ASR_IMDA_PART6_30_ASR_v2_1413": { | |
"apperance": "17. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Record the spoken word in text form." | |
] | |
}, | |
"32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572": { | |
"apperance": "32. Spoken Question Answering task: general speech", | |
"instructions": [ | |
"What does the man think the woman should do at 4:00." | |
] | |
}, | |
"33_SQA_IMDA_PART3_30_SQA_V2_2310": { | |
"apperance": "33. Spoken Question Answering task: conversation in Singapore accent", | |
"instructions": [ | |
"Does Speaker2's wife cook for Speaker2 when they are at home." | |
] | |
}, | |
"34_SQA_IMDA_PART3_30_SQA_V2_3621": { | |
"apperance": "34. Spoken Question Answering task: conversation in Singapore accent", | |
"instructions": [ | |
"Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language." | |
] | |
}, | |
"35_SQA_IMDA_PART3_30_SQA_V2_4062": { | |
"apperance": "35. Spoken Question Answering task: conversation in Singapore accent", | |
"instructions": [ | |
"What is the color of the vase mentioned in the dialogue." | |
] | |
}, | |
"36_DS_IMDA_PART4_30_DS_V2_849": { | |
"apperance": "36. Spoken Dialogue Summarization task: conversation with Singlish code-switch", | |
"instructions": [ | |
"Condense the dialogue into a concise summary highlighting major topics and conclusions." | |
] | |
}, | |
"39_Paralingual_IEMOCAP_ER_V2_91": { | |
"apperance": "39. Paralinguistics task: general speech", | |
"instructions": [ | |
"Based on the speaker's speech patterns, what do you think they are feeling." | |
] | |
}, | |
"40_Paralingual_IEMOCAP_ER_V2_567": { | |
"apperance": "40. Paralinguistics task: general speech", | |
"instructions": [ | |
"Based on the speaker's speech patterns, what do you think they are feeling." | |
] | |
}, | |
"42_Paralingual_IEMOCAP_GR_V2_320": { | |
"apperance": "42. Paralinguistics task: general speech", | |
"instructions": [ | |
"Is it possible for you to identify whether the speaker in this recording is male or female." | |
] | |
}, | |
"47_Paralingual_IMDA_PART3_30_NR_V2_10479": { | |
"apperance": "47. Paralinguistics task: conversation in Singapore accent", | |
"instructions": [ | |
"Can you guess which ethnic group this person is from based on their accent." | |
] | |
}, | |
"49_Paralingual_MELD_ER_V2_676": { | |
"apperance": "49. Paralinguistics task: general speech", | |
"instructions": [ | |
"What emotions do you think the speaker is expressing." | |
] | |
}, | |
"50_Paralingual_MELD_ER_V2_692": { | |
"apperance": "50. Paralinguistics task: general speech", | |
"instructions": [ | |
"Based on the speaker's speech patterns, what do you think they are feeling." | |
] | |
}, | |
"51_Paralingual_VOXCELEB1_GR_V2_2148": { | |
"apperance": "51. Paralinguistics task: general speech", | |
"instructions": [ | |
"May I know the gender of the speaker." | |
] | |
}, | |
"53_Paralingual_VOXCELEB1_NR_V2_2286": { | |
"apperance": "53. Paralinguistics task: general speech", | |
"instructions": [ | |
"What's the nationality identity of the speaker." | |
] | |
}, | |
"55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2": { | |
"apperance": "55. Spoken Question Answering task: general speech", | |
"instructions": [ | |
"What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth." | |
] | |
}, | |
"56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415": { | |
"apperance": "56. Spoken Question Answering task: general speech", | |
"instructions": [ | |
"Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore." | |
] | |
}, | |
"57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460": { | |
"apperance": "57. Spoken Question Answering task: general speech", | |
"instructions": [ | |
"How does the author respond to parents' worries about masks in schools." | |
] | |
}, | |
"1_ASR_IMDA_PART1_ASR_v2_141": { | |
"apperance": "1. Automatic Speech Recognition task: phonetically balanced reading", | |
"instructions": [ | |
"Turn the spoken language into a text format.", | |
"Please translate the content into Chinese." | |
] | |
}, | |
"2_ASR_IMDA_PART1_ASR_v2_2258": { | |
"apperance": "2. Automatic Speech Recognition task: phonetically balanced reading", | |
"instructions": [ | |
"Turn the spoken language into a text format.", | |
"Please translate the content into Chinese." | |
] | |
}, | |
"3_ASR_IMDA_PART1_ASR_v2_2265": { | |
"apperance": "3. Automatic Speech Recognition task: phonetically balanced reading", | |
"instructions": [ | |
"Turn the spoken language into a text format." | |
] | |
}, | |
"4_ASR_IMDA_PART2_ASR_v2_999": { | |
"apperance": "4. Automatic Speech Recognition task: reading in Singapore context", | |
"instructions": [ | |
"Translate the spoken words into text format." | |
] | |
}, | |
"5_ASR_IMDA_PART2_ASR_v2_2241": { | |
"apperance": "5. Automatic Speech Recognition task: reading in Singapore context", | |
"instructions": [ | |
"Translate the spoken words into text format." | |
] | |
}, | |
"6_ASR_IMDA_PART2_ASR_v2_3409": { | |
"apperance": "6. Automatic Speech Recognition task: reading in Singapore context", | |
"instructions": [ | |
"Translate the spoken words into text format." | |
] | |
}, | |
"8_ASR_IMDA_PART3_30_ASR_v2_1698": { | |
"apperance": "8. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Need this talk written down, please." | |
] | |
}, | |
"9_ASR_IMDA_PART3_30_ASR_v2_2474": { | |
"apperance": "9. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Need this talk written down, please." | |
] | |
}, | |
"10_ASR_IMDA_PART4_30_ASR_v2_1527": { | |
"apperance": "10. Automatic Speech Recognition task: conversation with Singlish code-switch", | |
"instructions": [ | |
"Write out the dialogue as text." | |
] | |
}, | |
"13_ASR_IMDA_PART5_30_ASR_v2_1446": { | |
"apperance": "13. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Translate this vocal recording into a textual format." | |
] | |
}, | |
"14_ASR_IMDA_PART5_30_ASR_v2_2281": { | |
"apperance": "14. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Translate this vocal recording into a textual format." | |
] | |
}, | |
"15_ASR_IMDA_PART5_30_ASR_v2_4388": { | |
"apperance": "15. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Translate this vocal recording into a textual format." | |
] | |
}, | |
"16_ASR_IMDA_PART6_30_ASR_v2_576": { | |
"apperance": "16. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Record the spoken word in text form." | |
] | |
}, | |
"18_ASR_IMDA_PART6_30_ASR_v2_2834": { | |
"apperance": "18. Automatic Speech Recognition task: conversation in Singapore accent", | |
"instructions": [ | |
"Record the spoken word in text form." | |
] | |
}, | |
"19_ASR_AIShell_zh_ASR_v2_5044": { | |
"apperance": "19. Automatic Speech Recognition task: speech in Chinese ", | |
"instructions": [ | |
"Transform the oral presentation into a text document." | |
] | |
}, | |
"20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833": { | |
"apperance": "20. Automatic Speech Recognition task: general speech", | |
"instructions": [ | |
"Please provide a written transcription of the speech." | |
] | |
}, | |
"25_ST_COVOST2_ZH-CN_EN_ST_V2_4567": { | |
"apperance": "25. Speech Translation task: Chinese to English", | |
"instructions": [ | |
"Please translate the given speech to English." | |
] | |
}, | |
"26_ST_COVOST2_EN_ZH-CN_ST_V2_5422": { | |
"apperance": "26. Speech Translation task: English to Chinese", | |
"instructions": [ | |
"Please translate the given speech to Chinese." | |
] | |
}, | |
"27_ST_COVOST2_EN_ZH-CN_ST_V2_6697": { | |
"apperance": "27. Speech Translation task: English to Chinese", | |
"instructions": [ | |
"Please translate the given speech to Chinese." | |
] | |
}, | |
"28_SI_ALPACA-GPT4-AUDIO_SI_V2_299": { | |
"apperance": "28. Speech Instruction task: general speech", | |
"instructions": [ | |
"Please follow the instruction in the speech." | |
] | |
}, | |
"29_SI_ALPACA-GPT4-AUDIO_SI_V2_750": { | |
"apperance": "29. Speech Instruction task: general speech", | |
"instructions": [ | |
"Please follow the instruction in the speech." | |
] | |
}, | |
"30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454": { | |
"apperance": "30. Speech Instruction task: general speech", | |
"instructions": [ | |
"Please follow the instruction in the speech." | |
] | |
}, | |
"female_pilot#1": { | |
"apperance": "Female Pilot Interview: Transcription", | |
"instructions": [ | |
"Please transcribe the speech" | |
] | |
}, | |
"female_pilot#2": { | |
"apperance": "Female Pilot Interview: Aircraft name", | |
"instructions": [ | |
"What does 大力士 mean in the conversation" | |
] | |
}, | |
"female_pilot#3": { | |
"apperance": "Female Pilot Interview: Air Force Personnel Count", | |
"instructions": [ | |
"How many air force personnel are there?" | |
] | |
}, | |
"female_pilot#4": { | |
"apperance": "Female Pilot Interview: Air Force Personnel Name", | |
"instructions": [ | |
"Can you tell me the names of the two pilots?" | |
] | |
}, | |
"female_pilot#5": { | |
"apperance": "Female Pilot Interview: Conversation Mood", | |
"instructions": [ | |
"What is the mood of the conversation?" | |
] | |
} | |
} | |
def reset_states(*state_dicts): | |
for states in state_dicts: | |
st.session_state.update(copy.deepcopy(states)) | |
st.session_state.update(copy.deepcopy(COMMON_DIALOGUE_STATES)) | |
def process_audio_bytes(audio_bytes): | |
origin_audio_array = bytes_to_array(audio_bytes) | |
truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000] | |
truncated_audio_bytes = array_to_bytes(truncated_audio_array) | |
audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8') | |
return origin_audio_array, audio_base64 | |
def update_voice_instruction_state(voice_bytes): | |
st.session_state.new_vi_array, st.session_state.new_vi_base64 = \ | |
process_audio_bytes(voice_bytes) | |
def init_state_section(): | |
st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide') | |
st.markdown( | |
( | |
'<style>' + \ | |
open('./style/app_style.css').read() + \ | |
open('./style/normal_window.css').read() + \ | |
open('./style/small_window.css').read() + \ | |
'</style>' | |
), | |
unsafe_allow_html=True | |
) | |
if "logger" not in st.session_state: | |
st.session_state.logger = load_logger() | |
st.session_state.session_id = st.session_state.logger.register_session() | |
for key, value in FIXED_GENERATION_CONFIG.items(): | |
if key not in st.session_state: | |
st.session_state[key]=copy.deepcopy(value) | |
for states in DEFAULT_DIALOGUE_STATE_DICTS: | |
for key, value in states.items(): | |
if key not in st.session_state: | |
st.session_state[key]=copy.deepcopy(value) | |
def header_section(component_name, description="", concise_description="", icon="🤖"): | |
st.markdown( | |
f"<h1 style='text-align: center;'>MERaLiON-AudioLLM {component_name} {icon}</h1>", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
f"""<div class="main-intro-normal-window"> | |
<p>This {component_name.lower()} is based on | |
<a href="https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION" | |
target="_blank" rel="noopener noreferrer"> MERaLiON-AudioLLM</a>, | |
developed by I2R, A*STAR, in collaboration with AISG, Singapore. | |
{description}</p></div>""", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
f"""<div class="main-intro-small-window"> | |
<p>This {component_name.lower()} is based on | |
<a href="https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION" | |
target="_blank" rel="noopener noreferrer"> MERaLiON-AudioLLM</a>.{concise_description}</p></div>""", | |
unsafe_allow_html=True | |
) | |
def sidebar_fragment(): | |
with st.container(height=256, border=False): | |
st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground") | |
st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Cascade System") | |
st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ End-to-End Voice Chat") | |
st.divider() | |
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.1, key='temperature') | |
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=0.9, key='top_p') | |
st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty") | |
def successful_example_section(audio_sample_names, audio_array_state, audio_base64_state, restore_state={}): | |
st.markdown(":fire: **Successful Tasks and Examples**") | |
sample_name = st.selectbox( | |
label="**Select Audio:**", | |
label_visibility="collapsed", | |
options=audio_sample_names, | |
format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"], | |
index=None, | |
placeholder="Select an audio sample:", | |
on_change=lambda: st.session_state.update( | |
on_select=True, | |
disprompt=True, | |
**copy.deepcopy(restore_state) | |
), | |
key='select') | |
if sample_name and st.session_state.on_select: | |
file_name = sample_name.split("#")[0] | |
audio_bytes = open(f"audio_samples/{file_name}.wav", "rb").read() | |
st.session_state.update( | |
on_select=False, | |
new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0] | |
) | |
st.session_state[audio_array_state], st.session_state[audio_base64_state] = \ | |
process_audio_bytes(audio_bytes) | |
st.rerun(scope="app") | |
def audio_attach_dialogue(audio_array_state, audio_base64_state, restore_state={}): | |
st.markdown("**Upload**") | |
uploaded_file = st.file_uploader( | |
label="**Upload Audio:**", | |
label_visibility="collapsed", | |
type=['wav', 'mp3'], | |
on_change=lambda: st.session_state.update( | |
on_upload=True, | |
**copy.deepcopy(restore_state) | |
), | |
key='upload' | |
) | |
if uploaded_file and st.session_state.on_upload: | |
audio_bytes = uploaded_file.read() | |
st.session_state[audio_array_state], st.session_state[audio_base64_state] = \ | |
process_audio_bytes(audio_bytes) | |
st.session_state.on_upload = False | |
st.rerun() | |
st.markdown("**Record**") | |
uploaded_file = st.audio_input( | |
label="**Record Audio:**", | |
label_visibility="collapsed", | |
on_change=lambda: st.session_state.update( | |
on_record=True, | |
**copy.deepcopy(restore_state) | |
), | |
key='record' | |
) | |
if uploaded_file and st.session_state.on_record: | |
audio_bytes = uploaded_file.read() | |
st.session_state[audio_array_state], st.session_state[audio_base64_state] = \ | |
process_audio_bytes(audio_bytes) | |
st.session_state.on_record = False | |
st.rerun() | |
def retrive_response_with_ui( | |
model_name: str, | |
text_input: str, | |
array_audio_input: np.ndarray, | |
base64_audio_input: str, | |
prefix: str = "", | |
stream: bool = True, | |
normalise_response: bool = False, | |
history: Optional[List] = None, | |
show_warning: bool = True, | |
**kwargs | |
): | |
if history is None: | |
history = [] | |
# Prepare request data | |
request_data = { | |
"text_input": str(text_input), | |
"model_name": str(model_name), | |
"array_audio_input": array_audio_input.tolist(), # Convert numpy array to list | |
"base64_audio_input": str(base64_audio_input) if base64_audio_input else None, | |
"history": list(history) if history else None, | |
"stream": bool(stream), | |
"max_completion_tokens": int(st.session_state.max_completion_tokens), | |
"temperature": float(st.session_state.temperature), | |
"top_p": float(st.session_state.top_p), | |
"repetition_penalty": float(st.session_state.repetition_penalty), | |
"top_k": int(st.session_state.top_k), | |
"length_penalty": float(st.session_state.length_penalty), | |
"seed": int(st.session_state.seed), | |
"extra_params": {} | |
} | |
# print(request_data) | |
# print(model_name) | |
error_msg = "" | |
warnings = [] | |
response = "" | |
try: | |
if stream: | |
# Streaming response | |
response_stream = requests.post(f"{API_BASE_URL}chat", json=request_data, stream=True) | |
response_stream.raise_for_status() | |
response_obj = itertools.chain([prefix], (chunk.decode() for chunk in response_stream)) | |
response = st.write_stream(response_obj) | |
else: | |
# Non-streaming response | |
api_response = requests.post(f"{API_BASE_URL}chat", json=request_data) | |
api_response.raise_for_status() | |
result = api_response.json() | |
if "warnings" in result: | |
warnings = result["warnings"] | |
response = result.get("response", "") | |
if normalise_response: | |
response = postprocess_voice_transcription(response) | |
response = prefix + response | |
st.write(response) | |
except requests.exceptions.RequestException as e: | |
error_msg = re.sub("[a-zA-Z0-9_\-.]+\.com", "<url>", str(e)) | |
error_msg = f"API request failed: {error_msg}" | |
st.error(error_msg) | |
if show_warning: | |
for warning_msg in warnings: | |
st.warning(warning_msg) | |
st.session_state.logger.register_query( | |
session_id=st.session_state.session_id, | |
base64_audio=base64_audio_input, | |
text_input=text_input, | |
history=history, | |
params=request_data["extra_params"], | |
response=response, | |
warnings=warnings, | |
error_msg=error_msg | |
) | |
return error_msg, warnings, response |