tomxxie commited on
Commit
d9ae2cb
·
1 Parent(s): a5dad3b

修改英文展示

Browse files
Files changed (2) hide show
  1. app.py +9 -8
  2. app_old.py +254 -49
app.py CHANGED
@@ -158,8 +158,8 @@ def download_audio(input_wav_path):
158
  else:
159
  return None
160
 
161
- # 创建Gradio界面
162
- with gr.Blocks(css=custom_css) as demo:
163
  # 添加标题
164
  gr.Markdown(
165
  f"""
@@ -176,13 +176,13 @@ with gr.Blocks(css=custom_css) as demo:
176
  with gr.Column(scale=1):
177
  audio_input = gr.Audio(label="Record", type="filepath")
178
  with gr.Column(scale=1, min_width=300): # 给输出框设置最小宽度,确保等高对齐
179
- output_text = gr.Textbox(label=" Output", lines=8, placeholder="The generated result will be displayed here...", interactive=False)
180
 
181
  # 添加任务选择和自定义输入框
182
  with gr.Row():
183
  task_dropdown = gr.Dropdown(
184
  label="Task",
185
- choices=list(TASK_PROMPT_MAPPING.keys()) + ["Custom Input Text"], # 新增选项
186
  value="ASR (Automatic Speech Recognition)"
187
  )
188
  custom_prompt_input = gr.Textbox(label="Custom Task Prompt", placeholder="Please enter a custom task prompt...", visible=False) # 新增文本输入框
@@ -202,7 +202,7 @@ with gr.Blocks(css=custom_css) as demo:
202
  container=False,
203
  elem_classes="confirmation-buttons"
204
  )
205
- save_button = gr.Button("提交反馈", variant="secondary")
206
 
207
  # 添加底部内容
208
  with gr.Row():
@@ -229,15 +229,16 @@ with gr.Blocks(css=custom_css) as demo:
229
  return gr.update(visible=False)
230
 
231
  def handle_submit(input_wav_path, task_choice, custom_prompt):
232
- if task_choice == "自主输入文本":
233
  input_prompt = custom_prompt # 使用用户输入的自定义文本
234
  else:
235
  input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型") # 使用预定义的提示
236
  output_res = do_decode(input_wav_path, input_prompt)
237
  return output_res
238
 
 
239
  task_dropdown.change(
240
- fn=lambda choice: gr.update(visible=choice == "自主输入文本"),
241
  inputs=task_dropdown,
242
  outputs=custom_prompt_input
243
  )
@@ -264,5 +265,5 @@ with gr.Blocks(css=custom_css) as demo:
264
  outputs=confirmation_row
265
  )
266
 
267
- if __name__== "__main__":
268
  demo.launch()
 
158
  else:
159
  return None
160
 
161
+ # 创建 Gradio 界面
162
+ with gr.Blocks() as demo:
163
  # 添加标题
164
  gr.Markdown(
165
  f"""
 
176
  with gr.Column(scale=1):
177
  audio_input = gr.Audio(label="Record", type="filepath")
178
  with gr.Column(scale=1, min_width=300): # 给输出框设置最小宽度,确保等高对齐
179
+ output_text = gr.Textbox(label="Output", lines=8, placeholder="The generated result will be displayed here...", interactive=False)
180
 
181
  # 添加任务选择和自定义输入框
182
  with gr.Row():
183
  task_dropdown = gr.Dropdown(
184
  label="Task",
185
+ choices=list(TASK_PROMPT_MAPPING.keys()) + ["Custom Task Prompt"], # 新增选项
186
  value="ASR (Automatic Speech Recognition)"
187
  )
188
  custom_prompt_input = gr.Textbox(label="Custom Task Prompt", placeholder="Please enter a custom task prompt...", visible=False) # 新增文本输入框
 
202
  container=False,
203
  elem_classes="confirmation-buttons"
204
  )
205
+ save_button = gr.Button("Submit Feedback", variant="secondary")
206
 
207
  # 添加底部内容
208
  with gr.Row():
 
229
  return gr.update(visible=False)
230
 
231
  def handle_submit(input_wav_path, task_choice, custom_prompt):
232
+ if task_choice == "Custom Task Prompt":
233
  input_prompt = custom_prompt # 使用用户输入的自定义文本
234
  else:
235
  input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型") # 使用预定义的提示
236
  output_res = do_decode(input_wav_path, input_prompt)
237
  return output_res
238
 
239
+ # 当任务选择框的值发生变化时,更新自定义输入框的可见性
240
  task_dropdown.change(
241
+ fn=lambda choice: gr.update(visible=choice == "Custom Task Prompt"),
242
  inputs=task_dropdown,
243
  outputs=custom_prompt_input
244
  )
 
265
  outputs=confirmation_row
266
  )
267
 
268
+ if __name__ == "__main__":
269
  demo.launch()
app_old.py CHANGED
@@ -1,64 +1,269 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import base64
2
+ import json
3
+ import time
4
+ from types import SimpleNamespace
5
+
6
+ # import spaces
7
+
8
  import gradio as gr
9
+ import os
10
 
11
+ import sys
12
+
13
+ import yaml
14
+
15
+ sys.path.insert(0, './')
16
+ # from wenet.utils.init_tokenizer import init_tokenizer
17
+ # from wenet.utils.init_model import init_model
18
+ import logging
19
+ # import librosa
20
+ # import torch
21
+ # import torchaudio
22
+ import numpy as np
23
+ def makedir_for_file(filepath):
24
+ dirpath = os.path.dirname(filepath)
25
+ if not os.path.exists(dirpath):
26
+ os.makedirs(dirpath)
27
+ def load_dict_from_yaml(file_path: str):
28
+ with open(file_path, 'rt', encoding='utf-8') as f:
29
+ dict_1 = yaml.load(f, Loader=yaml.FullLoader)
30
+ return dict_1
31
+
32
+
33
+ # 获取当前脚本文件的绝对路径
34
+ abs_path = os.path.abspath(__file__)
35
+ # 将图片转换为 Base64
36
+ with open(os.path.join(os.path.dirname(abs_path), "lab.png"), "rb") as image_file:
37
+ encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
38
+
39
+ # with open("./cat.jpg", "rb") as image_file:
40
+ # encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
41
+
42
+ # 自定义CSS样式
43
+ custom_css = """
44
+ /* 自定义CSS样式 */
45
  """
 
 
 
46
 
47
+ # 任务提示映射
48
+ TASK_PROMPT_MAPPING = {
49
+ "ASR (Automatic Speech Recognition)": "执行语音识别任务,将音频转换为文字。",
50
+ "SRWT (Speech Recognition with Timestamps)": "请转录音频内容,并为每个英文词汇及其对应的中文翻译标注出精确到0.1秒的起止时间,时间范围用<>括起来。",
51
+ "VED (Vocal Event Detection)(Categories:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转录为文字记录,并在记录末尾标注<音频事件>标签,音频事件共8种:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other。",
52
+ "SER (Speech Emotion Recognition)(Categories:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注<情感>标签,情感共8种:sad,anger,neutral,happy,surprise,fear,disgust,和other。",
53
+ "SSR (Speaking Style Recognition)(Categories:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频内容进行文字转录,并在最后添加<风格>标签,标签共8种:新闻科普、恐怖故事、童话故事、客服、诗歌散文、有声书、日常口语、其他。",
54
+ "SGC (Speaker Gender Classification)(Categories:female,male)": "请将音频转录为文本,并在文本结尾处标注<性别>标签,性别为female或male。",
55
+ "SAP (Speaker Age Prediction)(Categories:child、adult和old)": "请将音频转录为文本,并在文本结尾处标注<年龄>标签,年龄划分为child、adult和old三种。",
56
+ "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
57
+ }
58
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def init_model_my():
61
+ logging.basicConfig(level=logging.DEBUG,
62
+ format='%(asctime)s %(levelname)s %(message)s')
63
+ config_path = "train.yaml"
64
+ from huggingface_hub import hf_hub_download
65
+ # 从Hugging Face下载.pt文件
66
+ pt_file_path = hf_hub_download(repo_id="ASLP-lab/OSUM", filename="infer.pt")
67
+ args = SimpleNamespace(**{
68
+ "checkpoint": pt_file_path,
69
+ })
70
+ configs = load_dict_from_yaml(config_path)
71
+ model, configs = init_model(args, configs)
72
+ model = model.cuda()
73
+ tokenizer = init_tokenizer(configs)
74
+ print(model)
75
+ return model, tokenizer
76
 
77
+ # global_model, tokenizer = init_model_my()
78
+ print("model init success")
79
+ def do_resample(input_wav_path, output_wav_path):
80
+ """"""
81
+ print(f'input_wav_path: {input_wav_path}, output_wav_path: {output_wav_path}')
82
+ waveform, sample_rate = torchaudio.load(input_wav_path)
83
+ # 检查音频的维度
84
+ num_channels = waveform.shape[0]
85
+ # 如果音频是多通道的,则进行通道平均
86
+ if num_channels > 1:
87
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
88
+ waveform = torchaudio.transforms.Resample(
89
+ orig_freq=sample_rate, new_freq=16000)(waveform)
90
+ makedir_for_file(output_wav_path)
91
+ torchaudio.save(output_wav_path, waveform, 16000)
92
 
93
+ # @spaces.GPU
94
+ def true_decode_fuc(input_wav_path, input_prompt):
95
+ # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
96
+ print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
97
+ timestamp_ms = int(time.time() * 1000)
98
+ now_file_tmp_path_resample = f'/home/xlgeng/.cache/.temp/{timestamp_ms}_resample.wav'
99
+ do_resample(input_wav_path, now_file_tmp_path_resample)
100
+ input_wav_path = now_file_tmp_path_resample
101
+ waveform, sample_rate = torchaudio.load(input_wav_path)
102
+ waveform = waveform.squeeze(0) # (channel=1, sample) -> (sample,)
103
+ print(f'wavform shape: {waveform.shape}, sample_rate: {sample_rate}')
104
+ window = torch.hann_window(400)
105
+ stft = torch.stft(waveform,
106
+ 400,
107
+ 160,
108
+ window=window,
109
+ return_complex=True)
110
+ magnitudes = stft[..., :-1].abs() ** 2
111
 
112
+ filters = torch.from_numpy(
113
+ librosa.filters.mel(sr=sample_rate,
114
+ n_fft=400,
115
+ n_mels=80))
116
+ mel_spec = filters @ magnitudes
 
 
 
117
 
118
+ # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
119
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
120
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
121
+ log_spec = (log_spec + 4.0) / 4.0
122
+ feat = log_spec.transpose(0, 1)
123
+ feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
124
+ feat = feat.unsqueeze(0).cuda()
125
+ # feat = feat.half()
126
+ # feat_lens = feat_lens.half()
127
+ model = global_model.cuda()
128
+ model.eval()
129
+ res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
130
+ print("耿雪龙哈哈:", res_text)
131
+ return res_text
132
 
133
+ def do_decode(input_wav_path, input_prompt):
134
+ print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
135
+ # 省略处理逻辑
136
+ # output_res= true_decode_fuc(input_wav_path, input_prompt)
137
+ output_res = f"耿雪龙哈哈:测试结果, input_wav_path= {input_wav_path}, input_prompt= {input_prompt}"
138
+ return output_res
139
 
140
+ def save_to_jsonl(if_correct, wav, prompt, res):
141
+ data = {
142
+ "if_correct": if_correct,
143
+ "wav": wav,
144
+ "task": prompt,
145
+ "res": res
146
+ }
147
+ with open("results.jsonl", "a", encoding="utf-8") as f:
148
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
149
+
150
+ def handle_submit(input_wav_path, input_prompt):
151
+ output_res = do_decode(input_wav_path, input_prompt)
152
+ return output_res
153
+
154
+ def download_audio(input_wav_path):
155
+ if input_wav_path:
156
+ # 返回文件路径供下载
157
+ return input_wav_path
158
+ else:
159
+ return None
160
+
161
+ # 创建 Gradio 界面
162
+ with gr.Blocks() as demo:
163
+ # 添加标题
164
+ gr.Markdown(
165
+ f"""
166
+ <div style="display: flex; align-items: center; justify-content: center; text-align: center;">
167
+ <h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
168
+ OSUM Speech Understanding Model Test
169
+ </h1>
170
+ </div>
171
+ """
172
+ )
173
+
174
+ # 添加音频输入和任务选择
175
+ with gr.Row():
176
+ with gr.Column(scale=1):
177
+ audio_input = gr.Audio(label="Record", type="filepath")
178
+ with gr.Column(scale=1, min_width=300): # 给输出框设置最小宽度,确保等高对齐
179
+ output_text = gr.Textbox(label="Output", lines=8, placeholder="The generated result will be displayed here...", interactive=False)
180
+
181
+ # 添加任务选择和自定义输入框
182
+ with gr.Row():
183
+ task_dropdown = gr.Dropdown(
184
+ label="Task",
185
+ choices=list(TASK_PROMPT_MAPPING.keys()) + ["Custom Task Prompt"], # 新增选项
186
+ value="ASR (Automatic Speech Recognition)"
187
+ )
188
+ custom_prompt_input = gr.Textbox(label="Custom Task Prompt", placeholder="Please enter a custom task prompt...", visible=False) # 新增文本输入框
189
+
190
+ # 添加按钮(下载按钮在左边,开始处理按钮在右边)
191
+ with gr.Row():
192
+ download_button = gr.DownloadButton("Download Recording", variant="secondary", elem_classes=["button-height", "download-button"])
193
+ submit_button = gr.Button("Start to Process", variant="primary", elem_classes=["button-height", "submit-button"])
194
+
195
+ # 添加确认组件
196
+ with gr.Row(visible=False) as confirmation_row:
197
+ gr.Markdown("Please determine whether the result is correct:")
198
+ confirmation_buttons = gr.Radio(
199
+ choices=["Correct", "Incorrect"],
200
+ label="",
201
+ interactive=True,
202
+ container=False,
203
+ elem_classes="confirmation-buttons"
204
+ )
205
+ save_button = gr.Button("Submit Feedback", variant="secondary")
206
+
207
+ # 添加底部内容
208
+ with gr.Row():
209
+ # 底部内容容器
210
+ with gr.Column(scale=1, min_width=800): # 设置最小宽度以确保内容居中
211
+ gr.Markdown(
212
+ f"""
213
+ <div style="position: fixed; bottom: 20px; left: 50%; transform: translateX(-50%); display: flex; align-items: center; justify-content: center; gap: 20px;">
214
+ <div style="text-align: center;">
215
+ <p style="margin: 0;"><strong>Audio, Speech and Language Processing Group (ASLP@NPU),</strong></p>
216
+ <p style="margin: 0;"><strong>Northwestern Polytechnical University</strong></p>
217
+ </div>
218
+ <img src="data:image/png;base64,{encoded_string}" alt="OSUM Logo" style="height: 80px; width: auto;">
219
+ </div>
220
+ """
221
+ )
222
+
223
+ # 绑定事件
224
+ def show_confirmation(output_res, input_wav_path, input_prompt):
225
+ return gr.update(visible=True), output_res, input_wav_path, input_prompt
226
+
227
+ def save_result(if_correct, wav, prompt, res):
228
+ save_to_jsonl(if_correct, wav, prompt, res)
229
+ return gr.update(visible=False)
230
+
231
+ def handle_submit(input_wav_path, task_choice, custom_prompt):
232
+ if task_choice == "Custom Task Prompt":
233
+ input_prompt = custom_prompt # 使用用户输入的自定义文本
234
+ else:
235
+ input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型") # 使用预定义的提示
236
+ output_res = do_decode(input_wav_path, input_prompt)
237
+ return output_res
238
+
239
+ # 当任务选择框的值发生变化时,更新自定义输入框的可见性
240
+ task_dropdown.change(
241
+ fn=lambda choice: gr.update(visible=choice == "Custom Task Prompt"),
242
+ inputs=task_dropdown,
243
+ outputs=custom_prompt_input
244
+ )
245
+
246
+ submit_button.click(
247
+ fn=handle_submit,
248
+ inputs=[audio_input, task_dropdown, custom_prompt_input],
249
+ outputs=output_text
250
+ ).then(
251
+ fn=show_confirmation,
252
+ inputs=[output_text, audio_input, task_dropdown],
253
+ outputs=[confirmation_row, output_text, audio_input, task_dropdown]
254
+ )
255
+
256
+ download_button.click(
257
+ fn=download_audio,
258
+ inputs=[audio_input],
259
+ outputs=[download_button] # 输出到 download_button
260
+ )
261
 
262
+ save_button.click(
263
+ fn=save_result,
264
+ inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
265
+ outputs=confirmation_row
266
+ )
267
 
268
  if __name__ == "__main__":
269
+ demo.launch()