martin commited on
Commit
930f36f
·
1 Parent(s): ed26be9

fix aqta predict

Browse files
Files changed (2) hide show
  1. app.py +18 -9
  2. stepaudio.py +6 -2
app.py CHANGED
@@ -34,7 +34,7 @@ class CustomAsr:
34
  return text
35
 
36
 
37
- def add_message(chatbot, history, mic, text, asr_model):
38
  if not mic and not text:
39
  return chatbot, history, "Input is empty"
40
 
@@ -43,10 +43,7 @@ def add_message(chatbot, history, mic, text, asr_model):
43
  history.append({"role": "user", "content": text})
44
  elif mic and Path(mic).exists():
45
  chatbot.append({"role": "user", "content": {"path": mic}})
46
- # 使用用户语音的 asr 结果为了加速推理
47
- text = asr_model.run(mic)
48
- chatbot.append({"role": "user", "content": text})
49
- history.append({"role": "user", "content": text})
50
 
51
  print(f"{history=}")
52
  return chatbot, history, None
@@ -69,12 +66,24 @@ def save_tmp_audio(audio, sr):
69
  return temp_audio.name
70
 
71
 
72
- def predict(chatbot, history, audio_model):
73
  """Generate a response from the model."""
74
  try:
 
 
 
 
 
 
75
  text, audio, sr = audio_model(history, "闫雨婷")
76
  print(f"predict {text=}")
77
  audio_path = save_tmp_audio(audio, sr)
 
 
 
 
 
 
78
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
79
  chatbot.append({"role": "assistant", "content": text})
80
  history.append({"role": "assistant", "content": text})
@@ -105,13 +114,13 @@ def _launch_demo(args, audio_model, asr_model):
105
 
106
  def on_submit(chatbot, history, mic, text):
107
  chatbot, history, error = add_message(
108
- chatbot, history, mic, text, asr_model
109
  )
110
  if error:
111
  gr.Warning(error) # 显示警告消息
112
  return chatbot, history, None, None
113
  else:
114
- chatbot, history = predict(chatbot, history, audio_model)
115
  return chatbot, history, None, None
116
 
117
  submit_btn.click(
@@ -133,7 +142,7 @@ def _launch_demo(args, audio_model, asr_model):
133
  while history and history[-1]["role"] == "assistant":
134
  print(f"discard {history[-1]}")
135
  history.pop()
136
- return predict(chatbot, history, audio_model)
137
 
138
  regen_btn.click(
139
  regenerate,
 
34
  return text
35
 
36
 
37
+ def add_message(chatbot, history, mic, text):
38
  if not mic and not text:
39
  return chatbot, history, "Input is empty"
40
 
 
43
  history.append({"role": "user", "content": text})
44
  elif mic and Path(mic).exists():
45
  chatbot.append({"role": "user", "content": {"path": mic}})
46
+ history.append({"role": "user", "content": {"type":"audio", "audio": mic}})
 
 
 
47
 
48
  print(f"{history=}")
49
  return chatbot, history, None
 
66
  return temp_audio.name
67
 
68
 
69
+ def predict(chatbot, history, audio_model, asr_model):
70
  """Generate a response from the model."""
71
  try:
72
+ is_input_audio = False
73
+ user_audio_path = None
74
+ # 检测用户输入的是音频还是文本
75
+ if isinstance(history[-1]["content"], dict):
76
+ is_input_audio = True
77
+ user_audio_path = history[-1]["content"]["audio"]
78
  text, audio, sr = audio_model(history, "闫雨婷")
79
  print(f"predict {text=}")
80
  audio_path = save_tmp_audio(audio, sr)
81
+ # 缓存用户语音的 asr 文本结果为了加速下一次推理
82
+ if is_input_audio:
83
+ asr_text = asr_model.run(user_audio_path)
84
+ chatbot.append({"role": "user", "content": asr_text})
85
+ history[-1]["content"] = asr_text
86
+ print(f"{asr_text=}")
87
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
88
  chatbot.append({"role": "assistant", "content": text})
89
  history.append({"role": "assistant", "content": text})
 
114
 
115
  def on_submit(chatbot, history, mic, text):
116
  chatbot, history, error = add_message(
117
+ chatbot, history, mic, text
118
  )
119
  if error:
120
  gr.Warning(error) # 显示警告消息
121
  return chatbot, history, None, None
122
  else:
123
+ chatbot, history = predict(chatbot, history, audio_model, asr_model)
124
  return chatbot, history, None, None
125
 
126
  submit_btn.click(
 
142
  while history and history[-1]["role"] == "assistant":
143
  print(f"discard {history[-1]}")
144
  history.pop()
145
+ return predict(chatbot, history, audio_model, asr_model)
146
 
147
  regen_btn.click(
148
  regenerate,
stepaudio.py CHANGED
@@ -42,6 +42,11 @@ class StepAudio:
42
  output_audio = volumn_adjust(output_audio, volumn_ratio)
43
  return output_text, output_audio, sr
44
 
 
 
 
 
 
45
  def apply_chat_template(self, messages: list):
46
  text_with_audio = ""
47
  for msg in messages:
@@ -55,8 +60,7 @@ class StepAudio:
55
  if content["type"] == "text":
56
  text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
57
  elif content["type"] == "audio":
58
- audio_wav, sr = load_audio(content["audio"])
59
- audio_tokens = self.encoder(audio_wav, sr)
60
  text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
61
  elif content is None:
62
  text_with_audio += f"<|BOT|>{role}\n"
 
42
  output_audio = volumn_adjust(output_audio, volumn_ratio)
43
  return output_text, output_audio, sr
44
 
45
+ def encode_audio(self, audio_path):
46
+ audio_wav, sr = load_audio(audio_path)
47
+ audio_tokens = self.encoder(audio_wav, sr)
48
+ return audio_tokens
49
+
50
  def apply_chat_template(self, messages: list):
51
  text_with_audio = ""
52
  for msg in messages:
 
60
  if content["type"] == "text":
61
  text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
62
  elif content["type"] == "audio":
63
+ audio_tokens = self.encode_audio(content["audio"])
 
64
  text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
65
  elif content is None:
66
  text_with_audio += f"<|BOT|>{role}\n"