|
import torch |
|
import numpy as np |
|
from torchaudio import functional as F |
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
from starlette.exceptions import HTTPException |
|
import sys |
|
|
|
|
|
|
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
|
|
def preprocess_inputs(inputs, sampling_rate): |
|
inputs = ffmpeg_read(inputs, sampling_rate) |
|
|
|
if sampling_rate != 16000: |
|
inputs = F.resample( |
|
torch.from_numpy(inputs), sampling_rate, 16000 |
|
).numpy() |
|
|
|
if len(inputs.shape) != 1: |
|
logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}") |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}" |
|
) |
|
|
|
|
|
diarizer_inputs = torch.from_numpy(inputs).float() |
|
diarizer_inputs = diarizer_inputs.unsqueeze(0) |
|
|
|
return inputs, diarizer_inputs |
|
|
|
|
|
def diarize_audio(diarizer_inputs, diarization_pipeline, parameters): |
|
diarization = diarization_pipeline( |
|
{"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate}, |
|
num_speakers=parameters.num_speakers, |
|
min_speakers=parameters.min_speakers, |
|
max_speakers=parameters.max_speakers, |
|
) |
|
|
|
segments = [] |
|
for segment, track, label in diarization.itertracks(yield_label=True): |
|
segments.append( |
|
{ |
|
"segment": {"start": segment.start, "end": segment.end}, |
|
"track": track, |
|
"label": label, |
|
} |
|
) |
|
|
|
|
|
|
|
new_segments = [] |
|
prev_segment = cur_segment = segments[0] |
|
|
|
for i in range(1, len(segments)): |
|
cur_segment = segments[i] |
|
|
|
|
|
if cur_segment["label"] != prev_segment["label"] and i < len(segments): |
|
|
|
new_segments.append( |
|
{ |
|
"segment": { |
|
"start": prev_segment["segment"]["start"], |
|
"end": cur_segment["segment"]["start"], |
|
}, |
|
"speaker": prev_segment["label"], |
|
} |
|
) |
|
prev_segment = segments[i] |
|
|
|
|
|
new_segments.append( |
|
{ |
|
"segment": { |
|
"start": prev_segment["segment"]["start"], |
|
"end": cur_segment["segment"]["end"], |
|
}, |
|
"speaker": prev_segment["label"], |
|
} |
|
) |
|
|
|
return new_segments |
|
|
|
|
|
def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list: |
|
|
|
end_timestamps = np.array( |
|
[chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript]) |
|
segmented_preds = [] |
|
|
|
|
|
for segment in new_segments: |
|
|
|
end_time = segment["segment"]["end"] |
|
|
|
upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
|
|
|
if group_by_speaker: |
|
segmented_preds.append( |
|
{ |
|
"speaker": segment["speaker"], |
|
"text": "".join( |
|
[chunk["text"] for chunk in transcript[: upto_idx + 1]] |
|
), |
|
"timestamp": ( |
|
transcript[0]["timestamp"][0], |
|
transcript[upto_idx]["timestamp"][1], |
|
), |
|
} |
|
) |
|
else: |
|
for i in range(upto_idx + 1): |
|
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) |
|
|
|
|
|
transcript = transcript[upto_idx + 1:] |
|
end_timestamps = end_timestamps[upto_idx + 1:] |
|
|
|
if len(end_timestamps) == 0: |
|
break |
|
|
|
return segmented_preds |
|
|
|
|
|
def diarize(diarization_pipeline, file, parameters, asr_outputs): |
|
_, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) |
|
|
|
segments = diarize_audio( |
|
diarizer_inputs, |
|
diarization_pipeline, |
|
parameters |
|
) |
|
|
|
return post_process_segments_and_transcripts( |
|
segments, asr_outputs["chunks"], group_by_speaker=False |
|
) |