|
""" |
|
Nemo diarizer |
|
""" |
|
import os |
|
import json |
|
|
|
import wget |
|
import matplotlib.pyplot as plt |
|
from omegaconf import OmegaConf |
|
from nemo.collections.asr.models import ClusteringDiarizer |
|
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object |
|
from pyannote.core import notebook |
|
|
|
from diarizers.diarizer import Diarizer |
|
|
|
|
|
class NemoDiarizer(Diarizer): |
|
"""Class for Nemo Diarizer""" |
|
|
|
def __init__(self, audio_path: str, data_dir: str): |
|
""" |
|
Nemo diarizer class |
|
Args: |
|
audio_path (str): the path to the audio file |
|
""" |
|
self.audio_path = audio_path |
|
self.data_dir = data_dir |
|
self.diarization = None |
|
self.manifest_dir = os.path.join(self.data_dir, 'input_manifest.json') |
|
self.model_config = os.path.join(self.data_dir, 'offline_diarization.yaml') |
|
if not os.path.exists(self.model_config): |
|
config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/" \ |
|
"speaker_tasks/diarization/conf/offline_diarization.yaml" |
|
self.model_config = wget.download(config_url, self.data_dir) |
|
self.config = OmegaConf.load(self.model_config) |
|
|
|
def _create_manifest_file(self): |
|
""" |
|
Function that creates inference manifest file |
|
""" |
|
meta = { |
|
'audio_filepath': self.audio_path, |
|
'offset': 0, |
|
'duration': None, |
|
'label': 'infer', |
|
'text': '-', |
|
'num_speakers': None, |
|
'rttm_filepath': None, |
|
'uem_filepath': None |
|
} |
|
with open(self.manifest_dir, 'w') as fp: |
|
json.dump(meta, fp) |
|
fp.write('\n') |
|
|
|
def _apply_config(self, pretrained_speaker_model: str = 'ecapa_tdnn'): |
|
""" |
|
Function that edits the inference configuration file |
|
Args: |
|
pretrained_speaker_model (str): the pre-trained embedding model options are |
|
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn) |
|
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ |
|
speaker_diarization/results.html |
|
""" |
|
|
|
pretrained_vad = 'vad_marblenet' |
|
|
|
self.config.num_workers = 0 |
|
output_dir = os.path.join(self.data_dir, 'outputs') |
|
|
|
self.config.diarizer.manifest_filepath = self.manifest_dir |
|
self.config.diarizer.out_dir = output_dir |
|
self.config.diarizer.ignore_overlap = False |
|
|
|
self.config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model |
|
self.config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 0.5 |
|
self.config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.25 |
|
self.config.diarizer.oracle_vad = False |
|
self.config.diarizer.clustering.parameters.oracle_num_speakers = False |
|
|
|
|
|
self.config.diarizer.vad.model_path = pretrained_vad |
|
self.config.diarizer.vad.window_length_in_sec = 0.15 |
|
self.config.diarizer.vad.shift_length_in_sec = 0.01 |
|
self.config.diarizer.vad.parameters.onset = 0.8 |
|
self.config.diarizer.vad.parameters.offset = 0.6 |
|
self.config.diarizer.vad.parameters.min_duration_on = 0.1 |
|
self.config.diarizer.vad.parameters.min_duration_off = 0.4 |
|
|
|
def diarize_audio(self, pretrained_speaker_model: str = 'ecapa_tdnn'): |
|
""" |
|
function that diarizes the audio |
|
Args: |
|
pretrained_speaker_model (str): the pre-trained embedding model options are |
|
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn) |
|
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ |
|
speaker_diarization/results.html |
|
""" |
|
self._create_manifest_file() |
|
self._apply_config(pretrained_speaker_model) |
|
sd_model = ClusteringDiarizer(cfg=self.config) |
|
sd_model.diarize() |
|
audio_file_name_without_extension = os.path.basename(self.audio_path).rsplit('.', 1)[0] |
|
output_diarization_pred = f'{self.data_dir}/outputs/pred_rttms/' \ |
|
f'{audio_file_name_without_extension}.rttm' |
|
pred_labels = rttm_to_labels(output_diarization_pred) |
|
self.diarization = labels_to_pyannote_object(pred_labels) |
|
|
|
def get_diarization_figure(self) -> plt.gcf: |
|
""" |
|
Function that return the diarization figure |
|
""" |
|
if not self.diarization: |
|
self.diarize_audio() |
|
figure, ax = plt.subplots() |
|
notebook.plot_annotation(self.diarization, ax=ax, time=True, legend=True) |
|
return plt.gcf() |
|
|