YingxuHe commited on
Commit
89ed0ae
·
1 Parent(s): 8573823

implement multi-round for voice chat

Browse files
src/content/agent.py CHANGED
@@ -9,10 +9,11 @@ from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  MODEL_NAMES,
11
  AUDIO_SAMPLES_W_INSTRUCT,
12
- DEFAULT_DIALOGUE_STATES,
13
  init_state_section,
14
  header_section,
15
  sidebar_fragment,
 
16
  retrive_response_with_ui
17
  )
18
 
@@ -132,7 +133,7 @@ def bottom_input_section():
132
  st.button(
133
  'Clear',
134
  disabled=st.session_state.disprompt,
135
- on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES))
136
  )
137
 
138
  with bottom_cols[1]:
 
9
  from src.content.common import (
10
  MODEL_NAMES,
11
  AUDIO_SAMPLES_W_INSTRUCT,
12
+ AGENT_DIALOGUE_STATES,
13
  init_state_section,
14
  header_section,
15
  sidebar_fragment,
16
+ reset_states,
17
  retrive_response_with_ui
18
  )
19
 
 
133
  st.button(
134
  'Clear',
135
  disabled=st.session_state.disprompt,
136
+ on_click=lambda: reset_states(AGENT_DIALOGUE_STATES)
137
  )
138
 
139
  with bottom_cols[1]:
src/content/common.py CHANGED
@@ -13,20 +13,33 @@ from src.retrieval import load_retriever
13
  from src.logger import load_logger
14
 
15
 
16
- DEFAULT_DIALOGUE_STATES = dict(
17
  pg_audio_base64='',
18
  pg_audio_array=np.array([]),
19
- pg_messages=[],
 
 
 
 
20
  vc_audio_base64='',
21
  vc_audio_array=np.array([]),
22
- vc_messages=[],
 
 
 
 
 
23
  ag_audio_base64='',
24
  ag_audio_array=np.array([]),
25
  ag_visited_query_indices=[],
26
  ag_messages=[],
27
- ag_model_messages=[],
28
- disprompt = False,
29
- new_prompt = "",
 
 
 
 
30
  on_select=False,
31
  on_upload=False,
32
  on_record=False,
@@ -34,6 +47,14 @@ DEFAULT_DIALOGUE_STATES = dict(
34
  )
35
 
36
 
 
 
 
 
 
 
 
 
37
  MODEL_NAMES = OrderedDict({})
38
 
39
 
@@ -329,9 +350,10 @@ def init_state_section():
329
  if key not in st.session_state:
330
  st.session_state[key]=copy.deepcopy(value)
331
 
332
- for key, value in DEFAULT_DIALOGUE_STATES.items():
333
- if key not in st.session_state:
334
- st.session_state[key]=copy.deepcopy(value)
 
335
 
336
 
337
  def header_section(component_name, description="", concise_description="", icon="🤖"):
@@ -375,6 +397,12 @@ def sidebar_fragment():
375
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
376
 
377
 
 
 
 
 
 
 
378
  def retrive_response_with_ui(
379
  model_name: str,
380
  text_input: str,
 
13
  from src.logger import load_logger
14
 
15
 
16
+ PLAYGROUND_DIALOGUE_STATES = dict(
17
  pg_audio_base64='',
18
  pg_audio_array=np.array([]),
19
+ pg_messages=[]
20
+ )
21
+
22
+
23
+ VOICE_CHAT_DIALOGUE_STATES = dict(
24
  vc_audio_base64='',
25
  vc_audio_array=np.array([]),
26
+ vc_messages=[],
27
+ vc_model_messages=[]
28
+ )
29
+
30
+
31
+ AGENT_DIALOGUE_STATES = dict(
32
  ag_audio_base64='',
33
  ag_audio_array=np.array([]),
34
  ag_visited_query_indices=[],
35
  ag_messages=[],
36
+ ag_model_messages=[]
37
+ )
38
+
39
+
40
+ COMMON_DIALOGUE_STATES = dict(
41
+ disprompt=False,
42
+ new_prompt="",
43
  on_select=False,
44
  on_upload=False,
45
  on_record=False,
 
47
  )
48
 
49
 
50
+ DEFAULT_DIALOGUE_STATE_DICTS = [
51
+ PLAYGROUND_DIALOGUE_STATES,
52
+ VOICE_CHAT_DIALOGUE_STATES,
53
+ AGENT_DIALOGUE_STATES,
54
+ COMMON_DIALOGUE_STATES
55
+ ]
56
+
57
+
58
  MODEL_NAMES = OrderedDict({})
59
 
60
 
 
350
  if key not in st.session_state:
351
  st.session_state[key]=copy.deepcopy(value)
352
 
353
+ for states in DEFAULT_DIALOGUE_STATE_DICTS:
354
+ for key, value in states.items():
355
+ if key not in st.session_state:
356
+ st.session_state[key]=copy.deepcopy(value)
357
 
358
 
359
  def header_section(component_name, description="", concise_description="", icon="🤖"):
 
397
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
398
 
399
 
400
+ def reset_states(*state_dicts):
401
+ for states in state_dicts:
402
+ st.session_state.update(copy.deepcopy(states))
403
+ st.session_state.update(copy.deepcopy(COMMON_DIALOGUE_STATES))
404
+
405
+
406
  def retrive_response_with_ui(
407
  model_name: str,
408
  text_input: str,
src/content/playground.py CHANGED
@@ -8,10 +8,11 @@ from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
9
  MODEL_NAMES,
10
  AUDIO_SAMPLES_W_INSTRUCT,
11
- DEFAULT_DIALOGUE_STATES,
12
  init_state_section,
13
  header_section,
14
  sidebar_fragment,
 
15
  retrive_response_with_ui
16
  )
17
 
@@ -126,7 +127,7 @@ def bottom_input_section():
126
  st.button(
127
  'Clear',
128
  disabled=st.session_state.disprompt,
129
- on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES))
130
  )
131
 
132
  with bottom_cols[1]:
@@ -225,7 +226,7 @@ def playground_page():
225
  <strong>Paralinguistics</strong> tasks.
226
  This playground currently only support <strong>single-round</strong> conversation.
227
  """,
228
- concise_description=" It currently only support <strong>single-round</strong> conversation."
229
  )
230
 
231
  with st.sidebar:
 
8
  from src.content.common import (
9
  MODEL_NAMES,
10
  AUDIO_SAMPLES_W_INSTRUCT,
11
+ PLAYGROUND_DIALOGUE_STATES,
12
  init_state_section,
13
  header_section,
14
  sidebar_fragment,
15
+ reset_states,
16
  retrive_response_with_ui
17
  )
18
 
 
127
  st.button(
128
  'Clear',
129
  disabled=st.session_state.disprompt,
130
+ on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES)
131
  )
132
 
133
  with bottom_cols[1]:
 
226
  <strong>Paralinguistics</strong> tasks.
227
  This playground currently only support <strong>single-round</strong> conversation.
228
  """,
229
+ concise_description=" This playground currently only support <strong>single-round</strong> conversation."
230
  )
231
 
232
  with st.sidebar:
src/content/voice_chat.py CHANGED
@@ -4,20 +4,26 @@ import base64
4
  import numpy as np
5
  import streamlit as st
6
 
7
- from src.generation import MAX_AUDIO_LENGTH
8
- from src.utils import bytes_to_array, array_to_bytes
 
 
 
9
  from src.content.common import (
10
  MODEL_NAMES,
11
- DEFAULT_DIALOGUE_STATES,
12
  init_state_section,
13
  header_section,
14
  sidebar_fragment,
 
15
  retrive_response_with_ui
16
  )
 
17
 
18
 
19
  # TODO: change this.
20
- DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply the user in a friendly and helpful way."
 
21
 
22
 
23
  def _update_audio(audio_bytes):
@@ -30,13 +36,12 @@ def _update_audio(audio_bytes):
30
 
31
 
32
  def bottom_input_section():
33
- st.info(":bulb: Ask something with clear intention.")
34
  bottom_cols = st.columns([0.03, 0.97])
35
  with bottom_cols[0]:
36
  st.button(
37
  'Clear',
38
  disabled=st.session_state.disprompt,
39
- on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES))
40
  )
41
 
42
  with bottom_cols[1]:
@@ -45,7 +50,6 @@ def bottom_input_section():
45
  label_visibility="collapsed",
46
  on_change=lambda: st.session_state.update(
47
  on_record=True,
48
- vc_messages=[],
49
  disprompt=True
50
  ),
51
  key='record'
@@ -56,13 +60,25 @@ def bottom_input_section():
56
  _update_audio(audio_bytes)
57
  st.session_state.update(
58
  on_record=False,
59
- new_prompt=DEFAULT_PROMPT
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def conversation_section():
 
64
  for message in st.session_state.vc_messages:
65
- with st.chat_message(message["role"]):
66
  if message.get("error"):
67
  st.error(message["error"])
68
  for warning_msg in message.get("warnings", []):
@@ -75,53 +91,74 @@ def conversation_section():
75
  with st._bottom:
76
  bottom_input_section()
77
 
78
- if one_time_prompt := st.session_state.new_prompt:
79
- one_time_array = st.session_state.vc_audio_array
80
- one_time_base64 = st.session_state.vc_audio_base64
81
- st.session_state.update(
82
- new_prompt="",
83
- one_time_array=np.array([]),
84
- one_time_base64="",
85
- vc_messages=[]
86
- )
87
 
88
- with st.chat_message("user"):
89
- st.audio(one_time_array, format="audio/wav", sample_rate=16000)
 
90
 
91
- st.session_state.vc_messages.append({"role": "user", "audio": one_time_array})
 
 
 
 
 
 
 
 
 
92
 
93
- with st.chat_message("assistant"):
94
- with st.spinner("Thinking..."):
95
- error_msg, warnings, response = retrive_response_with_ui(
96
- model_name=MODEL_NAMES["audiollm-it"]["vllm_name"],
97
- text_input=one_time_prompt,
98
- array_audio_input=one_time_array,
99
- base64_audio_input=one_time_base64,
100
- stream=True
101
- )
102
-
103
- st.session_state.vc_messages.append({
104
- "role": "assistant",
105
- "error": error_msg,
106
- "warnings": warnings,
107
- "content": response
108
- })
109
-
110
- st.session_state.disprompt=False
111
- st.rerun(scope="app")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def voice_chat_page():
115
  init_state_section()
116
  header_section(
117
  component_name="Voice Chat",
118
- description=""" It currently only support <strong>single-round</strong> conversation.
119
  Feel free to talk about anything.""",
120
- concise_description=" It currently only support <strong>single-round</strong> conversation.",
121
  icon="🗣️"
122
  )
123
 
124
  with st.sidebar:
125
  sidebar_fragment()
126
 
 
127
  conversation_section()
 
4
  import numpy as np
5
  import streamlit as st
6
 
7
+ from src.generation import (
8
+ MAX_AUDIO_LENGTH,
9
+ prepare_multimodal_content,
10
+ change_multimodal_content
11
+ )
12
  from src.content.common import (
13
  MODEL_NAMES,
14
+ VOICE_CHAT_DIALOGUE_STATES,
15
  init_state_section,
16
  header_section,
17
  sidebar_fragment,
18
+ reset_states,
19
  retrive_response_with_ui
20
  )
21
+ from src.utils import bytes_to_array, array_to_bytes
22
 
23
 
24
  # TODO: change this.
25
+ DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply to the user in a friendly and helpful way."
26
+ MAX_VC_ROUNDS = 5
27
 
28
 
29
  def _update_audio(audio_bytes):
 
36
 
37
 
38
  def bottom_input_section():
 
39
  bottom_cols = st.columns([0.03, 0.97])
40
  with bottom_cols[0]:
41
  st.button(
42
  'Clear',
43
  disabled=st.session_state.disprompt,
44
+ on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES)
45
  )
46
 
47
  with bottom_cols[1]:
 
50
  label_visibility="collapsed",
51
  on_change=lambda: st.session_state.update(
52
  on_record=True,
 
53
  disprompt=True
54
  ),
55
  key='record'
 
60
  _update_audio(audio_bytes)
61
  st.session_state.update(
62
  on_record=False,
 
63
  )
64
 
65
 
66
+ @st.fragment
67
+ def system_prompt_fragment():
68
+ with st.expander("System Prompt"):
69
+ st.text_area(
70
+ label="Insert system instructions or background knowledge here.",
71
+ label_visibility="collapsed",
72
+ max_chars=5000,
73
+ key="system_prompt",
74
+ value=DEFAULT_PROMPT,
75
+ )
76
+
77
+
78
  def conversation_section():
79
+ chat_message_container = st.container(height=480)
80
  for message in st.session_state.vc_messages:
81
+ with chat_message_container.chat_message(message["role"]):
82
  if message.get("error"):
83
  st.error(message["error"])
84
  for warning_msg in message.get("warnings", []):
 
91
  with st._bottom:
92
  bottom_input_section()
93
 
94
+ if not st.session_state.vc_audio_base64:
95
+ return
 
 
 
 
 
 
 
96
 
97
+ if len(st.session_state.vc_messages) >= MAX_VC_ROUNDS * 2:
98
+ st.toast(f":warning: max conversation rounds ({MAX_VC_ROUNDS}) reached!")
99
+ return
100
 
101
+ one_time_prompt = DEFAULT_PROMPT
102
+ one_time_array = st.session_state.vc_audio_array
103
+ one_time_base64 = st.session_state.vc_audio_base64
104
+ st.session_state.update(
105
+ vc_audio_array=np.array([]),
106
+ vc_audio_base64="",
107
+ )
108
+
109
+ with chat_message_container.chat_message("user"):
110
+ st.audio(one_time_array, format="audio/wav", sample_rate=16000)
111
 
112
+ st.session_state.vc_messages.append({"role": "user", "audio": one_time_array})
113
+
114
+ if not st.session_state.vc_model_messages:
115
+ one_time_prompt = st.session_state.system_prompt
116
+ else:
117
+ st.session_state.vc_model_messages[0]["content"] = change_multimodal_content(
118
+ st.session_state.vc_model_messages[0]["content"],
119
+ text_input=st.session_state.system_prompt
120
+ )
121
+
122
+ with chat_message_container.chat_message("assistant"):
123
+ with st.spinner("Thinking..."):
124
+ error_msg, warnings, response = retrive_response_with_ui(
125
+ model_name=MODEL_NAMES["audiollm-it"]["vllm_name"],
126
+ text_input=one_time_prompt,
127
+ array_audio_input=one_time_array,
128
+ base64_audio_input=one_time_base64,
129
+ stream=True,
130
+ history=st.session_state.vc_model_messages
131
+ )
132
+
133
+ st.session_state.vc_messages.append({
134
+ "role": "assistant",
135
+ "error": error_msg,
136
+ "warnings": warnings,
137
+ "content": response
138
+ })
139
+
140
+ mm_content = prepare_multimodal_content(one_time_prompt, one_time_base64)
141
+ st.session_state.vc_model_messages.extend([
142
+ {"role": "user", "content": mm_content},
143
+ {"role": "assistant", "content": response}
144
+ ])
145
+
146
+ st.session_state.disprompt=False
147
+ st.rerun(scope="app")
148
 
149
 
150
  def voice_chat_page():
151
  init_state_section()
152
  header_section(
153
  component_name="Voice Chat",
154
+ description=""" It currently only support up to <strong>5 rounds</strong> of conversations.
155
  Feel free to talk about anything.""",
156
+ concise_description=" It currently only support up to <strong>5 rounds</strong> of conversations.",
157
  icon="🗣️"
158
  )
159
 
160
  with st.sidebar:
161
  sidebar_fragment()
162
 
163
+ system_prompt_fragment()
164
  conversation_section()
src/generation.py CHANGED
@@ -40,6 +40,45 @@ def load_model() -> Dict:
40
  return name_to_client_mapper
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def _retrive_response(
44
  model: str,
45
  text_input: str,
 
40
  return name_to_client_mapper
41
 
42
 
43
+ def prepare_multimodal_content(text_input, base64_audio_input):
44
+ return [
45
+ {
46
+ "type": "text",
47
+ "text": f"Text instruction: {text_input}"
48
+ },
49
+ {
50
+ "type": "audio_url",
51
+ "audio_url": {
52
+ "url": f"data:audio/ogg;base64,{base64_audio_input}"
53
+ },
54
+ },
55
+ ]
56
+
57
+
58
+ def change_multimodal_content(
59
+ original_content,
60
+ text_input="",
61
+ base64_audio_input=""):
62
+
63
+ # Since python 3.7 dictionary is ordered.
64
+ if text_input:
65
+ original_content[0] = {
66
+ "type": "text",
67
+ "text": f"Text instruction: {text_input}"
68
+ }
69
+
70
+ if base64_audio_input:
71
+ original_content[1] = {
72
+ "type": "audio_url",
73
+ "audio_url": {
74
+ "url": f"data:audio/ogg;base64,{base64_audio_input}"
75
+ }
76
+ }
77
+
78
+ return original_content
79
+
80
+
81
+
82
  def _retrive_response(
83
  model: str,
84
  text_input: str,