drown0315 commited on
Commit
ada1a34
·
1 Parent(s): 89a8d65

feat: 增加双语字幕

Browse files
Files changed (1) hide show
  1. decode.py +33 -1
decode.py CHANGED
@@ -19,6 +19,7 @@ import subprocess
19
  from dataclasses import dataclass
20
  from datetime import timedelta
21
  from typing import Optional
 
22
 
23
  import numpy as np
24
  import sherpa_onnx
@@ -122,7 +123,9 @@ def decode(
122
  recognizer.decode_stream(s)
123
 
124
  for seg, stream in zip(segments, streams):
125
- seg.text = stream.result.text.strip()
 
 
126
  if len(seg.text) == 0:
127
  logging.info("Skip empty segment")
128
  continue
@@ -143,3 +146,32 @@ def decode(
143
  all_text = punct.add_punctuation(all_text)
144
 
145
  return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from dataclasses import dataclass
20
  from datetime import timedelta
21
  from typing import Optional
22
+ from transformers import pipeline, MarianMTModel, MarianTokenizer
23
 
24
  import numpy as np
25
  import sherpa_onnx
 
123
  recognizer.decode_stream(s)
124
 
125
  for seg, stream in zip(segments, streams):
126
+ en_text = stream.result.text.strip()
127
+ cn_text = _llm_translator.translate(en_text)
128
+ seg.text = en_text +"\n"+cn_text
129
  if len(seg.text) == 0:
130
  logging.info("Skip empty segment")
131
  continue
 
146
  all_text = punct.add_punctuation(all_text)
147
 
148
  return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text
149
+
150
+
151
+
152
+
153
+ def translate_en_to_cn(src_text: str, ) -> str:
154
+
155
+ model_name = "Helsinki-NLP/opus-mt-en-zh"
156
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
157
+ model = MarianMTModel.from_pretrained(model_name)
158
+ translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
159
+ res = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
160
+ return res
161
+
162
+
163
+ class LLMTranslator:
164
+ _tokenizer: MarianTokenizer
165
+ _model: MarianMTModel
166
+ def __init__(self):
167
+ model_name = "Helsinki-NLP/opus-mt-en-zh"
168
+ self._tokenizer = MarianTokenizer.from_pretrained(model_name)
169
+ self._model = MarianMTModel.from_pretrained(model_name)
170
+
171
+ def translate(self, src_text: str) -> str:
172
+ translated = self._model.generate(**self._tokenizer(src_text, return_tensors="pt", padding=True))
173
+ res = [self._tokenizer.decode(t, skip_special_tokens=True) for t in translated]
174
+ return res
175
+
176
+
177
+ _llm_translator = LLMTranslator()