|
|
|
|
|
|
|
import io |
|
import json |
|
import matplotlib as mpl |
|
import matplotlib.pyplot as plt |
|
import mmap |
|
import numpy as np |
|
import soundfile |
|
import torchaudio |
|
import torch |
|
from pydub import AudioSegment |
|
|
|
|
|
|
|
import math |
|
from simuleval.data.segments import SpeechSegment, EmptySegment |
|
from seamless_communication.streaming.agents.seamless_streaming_s2st import ( |
|
SeamlessStreamingS2STVADAgent, |
|
) |
|
|
|
from simuleval.utils.arguments import cli_argument_list |
|
from simuleval import options |
|
|
|
|
|
from typing import Union, List |
|
from simuleval.data.segments import Segment, TextSegment |
|
from simuleval.agents.pipeline import TreeAgentPipeline |
|
from simuleval.agents.states import AgentStates |
|
|
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
|
|
|
|
class AudioFrontEnd: |
|
def __init__(self, wav_file, segment_size) -> None: |
|
self.samples, self.sample_rate = soundfile.read(wav_file) |
|
print(self.sample_rate, "sample rate") |
|
assert self.sample_rate == SAMPLE_RATE |
|
|
|
self.samples = self.samples |
|
self.segment_size = segment_size |
|
self.step = 0 |
|
|
|
def send_segment(self): |
|
""" |
|
This is the front-end logic in simuleval instance.py |
|
""" |
|
|
|
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) |
|
|
|
if self.step < len(self.samples): |
|
if self.step + num_samples >= len(self.samples): |
|
samples = self.samples[self.step :] |
|
is_finished = True |
|
else: |
|
samples = self.samples[self.step : self.step + num_samples] |
|
is_finished = False |
|
self.samples = self.samples[self.step:] |
|
self.step = min(self.step + num_samples, len(self.samples)) |
|
segment = SpeechSegment( |
|
content=samples, |
|
sample_rate=self.sample_rate, |
|
finished=is_finished, |
|
) |
|
else: |
|
|
|
segment = EmptySegment( |
|
finished=True, |
|
) |
|
self.step = 0 |
|
self.samples = [] |
|
return segment |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_segments(self, wav): |
|
new_samples, _ = soundfile.read(wav) |
|
self.samples = np.concatenate((self.samples, new_samples)) |
|
|
|
|
|
class OutputSegments: |
|
def __init__(self, segments: Union[List[Segment], Segment]): |
|
if isinstance(segments, Segment): |
|
segments = [segments] |
|
self.segments: List[Segment] = [s for s in segments] |
|
|
|
@property |
|
def is_empty(self): |
|
return all(segment.is_empty for segment in self.segments) |
|
|
|
@property |
|
def finished(self): |
|
return all(segment.finished for segment in self.segments) |
|
|
|
|
|
def get_audiosegment(samples, sr): |
|
b = io.BytesIO() |
|
soundfile.write(b, samples, samplerate=sr, format="wav") |
|
b.seek(0) |
|
return AudioSegment.from_file(b) |
|
|
|
|
|
def reset_states(system, states): |
|
if isinstance(system, TreeAgentPipeline): |
|
states_iter = states.values() |
|
else: |
|
states_iter = states |
|
for state in states_iter: |
|
state.reset() |
|
|
|
|
|
def get_states_root(system, states) -> AgentStates: |
|
if isinstance(system, TreeAgentPipeline): |
|
|
|
return states[system.source_module] |
|
else: |
|
|
|
return system.states[0] |
|
|
|
|
|
def build_streaming_system(model_configs, agent_class): |
|
parser = options.general_parser() |
|
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1") |
|
|
|
agent_class.add_args(parser) |
|
args, _ = parser.parse_known_args(cli_argument_list(model_configs)) |
|
system = agent_class.from_args(args) |
|
return system |
|
|
|
|
|
def run_streaming_inference(system, audio_frontend, system_states, tgt_lang): |
|
|
|
|
|
|
|
|
|
|
|
delays = {"s2st": [], "s2tt": []} |
|
prediction_lists = {"s2st": [], "s2tt": []} |
|
speech_durations = [] |
|
curr_delay = 0 |
|
target_sample_rate = None |
|
|
|
while True: |
|
input_segment = audio_frontend.send_segment() |
|
input_segment.tgt_lang = tgt_lang |
|
curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000 |
|
if input_segment.finished: |
|
|
|
get_states_root(system, system_states).source_finished = True |
|
|
|
if isinstance(input_segment, EmptySegment): |
|
return None, None, None, None |
|
output_segments = OutputSegments(system.pushpop(input_segment, system_states)) |
|
if not output_segments.is_empty: |
|
for segment in output_segments.segments: |
|
|
|
|
|
if isinstance(segment, SpeechSegment): |
|
pred_duration = 1000 * len(segment.content) / segment.sample_rate |
|
speech_durations.append(pred_duration) |
|
delays["s2st"].append(curr_delay) |
|
prediction_lists["s2st"].append(segment.content) |
|
target_sample_rate = segment.sample_rate |
|
elif isinstance(segment, TextSegment): |
|
delays["s2tt"].append(curr_delay) |
|
prediction_lists["s2tt"].append(segment.content) |
|
print(curr_delay, segment.content) |
|
if output_segments.finished: |
|
reset_states(system, system_states) |
|
if input_segment.finished: |
|
|
|
|
|
break |
|
return delays, prediction_lists, speech_durations, target_sample_rate |
|
|
|
|
|
def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations): |
|
|
|
intervals = [] |
|
|
|
start = prev_end = prediction_offset = delays["s2st"][0] |
|
target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000) |
|
|
|
for i, delay in enumerate(delays["s2st"]): |
|
start = max(prev_end, delay) |
|
|
|
if start > prev_end: |
|
|
|
target_samples += [0.0] * int( |
|
target_sample_rate * (start - prev_end) / 1000 |
|
) |
|
|
|
target_samples += prediction_lists["s2st"][i] |
|
duration = speech_durations[i] |
|
prev_end = start + duration |
|
intervals.append([start, duration]) |
|
return target_samples, intervals |
|
|
|
|