Spaces:
Running
Running
implement multi-round for voice chat
Browse files- src/content/agent.py +3 -2
- src/content/common.py +37 -9
- src/content/playground.py +4 -3
- src/content/voice_chat.py +79 -42
- src/generation.py +39 -0
src/content/agent.py
CHANGED
@@ -9,10 +9,11 @@ from src.utils import bytes_to_array, array_to_bytes
|
|
9 |
from src.content.common import (
|
10 |
MODEL_NAMES,
|
11 |
AUDIO_SAMPLES_W_INSTRUCT,
|
12 |
-
|
13 |
init_state_section,
|
14 |
header_section,
|
15 |
sidebar_fragment,
|
|
|
16 |
retrive_response_with_ui
|
17 |
)
|
18 |
|
@@ -132,7 +133,7 @@ def bottom_input_section():
|
|
132 |
st.button(
|
133 |
'Clear',
|
134 |
disabled=st.session_state.disprompt,
|
135 |
-
on_click=lambda:
|
136 |
)
|
137 |
|
138 |
with bottom_cols[1]:
|
|
|
9 |
from src.content.common import (
|
10 |
MODEL_NAMES,
|
11 |
AUDIO_SAMPLES_W_INSTRUCT,
|
12 |
+
AGENT_DIALOGUE_STATES,
|
13 |
init_state_section,
|
14 |
header_section,
|
15 |
sidebar_fragment,
|
16 |
+
reset_states,
|
17 |
retrive_response_with_ui
|
18 |
)
|
19 |
|
|
|
133 |
st.button(
|
134 |
'Clear',
|
135 |
disabled=st.session_state.disprompt,
|
136 |
+
on_click=lambda: reset_states(AGENT_DIALOGUE_STATES)
|
137 |
)
|
138 |
|
139 |
with bottom_cols[1]:
|
src/content/common.py
CHANGED
@@ -13,20 +13,33 @@ from src.retrieval import load_retriever
|
|
13 |
from src.logger import load_logger
|
14 |
|
15 |
|
16 |
-
|
17 |
pg_audio_base64='',
|
18 |
pg_audio_array=np.array([]),
|
19 |
-
pg_messages=[]
|
|
|
|
|
|
|
|
|
20 |
vc_audio_base64='',
|
21 |
vc_audio_array=np.array([]),
|
22 |
-
vc_messages=[],
|
|
|
|
|
|
|
|
|
|
|
23 |
ag_audio_base64='',
|
24 |
ag_audio_array=np.array([]),
|
25 |
ag_visited_query_indices=[],
|
26 |
ag_messages=[],
|
27 |
-
ag_model_messages=[]
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
on_select=False,
|
31 |
on_upload=False,
|
32 |
on_record=False,
|
@@ -34,6 +47,14 @@ DEFAULT_DIALOGUE_STATES = dict(
|
|
34 |
)
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
MODEL_NAMES = OrderedDict({})
|
38 |
|
39 |
|
@@ -329,9 +350,10 @@ def init_state_section():
|
|
329 |
if key not in st.session_state:
|
330 |
st.session_state[key]=copy.deepcopy(value)
|
331 |
|
332 |
-
for
|
333 |
-
|
334 |
-
st.session_state
|
|
|
335 |
|
336 |
|
337 |
def header_section(component_name, description="", concise_description="", icon="🤖"):
|
@@ -375,6 +397,12 @@ def sidebar_fragment():
|
|
375 |
st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
|
376 |
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
def retrive_response_with_ui(
|
379 |
model_name: str,
|
380 |
text_input: str,
|
|
|
13 |
from src.logger import load_logger
|
14 |
|
15 |
|
16 |
+
PLAYGROUND_DIALOGUE_STATES = dict(
|
17 |
pg_audio_base64='',
|
18 |
pg_audio_array=np.array([]),
|
19 |
+
pg_messages=[]
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
VOICE_CHAT_DIALOGUE_STATES = dict(
|
24 |
vc_audio_base64='',
|
25 |
vc_audio_array=np.array([]),
|
26 |
+
vc_messages=[],
|
27 |
+
vc_model_messages=[]
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
AGENT_DIALOGUE_STATES = dict(
|
32 |
ag_audio_base64='',
|
33 |
ag_audio_array=np.array([]),
|
34 |
ag_visited_query_indices=[],
|
35 |
ag_messages=[],
|
36 |
+
ag_model_messages=[]
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
COMMON_DIALOGUE_STATES = dict(
|
41 |
+
disprompt=False,
|
42 |
+
new_prompt="",
|
43 |
on_select=False,
|
44 |
on_upload=False,
|
45 |
on_record=False,
|
|
|
47 |
)
|
48 |
|
49 |
|
50 |
+
DEFAULT_DIALOGUE_STATE_DICTS = [
|
51 |
+
PLAYGROUND_DIALOGUE_STATES,
|
52 |
+
VOICE_CHAT_DIALOGUE_STATES,
|
53 |
+
AGENT_DIALOGUE_STATES,
|
54 |
+
COMMON_DIALOGUE_STATES
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
MODEL_NAMES = OrderedDict({})
|
59 |
|
60 |
|
|
|
350 |
if key not in st.session_state:
|
351 |
st.session_state[key]=copy.deepcopy(value)
|
352 |
|
353 |
+
for states in DEFAULT_DIALOGUE_STATE_DICTS:
|
354 |
+
for key, value in states.items():
|
355 |
+
if key not in st.session_state:
|
356 |
+
st.session_state[key]=copy.deepcopy(value)
|
357 |
|
358 |
|
359 |
def header_section(component_name, description="", concise_description="", icon="🤖"):
|
|
|
397 |
st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
|
398 |
|
399 |
|
400 |
+
def reset_states(*state_dicts):
|
401 |
+
for states in state_dicts:
|
402 |
+
st.session_state.update(copy.deepcopy(states))
|
403 |
+
st.session_state.update(copy.deepcopy(COMMON_DIALOGUE_STATES))
|
404 |
+
|
405 |
+
|
406 |
def retrive_response_with_ui(
|
407 |
model_name: str,
|
408 |
text_input: str,
|
src/content/playground.py
CHANGED
@@ -8,10 +8,11 @@ from src.utils import bytes_to_array, array_to_bytes
|
|
8 |
from src.content.common import (
|
9 |
MODEL_NAMES,
|
10 |
AUDIO_SAMPLES_W_INSTRUCT,
|
11 |
-
|
12 |
init_state_section,
|
13 |
header_section,
|
14 |
sidebar_fragment,
|
|
|
15 |
retrive_response_with_ui
|
16 |
)
|
17 |
|
@@ -126,7 +127,7 @@ def bottom_input_section():
|
|
126 |
st.button(
|
127 |
'Clear',
|
128 |
disabled=st.session_state.disprompt,
|
129 |
-
on_click=lambda:
|
130 |
)
|
131 |
|
132 |
with bottom_cols[1]:
|
@@ -225,7 +226,7 @@ def playground_page():
|
|
225 |
<strong>Paralinguistics</strong> tasks.
|
226 |
This playground currently only support <strong>single-round</strong> conversation.
|
227 |
""",
|
228 |
-
concise_description="
|
229 |
)
|
230 |
|
231 |
with st.sidebar:
|
|
|
8 |
from src.content.common import (
|
9 |
MODEL_NAMES,
|
10 |
AUDIO_SAMPLES_W_INSTRUCT,
|
11 |
+
PLAYGROUND_DIALOGUE_STATES,
|
12 |
init_state_section,
|
13 |
header_section,
|
14 |
sidebar_fragment,
|
15 |
+
reset_states,
|
16 |
retrive_response_with_ui
|
17 |
)
|
18 |
|
|
|
127 |
st.button(
|
128 |
'Clear',
|
129 |
disabled=st.session_state.disprompt,
|
130 |
+
on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES)
|
131 |
)
|
132 |
|
133 |
with bottom_cols[1]:
|
|
|
226 |
<strong>Paralinguistics</strong> tasks.
|
227 |
This playground currently only support <strong>single-round</strong> conversation.
|
228 |
""",
|
229 |
+
concise_description=" This playground currently only support <strong>single-round</strong> conversation."
|
230 |
)
|
231 |
|
232 |
with st.sidebar:
|
src/content/voice_chat.py
CHANGED
@@ -4,20 +4,26 @@ import base64
|
|
4 |
import numpy as np
|
5 |
import streamlit as st
|
6 |
|
7 |
-
from src.generation import
|
8 |
-
|
|
|
|
|
|
|
9 |
from src.content.common import (
|
10 |
MODEL_NAMES,
|
11 |
-
|
12 |
init_state_section,
|
13 |
header_section,
|
14 |
sidebar_fragment,
|
|
|
15 |
retrive_response_with_ui
|
16 |
)
|
|
|
17 |
|
18 |
|
19 |
# TODO: change this.
|
20 |
-
DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply the user in a friendly and helpful way."
|
|
|
21 |
|
22 |
|
23 |
def _update_audio(audio_bytes):
|
@@ -30,13 +36,12 @@ def _update_audio(audio_bytes):
|
|
30 |
|
31 |
|
32 |
def bottom_input_section():
|
33 |
-
st.info(":bulb: Ask something with clear intention.")
|
34 |
bottom_cols = st.columns([0.03, 0.97])
|
35 |
with bottom_cols[0]:
|
36 |
st.button(
|
37 |
'Clear',
|
38 |
disabled=st.session_state.disprompt,
|
39 |
-
on_click=lambda:
|
40 |
)
|
41 |
|
42 |
with bottom_cols[1]:
|
@@ -45,7 +50,6 @@ def bottom_input_section():
|
|
45 |
label_visibility="collapsed",
|
46 |
on_change=lambda: st.session_state.update(
|
47 |
on_record=True,
|
48 |
-
vc_messages=[],
|
49 |
disprompt=True
|
50 |
),
|
51 |
key='record'
|
@@ -56,13 +60,25 @@ def bottom_input_section():
|
|
56 |
_update_audio(audio_bytes)
|
57 |
st.session_state.update(
|
58 |
on_record=False,
|
59 |
-
new_prompt=DEFAULT_PROMPT
|
60 |
)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def conversation_section():
|
|
|
64 |
for message in st.session_state.vc_messages:
|
65 |
-
with
|
66 |
if message.get("error"):
|
67 |
st.error(message["error"])
|
68 |
for warning_msg in message.get("warnings", []):
|
@@ -75,53 +91,74 @@ def conversation_section():
|
|
75 |
with st._bottom:
|
76 |
bottom_input_section()
|
77 |
|
78 |
-
if
|
79 |
-
|
80 |
-
one_time_base64 = st.session_state.vc_audio_base64
|
81 |
-
st.session_state.update(
|
82 |
-
new_prompt="",
|
83 |
-
one_time_array=np.array([]),
|
84 |
-
one_time_base64="",
|
85 |
-
vc_messages=[]
|
86 |
-
)
|
87 |
|
88 |
-
|
89 |
-
|
|
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
|
114 |
def voice_chat_page():
|
115 |
init_state_section()
|
116 |
header_section(
|
117 |
component_name="Voice Chat",
|
118 |
-
description=""" It currently only support <strong>
|
119 |
Feel free to talk about anything.""",
|
120 |
-
concise_description=" It currently only support <strong>
|
121 |
icon="🗣️"
|
122 |
)
|
123 |
|
124 |
with st.sidebar:
|
125 |
sidebar_fragment()
|
126 |
|
|
|
127 |
conversation_section()
|
|
|
4 |
import numpy as np
|
5 |
import streamlit as st
|
6 |
|
7 |
+
from src.generation import (
|
8 |
+
MAX_AUDIO_LENGTH,
|
9 |
+
prepare_multimodal_content,
|
10 |
+
change_multimodal_content
|
11 |
+
)
|
12 |
from src.content.common import (
|
13 |
MODEL_NAMES,
|
14 |
+
VOICE_CHAT_DIALOGUE_STATES,
|
15 |
init_state_section,
|
16 |
header_section,
|
17 |
sidebar_fragment,
|
18 |
+
reset_states,
|
19 |
retrive_response_with_ui
|
20 |
)
|
21 |
+
from src.utils import bytes_to_array, array_to_bytes
|
22 |
|
23 |
|
24 |
# TODO: change this.
|
25 |
+
DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply to the user in a friendly and helpful way."
|
26 |
+
MAX_VC_ROUNDS = 5
|
27 |
|
28 |
|
29 |
def _update_audio(audio_bytes):
|
|
|
36 |
|
37 |
|
38 |
def bottom_input_section():
|
|
|
39 |
bottom_cols = st.columns([0.03, 0.97])
|
40 |
with bottom_cols[0]:
|
41 |
st.button(
|
42 |
'Clear',
|
43 |
disabled=st.session_state.disprompt,
|
44 |
+
on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES)
|
45 |
)
|
46 |
|
47 |
with bottom_cols[1]:
|
|
|
50 |
label_visibility="collapsed",
|
51 |
on_change=lambda: st.session_state.update(
|
52 |
on_record=True,
|
|
|
53 |
disprompt=True
|
54 |
),
|
55 |
key='record'
|
|
|
60 |
_update_audio(audio_bytes)
|
61 |
st.session_state.update(
|
62 |
on_record=False,
|
|
|
63 |
)
|
64 |
|
65 |
|
66 |
+
@st.fragment
|
67 |
+
def system_prompt_fragment():
|
68 |
+
with st.expander("System Prompt"):
|
69 |
+
st.text_area(
|
70 |
+
label="Insert system instructions or background knowledge here.",
|
71 |
+
label_visibility="collapsed",
|
72 |
+
max_chars=5000,
|
73 |
+
key="system_prompt",
|
74 |
+
value=DEFAULT_PROMPT,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
def conversation_section():
|
79 |
+
chat_message_container = st.container(height=480)
|
80 |
for message in st.session_state.vc_messages:
|
81 |
+
with chat_message_container.chat_message(message["role"]):
|
82 |
if message.get("error"):
|
83 |
st.error(message["error"])
|
84 |
for warning_msg in message.get("warnings", []):
|
|
|
91 |
with st._bottom:
|
92 |
bottom_input_section()
|
93 |
|
94 |
+
if not st.session_state.vc_audio_base64:
|
95 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
if len(st.session_state.vc_messages) >= MAX_VC_ROUNDS * 2:
|
98 |
+
st.toast(f":warning: max conversation rounds ({MAX_VC_ROUNDS}) reached!")
|
99 |
+
return
|
100 |
|
101 |
+
one_time_prompt = DEFAULT_PROMPT
|
102 |
+
one_time_array = st.session_state.vc_audio_array
|
103 |
+
one_time_base64 = st.session_state.vc_audio_base64
|
104 |
+
st.session_state.update(
|
105 |
+
vc_audio_array=np.array([]),
|
106 |
+
vc_audio_base64="",
|
107 |
+
)
|
108 |
+
|
109 |
+
with chat_message_container.chat_message("user"):
|
110 |
+
st.audio(one_time_array, format="audio/wav", sample_rate=16000)
|
111 |
|
112 |
+
st.session_state.vc_messages.append({"role": "user", "audio": one_time_array})
|
113 |
+
|
114 |
+
if not st.session_state.vc_model_messages:
|
115 |
+
one_time_prompt = st.session_state.system_prompt
|
116 |
+
else:
|
117 |
+
st.session_state.vc_model_messages[0]["content"] = change_multimodal_content(
|
118 |
+
st.session_state.vc_model_messages[0]["content"],
|
119 |
+
text_input=st.session_state.system_prompt
|
120 |
+
)
|
121 |
+
|
122 |
+
with chat_message_container.chat_message("assistant"):
|
123 |
+
with st.spinner("Thinking..."):
|
124 |
+
error_msg, warnings, response = retrive_response_with_ui(
|
125 |
+
model_name=MODEL_NAMES["audiollm-it"]["vllm_name"],
|
126 |
+
text_input=one_time_prompt,
|
127 |
+
array_audio_input=one_time_array,
|
128 |
+
base64_audio_input=one_time_base64,
|
129 |
+
stream=True,
|
130 |
+
history=st.session_state.vc_model_messages
|
131 |
+
)
|
132 |
+
|
133 |
+
st.session_state.vc_messages.append({
|
134 |
+
"role": "assistant",
|
135 |
+
"error": error_msg,
|
136 |
+
"warnings": warnings,
|
137 |
+
"content": response
|
138 |
+
})
|
139 |
+
|
140 |
+
mm_content = prepare_multimodal_content(one_time_prompt, one_time_base64)
|
141 |
+
st.session_state.vc_model_messages.extend([
|
142 |
+
{"role": "user", "content": mm_content},
|
143 |
+
{"role": "assistant", "content": response}
|
144 |
+
])
|
145 |
+
|
146 |
+
st.session_state.disprompt=False
|
147 |
+
st.rerun(scope="app")
|
148 |
|
149 |
|
150 |
def voice_chat_page():
|
151 |
init_state_section()
|
152 |
header_section(
|
153 |
component_name="Voice Chat",
|
154 |
+
description=""" It currently only support up to <strong>5 rounds</strong> of conversations.
|
155 |
Feel free to talk about anything.""",
|
156 |
+
concise_description=" It currently only support up to <strong>5 rounds</strong> of conversations.",
|
157 |
icon="🗣️"
|
158 |
)
|
159 |
|
160 |
with st.sidebar:
|
161 |
sidebar_fragment()
|
162 |
|
163 |
+
system_prompt_fragment()
|
164 |
conversation_section()
|
src/generation.py
CHANGED
@@ -40,6 +40,45 @@ def load_model() -> Dict:
|
|
40 |
return name_to_client_mapper
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def _retrive_response(
|
44 |
model: str,
|
45 |
text_input: str,
|
|
|
40 |
return name_to_client_mapper
|
41 |
|
42 |
|
43 |
+
def prepare_multimodal_content(text_input, base64_audio_input):
|
44 |
+
return [
|
45 |
+
{
|
46 |
+
"type": "text",
|
47 |
+
"text": f"Text instruction: {text_input}"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"type": "audio_url",
|
51 |
+
"audio_url": {
|
52 |
+
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
53 |
+
},
|
54 |
+
},
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
def change_multimodal_content(
|
59 |
+
original_content,
|
60 |
+
text_input="",
|
61 |
+
base64_audio_input=""):
|
62 |
+
|
63 |
+
# Since python 3.7 dictionary is ordered.
|
64 |
+
if text_input:
|
65 |
+
original_content[0] = {
|
66 |
+
"type": "text",
|
67 |
+
"text": f"Text instruction: {text_input}"
|
68 |
+
}
|
69 |
+
|
70 |
+
if base64_audio_input:
|
71 |
+
original_content[1] = {
|
72 |
+
"type": "audio_url",
|
73 |
+
"audio_url": {
|
74 |
+
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
return original_content
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
def _retrive_response(
|
83 |
model: str,
|
84 |
text_input: str,
|