wwdok commited on
Commit
ea6d8b2
·
1 Parent(s): 489bea3

add chatgpt support

Browse files
Files changed (6) hide show
  1. README.md +14 -2
  2. app.py +26 -88
  3. docs/text_postprocess.png +0 -0
  4. requirements.txt +5 -3
  5. src/utils.py +152 -3
  6. src/vad.py +3 -3
README.md CHANGED
@@ -16,7 +16,18 @@ Fork from : https://huggingface.co/spaces/aadnk/faster-whisper-webui/tree/main
16
 
17
  我的更改:
18
 
19
- * 新添加了一个文本后处理的tab
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Running Locally
22
 
@@ -194,4 +205,5 @@ registry.gitlab.com/aadnk/whisper-webui:latest
194
 
195
  - [ ] 如果是一个视频列表,只下载第一个视频
196
  - [ ] ~~如果已经转录完了再选翻译任务,则不重新转录~~
197
- - [ ] ~~目前翻译任务只能由任意语言翻译成英语,不能指定其他语言,要能支持翻译成其他语言,至少支持中文~~
 
 
16
 
17
  我的更改:
18
 
19
+ * 新添加了一个文本后处理的tab
20
+ ![Alt text](docs/text_postprocess.png)
21
+ * 支持使用ChatGPT或Paddle auto punc(二选一)对文本自动添加合适的标点符号
22
+ * 支持使用pycorrector对文本进行纠错
23
+ * 支持去掉指定的语气助词
24
+ * 支持对输出文本的指定字符进行替换
25
+
26
+ 该App同时部署在:
27
+
28
+ * HuggingFace Spaces: https://huggingface.co/spaces/wwdok/faster-whisper-webui-cn
29
+ * OpenXLab: https://openxlab.org.cn/apps/detail/wwdok/faster-whisper-webui
30
+ * ModelScope: https://modelscope.cn/studios/wwd123/faster-whisper-webui-cn
31
 
32
  # Running Locally
33
 
 
205
 
206
  - [ ] 如果是一个视频列表,只下载第一个视频
207
  - [ ] ~~如果已经转录完了再选翻译任务,则不重新转录~~
208
+ - [ ] ~~目前翻译任务只能由任意语言翻译成英语,不能指定其他语言,要能支持翻译成其他语言,至少支持中文~~\
209
+ - [ ] 使用ChatGPT自动纠正错别字和添加标点符号
app.py CHANGED
@@ -10,7 +10,6 @@ import pathlib
10
  import tempfile
11
  import zipfile
12
  import numpy as np
13
- import pyperclip
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
@@ -32,6 +31,7 @@ import gradio as gr
32
 
33
  from src.download import ExceededMaximumDuration, download_url
34
  from src.utils import optional_int, slugify, write_srt, write_vtt
 
35
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
36
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
37
  from src.whisper.whisperFactory import create_whisper_container
@@ -614,102 +614,36 @@ def create_ui(app_config: ApplicationConfig):
614
  gr.Text(label="Segments")
615
  ])
616
 
617
- def get_chunks(s, maxlength, separator=None):
618
- start = 0
619
- end = 0
620
- while start + maxlength < len(s) and end != -1:
621
- if separator is not None:
622
- end = s.rfind(separator, start, start + maxlength + 1)
623
- segment = s[start:end]
624
- yield segment.replace(separator, "")
625
- start = end +1
626
- else:
627
- end = start + maxlength
628
- yield s[start:end]
629
- start = end
630
-
631
- yield s[start:]
632
-
633
- def post_processing(text, apply_correction, auto_punc, separator, remove_words):
634
- print(f"==>> separator: {separator}")
635
- original_separator1 = " "
636
- original_separator2 = ","
637
- # 对于长文本需要先分段再推理,推理完再合并
638
- # auto_punc:自动添加合适的、不同的标点符号
639
- if auto_punc == True:
640
- # 自动分段文本之前先去除原有的标点符号
641
- text = text.replace(original_separator1, "")
642
- text = text.replace(original_separator2, "")
643
- import paddlehub as hub
644
- model = hub.Module(name='auto_punc', version='1.0.0')
645
- t3 = time.time()
646
- # split long text to short text less than max_length and store them in list
647
- max_length = 256
648
- chunks = list(get_chunks(text, max_length))
649
- results = []
650
- results = model.add_puncs(chunks, max_length=max_length)
651
- text = ",".join(results) # 分段处硬编码成使用中文逗号分割
652
- t4 = time.time()
653
- print("Auto punc finished. Cost time: {:.2f}s".format(t4-t3))
654
- # print(f"==>> text after auto punc: {text}")
655
- else:
656
- # 将空格全部统一替换成一种分隔符
657
- if separator == "\\n":
658
- # 直接使用separator会无法换行
659
- text = text.replace(original_separator1, "\n")
660
- text = text.replace(original_separator2, "\n")
661
- else:
662
- text = text.replace(original_separator1, separator)
663
- text = text.replace(original_separator2, separator)
664
-
665
- if apply_correction == True:
666
- import pycorrector
667
- print("Start correcting...")
668
- t1 = time.time()
669
- text, detail = pycorrector.correct(text)
670
- t2 = time.time()
671
- print("Correcting finished. Cost time: {:.2f}s".format(t2-t1))
672
- print(f"==>> detail: {detail}")
673
-
674
- # 去掉语气词
675
- t5 = time.time()
676
- remove_words = remove_words.split(",") + remove_words.split(",") + remove_words.split(" ")
677
- for word in remove_words:
678
- text = text.replace(word, "")
679
- t6 = time.time()
680
- print("Remove words finished. Cost time: {:.2f}s".format(t6-t5))
681
-
682
- return text
683
-
684
- def replace(text, src_word, target_word):
685
- text = text.replace(src_word, target_word)
686
- return text
687
-
688
- def switch_punc_method(auto_punc):
689
- if auto_punc == True:
690
- # gr.update里的参数是给output的
691
- return gr.update(visible=False) # 教程:https://www.gradio.app/guides/blocks-and-event-listeners#updating-component-configurations
692
  else:
693
- return gr.update(visible=True)
694
-
695
- def copy_text(text):
696
- pyperclip.copy(text)
697
-
698
  test_postprocess = gr.Blocks()
699
 
700
  with test_postprocess:
701
  gr.Markdown(
702
  """
703
- 后处理Simple或Full标签页输出的Transcription里的文本
704
  """
705
  )
706
  with gr.Row():
707
  with gr.Column():
708
  input_text = gr.TextArea(label="输入文本", placeholder="在此处粘贴你的待处理文本")
709
- apply_correction = gr.Checkbox(label="文本纠错", value=False)
710
- auto_punc = gr.Checkbox(label="自动添加标点符号", value=False)
711
- separator = gr.Text(label="分隔符(一般是逗号,或换行\\n)", value="")
712
- remove_words = gr.Text(label="去掉的语气词", value="呢,啊,哦,嗯,嘛,吧,呀,哈,哇,呐,噢,嘞,嘛")
 
 
 
 
 
713
  submit_btn = gr.Button("提交")
714
  with gr.Column():
715
  output_text = gr.TextArea(label="输出文本", interactive=True).style(show_copy_button=True)
@@ -718,10 +652,14 @@ def create_ui(app_config: ApplicationConfig):
718
  target_word = gr.Text(label="替换后的字符")
719
  replace_btn = gr.Button("替换")
720
  copy_btn = gr.Button("复制到剪贴板")
721
- auto_punc.change(switch_punc_method, inputs=[auto_punc], outputs=[separator])
722
- submit_btn.click(post_processing, inputs=[input_text, apply_correction, auto_punc, separator, remove_words], outputs=output_text)
 
 
 
723
  replace_btn.click(replace, inputs=[output_text, src_word, target_word], outputs=output_text)
724
  copy_btn.click(copy_text, inputs=output_text)
 
725
  demo = gr.TabbedInterface([simple_transcribe, full_transcribe, test_postprocess], tab_names=["Simple", "Full", "Text Postprocess"])
726
 
727
  # Queue up the demo
 
10
  import tempfile
11
  import zipfile
12
  import numpy as np
 
13
  import torch
14
 
15
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
31
 
32
  from src.download import ExceededMaximumDuration, download_url
33
  from src.utils import optional_int, slugify, write_srt, write_vtt
34
+ from src.utils import post_processing, replace, copy_text, on_token_change, num_tokens_from_messages, chat_with_gpt
35
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
36
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
37
  from src.whisper.whisperFactory import create_whisper_container
 
614
  gr.Text(label="Segments")
615
  ])
616
 
617
+ def switch_punc_method(use_chatgpt, auto_punc):
618
+ if use_chatgpt == True and auto_punc == True:
619
+ return gr.update(), gr.update(), gr.update()
620
+ elif use_chatgpt == True and auto_punc == False:
621
+ return gr.update(visible=True), gr.update(visible=True, interactive=True), gr.update(visible=False)
622
+ elif use_chatgpt == False and auto_punc == True:
623
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  else:
625
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
626
+
 
 
 
627
  test_postprocess = gr.Blocks()
628
 
629
  with test_postprocess:
630
  gr.Markdown(
631
  """
632
+ 后处理Simple或Full标签页输出的Transcription里的文本,也可以单独使用
633
  """
634
  )
635
  with gr.Row():
636
  with gr.Column():
637
  input_text = gr.TextArea(label="输入文本", placeholder="在此处粘贴你的待处理文本")
638
+ tokens_count = gr.Markdown(label="Tokens 计数: 0", visible=False)
639
+ use_chatgpt = gr.Checkbox(label="使用ChatGPT自动纠正错别字和添加标点符号", value=False)
640
+ user_token = gr.Textbox(value='', placeholder="OpenAI API Key", type="password", visible=False,
641
+ label="输入你的 OpenAI API Key. 你可以从这里(https://platform.openai.com/account/api-keys)获取.\
642
+ \n⚠ 注意!使用ChatGPT来处理文本会消耗大量的tokens,免费版用户谨慎使用!")
643
+ apply_correction = gr.Checkbox(label="使用pycorrector纠正错别字", value=False)
644
+ auto_punc = gr.Checkbox(label="使用paddle auto punc自动添加标点符号", value=False)
645
+ separator = gr.Text(label="使用统一的标点符号(比如逗号,或换行\\n)", value=",")
646
+ remove_words = gr.Text(label="去掉的语气助词", value="呢,啊,哦,嗯,嘛,吧,呀,哈,哇,呐,噢,嘞,嘛")
647
  submit_btn = gr.Button("提交")
648
  with gr.Column():
649
  output_text = gr.TextArea(label="输出文本", interactive=True).style(show_copy_button=True)
 
652
  target_word = gr.Text(label="替换后的字符")
653
  replace_btn = gr.Button("替换")
654
  copy_btn = gr.Button("复制到剪贴板")
655
+ input_text.change(num_tokens_from_messages, inputs=[input_text], outputs=[tokens_count])
656
+ auto_punc.change(switch_punc_method, inputs=[use_chatgpt, auto_punc], outputs=[tokens_count, user_token, separator])
657
+ use_chatgpt.change(switch_punc_method, inputs=[use_chatgpt, auto_punc], outputs=[tokens_count, user_token, separator])
658
+ user_token.change(on_token_change, inputs=[user_token], outputs=[])
659
+ submit_btn.click(post_processing, inputs=[input_text, use_chatgpt, user_token, apply_correction, auto_punc, separator, remove_words], outputs=output_text)
660
  replace_btn.click(replace, inputs=[output_text, src_word, target_word], outputs=output_text)
661
  copy_btn.click(copy_text, inputs=output_text)
662
+
663
  demo = gr.TabbedInterface([simple_transcribe, full_transcribe, test_postprocess], tab_names=["Simple", "Full", "Text Postprocess"])
664
 
665
  # Queue up the demo
docs/text_postprocess.png ADDED
requirements.txt CHANGED
@@ -10,6 +10,8 @@ more_itertools
10
  pycorrector
11
  paddlepaddle == 2.4.0
12
  paddlehub
13
- aiobotocore
14
- botocore
15
- pyperclip
 
 
 
10
  pycorrector
11
  paddlepaddle == 2.4.0
12
  paddlehub
13
+ -U aiobotocore
14
+ -U botocore
15
+ pyperclip
16
+ openai
17
+ tiktoken
src/utils.py CHANGED
@@ -1,12 +1,14 @@
1
  import textwrap
2
  import unicodedata
3
  import re
4
-
5
  import zlib
6
  from typing import Iterator, TextIO, Union
7
  import tqdm
8
-
9
  import urllib3
 
 
10
 
11
 
12
  def exact_div(x, y):
@@ -242,4 +244,151 @@ def download_file(url: str, destination: str):
242
  break
243
 
244
  output.write(buffer)
245
- loop.update(len(buffer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import textwrap
2
  import unicodedata
3
  import re
4
+ import time
5
  import zlib
6
  from typing import Iterator, TextIO, Union
7
  import tqdm
8
+ import pyperclip
9
  import urllib3
10
+ import openai
11
+ import tiktoken
12
 
13
 
14
  def exact_div(x, y):
 
244
  break
245
 
246
  output.write(buffer)
247
+ loop.update(len(buffer))
248
+
249
+ # -------------used for text post processing tab----------------
250
+ system_prompt = "You are a helpful assistant."
251
+ user_prompt = "请帮我把下面的文本纠正错别字并添加合适的标点符号,返回的消息只要处理后的文本:"
252
+
253
+ def get_chunks(s, maxlength, separator=None):
254
+ start = 0
255
+ end = 0
256
+ while start + maxlength < len(s) and end != -1:
257
+ if separator is not None:
258
+ end = s.rfind(separator, start, start + maxlength + 1)
259
+ segment = s[start:end]
260
+ yield segment.replace(separator, "")
261
+ start = end +1
262
+ else:
263
+ end = start + maxlength
264
+ yield s[start:end]
265
+ start = end
266
+
267
+ yield s[start:]
268
+
269
+ def post_processing(text, use_chatgpt, user_token, apply_correction, auto_punc, separator, remove_words):
270
+ # print(f"==>> separator: {separator}")
271
+ original_separator1 = " "
272
+ original_separator2 = ","
273
+
274
+ if use_chatgpt == True:
275
+ if user_token == "":
276
+ text = "请先设置你的OpenAI API Key,然后再重试"
277
+ return text
278
+ else:
279
+ text = chat_with_gpt(text, system_prompt, user_prompt)
280
+ return text
281
+ # 对于长文本需要先分段再推理,推理完再合并
282
+ elif auto_punc == True:
283
+ # 自动分段文本之前先去除原有的标点符号
284
+ text = text.replace(original_separator1, "")
285
+ text = text.replace(original_separator2, "")
286
+ import paddlehub as hub
287
+ model = hub.Module(name='auto_punc', version='1.0.0')
288
+ t3 = time.time()
289
+ # split long text to short text less than max_length and store them in list
290
+ max_length = 256
291
+ chunks = list(get_chunks(text, max_length))
292
+ results = []
293
+ results = model.add_puncs(chunks, max_length=max_length)
294
+ text = ",".join(results) # 分段处硬编码成使用中文逗号分割
295
+ t4 = time.time()
296
+ print("Auto punc finished. Cost time: {:.2f}s".format(t4-t3))
297
+ # print(f"==>> text after auto punc: {text}")
298
+ else:
299
+ # 将空格全部统一替换成一种分隔符
300
+ if separator == "\\n":
301
+ # 直接使用separator会无法换行
302
+ text = text.replace(original_separator1, "\n")
303
+ text = text.replace(original_separator2, "\n")
304
+ else:
305
+ text = text.replace(original_separator1, separator)
306
+ text = text.replace(original_separator2, separator)
307
+
308
+ if apply_correction == True:
309
+ import pycorrector
310
+ print("Start correcting...")
311
+ t1 = time.time()
312
+ text, detail = pycorrector.correct(text)
313
+ t2 = time.time()
314
+ print("Correcting finished. Cost time: {:.2f}s".format(t2-t1))
315
+ print(f"==>> detail: {detail}")
316
+
317
+ # 去掉语气词
318
+ t5 = time.time()
319
+ remove_words = remove_words.split(",") + remove_words.split(",") + remove_words.split(" ")
320
+ for word in remove_words:
321
+ text = text.replace(word, "")
322
+ t6 = time.time()
323
+ print("Remove words finished. Cost time: {:.2f}s".format(t6-t5))
324
+
325
+ return text
326
+
327
+ def replace(text, src_word, target_word):
328
+ text = text.replace(src_word, target_word)
329
+ return text
330
+
331
+ def copy_text(text):
332
+ pyperclip.copy(text)
333
+
334
+ def num_tokens_from_messages(message):
335
+ """Return the number of tokens used by a list of messages."""
336
+ model="gpt-3.5-turbo-0613"
337
+ try:
338
+ encoding = tiktoken.encoding_for_model(model)
339
+ except KeyError:
340
+ print("Warning: model not found. Using cl100k_base encoding.")
341
+ encoding = tiktoken.get_encoding("cl100k_base")
342
+ if model in {
343
+ "gpt-3.5-turbo-0613",
344
+ "gpt-3.5-turbo-16k-0613",
345
+ "gpt-4-0314",
346
+ "gpt-4-32k-0314",
347
+ "gpt-4-0613",
348
+ "gpt-4-32k-0613",
349
+ }:
350
+ tokens_per_message = 3
351
+ tokens_per_name = 1
352
+ elif model == "gpt-3.5-turbo-0301":
353
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
354
+ tokens_per_name = -1 # if there's a name, the role is omitted
355
+ elif "gpt-3.5-turbo" in model:
356
+ print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
357
+ return num_tokens_from_messages(message, model="gpt-3.5-turbo-0613")
358
+ elif "gpt-4" in model:
359
+ print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
360
+ return num_tokens_from_messages(message, model="gpt-4-0613")
361
+ else:
362
+ raise NotImplementedError(
363
+ f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
364
+ )
365
+ num_tokens = 0
366
+ num_tokens += tokens_per_message
367
+ message = user_prompt + message
368
+ num_tokens += len(encoding.encode(message))
369
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
370
+ return f"Tokens 计数: {num_tokens}"
371
+
372
+ def on_token_change(user_token):
373
+ openai.api_key = user_token
374
+
375
+ def chat_with_gpt(input_message, system_prompt, user_prompt, temperature=0, max_tokens=4096):
376
+ system_content = [{ "role": "system", "content": system_prompt }]
377
+ user_content = [{ "role": "user", "content": user_prompt + input_message }]
378
+ try:
379
+ completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=system_content + user_content, temperature=temperature, max_tokens=max_tokens)
380
+ response_msg = completion.choices[0].message['content']
381
+
382
+ prompt_tokens = completion['usage']['prompt_tokens']
383
+ completion_tokens = completion['usage']['completion_tokens']
384
+ total_tokens = completion['usage']['total_tokens']
385
+ print(f"==>> prompt_tokens: {prompt_tokens}")
386
+ print(f"==>> completion_tokens: {completion_tokens}")
387
+ print(f"==>> total_tokens: {total_tokens}")
388
+ return response_msg
389
+
390
+ except Exception as e:
391
+ return f"Error: {e}"
392
+
393
+
394
+ # -------------used for text post processing tab----------------
src/vad.py CHANGED
@@ -203,8 +203,8 @@ class AbstractTranscription(ABC):
203
  # Detected language
204
  detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
205
 
206
- print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
207
- segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
208
 
209
  perf_start_time = time.perf_counter()
210
 
@@ -213,7 +213,7 @@ class AbstractTranscription(ABC):
213
  segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
214
 
215
  perf_end_time = time.perf_counter()
216
- print("Whisper took {} seconds".format(perf_end_time - perf_start_time))
217
 
218
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
219
 
 
203
  # Detected language
204
  detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
205
 
206
+ # print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
207
+ # segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
208
 
209
  perf_start_time = time.perf_counter()
210
 
 
213
  segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
214
 
215
  perf_end_time = time.perf_counter()
216
+ # print("Whisper took {} seconds".format(perf_end_time - perf_start_time))
217
 
218
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
219