File size: 5,166 Bytes
f0a085b 4691b00 f0a085b 4691b00 ada1a34 f0a085b 5e3d623 f0a085b b825d1e f0a085b b825d1e f0a085b 5e3d623 f0a085b 21fcf42 f0a085b cfd7673 0a7d03f f0a085b 0a7d03f f0a085b 8b09827 f0a085b ada1a34 5e3d623 a702d26 5e3d623 f5abfaa cfd7673 815053b cfd7673 21fcf42 f0a085b c98790c cfd7673 f0a085b cfd7673 ada1a34 5e3d623 ada1a34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
from transformers import pipeline, MarianMTModel, MarianTokenizer
import numpy as np
import sherpa_onnx
from model import sample_rate
@dataclass
class Segment:
start: float
duration: float
text: str = ""
cn_text: str = ""
@property
def end(self):
return self.start + self.duration
def __str__(self):
s = f"0{timedelta(seconds=self.start)}"[:-3]
s += " --> "
s += f"0{timedelta(seconds=self.end)}"[:-3]
s = s.replace(".", ",")
s += "\n"
s += self.text
s += "\n"
s += self.cn_text
return s
def decode(
recognizer: sherpa_onnx.OfflineRecognizer,
vad: sherpa_onnx.VoiceActivityDetector,
punct: Optional[sherpa_onnx.OfflinePunctuation],
filename: str,
) -> str:
ffmpeg_cmd = [
"ffmpeg",
"-i",
filename,
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"-ac",
"1",
"-ar",
str(sample_rate),
"-",
]
process = subprocess.Popen(
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
)
frames_per_read = int(sample_rate * 100) # 100 second
window_size = 512
buffer = []
segment_list = []
logging.info("Started!")
all_text = []
is_last = False
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_last:
break
is_last = True
data = np.zeros(sample_rate, dtype=np.int16)
samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768
buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]
streams = []
segments = []
while not vad.empty():
segment = Segment(
start=vad.front.start / sample_rate,
duration=len(vad.front.samples) / sample_rate,
)
segments.append(segment)
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, vad.front.samples)
streams.append(stream)
vad.pop()
for s in streams:
recognizer.decode_stream(s)
for seg, stream in zip(segments, streams):
en_text = stream.result.text.strip()
seg.text = en_text
if len(seg.text) == 0:
logging.info("Skip empty segment")
continue
seg.cn_text = _llm_translator.translate(en_text)
if len(all_text) == 0:
all_text.append(seg.text)
elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1:
all_text.append(" ")
all_text.append(seg.text)
else:
all_text.append(seg.text)
if punct is not None:
seg.text = punct.add_punctuation(seg.text)
segment_list.append(seg)
all_text = "".join(all_text)
if punct is not None:
all_text = punct.add_punctuation(all_text)
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text
def translate_en_to_cn(src_text: str, ) -> str:
model_name = "Helsinki-NLP/opus-mt-en-zh"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
res = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return res
class LLMTranslator:
_tokenizer: MarianTokenizer
_model: MarianMTModel
def __init__(self):
model_name = "Helsinki-NLP/opus-mt-en-zh"
self._tokenizer = MarianTokenizer.from_pretrained(model_name)
self._model = MarianMTModel.from_pretrained(model_name)
def translate(self, src_text: str) -> str:
translated = self._model.generate(**self._tokenizer(src_text, return_tensors="pt", padding=True))
res = [self._tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return "".join(str(itemText) for itemText in res)
_llm_translator = LLMTranslator() |