Last commit not found
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
# | |
# See LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import subprocess | |
from dataclasses import dataclass | |
from datetime import timedelta | |
from typing import Optional | |
import numpy as np | |
import sherpa_onnx | |
from model import sample_rate | |
class Segment: | |
start: float | |
duration: float | |
text: str = "" | |
def end(self): | |
return self.start + self.duration | |
def __str__(self): | |
s = f"0{timedelta(seconds=self.start)}"[:-3] | |
s += " --> " | |
s += f"0{timedelta(seconds=self.end)}"[:-3] | |
s = s.replace(".", ",") | |
s += "\n" | |
s += self.text | |
return s | |
def decode( | |
recognizer: sherpa_onnx.OfflineRecognizer, | |
vad: sherpa_onnx.VoiceActivityDetector, | |
punct: Optional[sherpa_onnx.OfflinePunctuation], | |
filename: str, | |
) -> str: | |
ffmpeg_cmd = [ | |
"ffmpeg", | |
"-i", | |
filename, | |
"-f", | |
"s16le", | |
"-acodec", | |
"pcm_s16le", | |
"-ac", | |
"1", | |
"-ar", | |
str(sample_rate), | |
"-", | |
] | |
process = subprocess.Popen( | |
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL | |
) | |
frames_per_read = int(sample_rate * 100) # 100 second | |
window_size = 512 | |
buffer = [] | |
segment_list = [] | |
logging.info("Started!") | |
all_text = [] | |
is_last = False | |
while True: | |
# *2 because int16_t has two bytes | |
data = process.stdout.read(frames_per_read * 2) | |
if not data: | |
if is_last: | |
break | |
is_last = True | |
data = np.zeros(sample_rate, dtype=np.int16) | |
samples = np.frombuffer(data, dtype=np.int16) | |
samples = samples.astype(np.float32) / 32768 | |
buffer = np.concatenate([buffer, samples]) | |
while len(buffer) > window_size: | |
vad.accept_waveform(buffer[:window_size]) | |
buffer = buffer[window_size:] | |
streams = [] | |
segments = [] | |
while not vad.empty(): | |
segment = Segment( | |
start=vad.front.start / sample_rate, | |
duration=len(vad.front.samples) / sample_rate, | |
) | |
segments.append(segment) | |
stream = recognizer.create_stream() | |
stream.accept_waveform(sample_rate, vad.front.samples) | |
streams.append(stream) | |
vad.pop() | |
for s in streams: | |
recognizer.decode_stream(s) | |
for seg, stream in zip(segments, streams): | |
seg.text = stream.result.text.strip() | |
if len(seg.text) == 0: | |
logging.info("Skip empty segment") | |
continue | |
if len(all_text) == 0: | |
all_text.append(seg.text) | |
elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1: | |
all_text.append(" ") | |
all_text.append(seg.text) | |
else: | |
all_text.append(seg.text) | |
if punct is not None: | |
seg.text = punct.add_punctuation(seg.text) | |
segment_list.append(seg) | |
all_text = "".join(all_text) | |
if punct is not None: | |
all_text = punct.add_punctuation(all_text) | |
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text | |