|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
|
|
from nemo.collections.asr.metrics.der import evaluate_der |
|
from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR |
|
from nemo.collections.asr.parts.utils.manifest_utils import read_file |
|
from nemo.collections.asr.parts.utils.speaker_utils import ( |
|
get_uniqname_from_filepath, |
|
labels_to_pyannote_object, |
|
rttm_to_labels, |
|
) |
|
|
|
|
|
""" |
|
Evaluation script for diarization with ASR. |
|
Calculates Diarization Error Rate (DER) with RTTM files and WER and cpWER with CTM files. |
|
In the output ctm_eval.csv file in the output folder, |
|
session-level DER, WER, cpWER and speaker counting accuracies are evaluated. |
|
|
|
- Evaluation mode |
|
|
|
diar_eval_mode == "full": |
|
DIHARD challenge style evaluation, the most strict way of evaluating diarization |
|
(collar, ignore_overlap) = (0.0, False) |
|
diar_eval_mode == "fair": |
|
Evaluation setup used in VoxSRC challenge |
|
(collar, ignore_overlap) = (0.25, False) |
|
diar_eval_mode == "forgiving": |
|
Traditional evaluation setup |
|
(collar, ignore_overlap) = (0.25, True) |
|
diar_eval_mode == "all": |
|
Compute all three modes (default) |
|
|
|
|
|
Use CTM files to calculate WER and cpWER |
|
``` |
|
python eval_diar_with_asr.py \ |
|
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ |
|
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ |
|
--hyp_ctm_list="/path/to/hypothesis_ctm_filepaths.list" \ |
|
--ref_ctm_list="/path/to/reference_ctm_filepaths.list" \ |
|
--root_path="/path/to/output/directory" |
|
``` |
|
|
|
Use .json files to calculate WER and cpWER |
|
``` |
|
python eval_diar_with_asr.py \ |
|
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ |
|
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ |
|
--hyp_json_list="/path/to/hypothesis_json_filepaths.list" \ |
|
--ref_ctm_list="/path/to/reference_ctm_filepaths.list" \ |
|
--root_path="/path/to/output/directory" |
|
``` |
|
|
|
Only use RTTMs to calculate DER |
|
``` |
|
python eval_diar_with_asr.py \ |
|
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ |
|
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ |
|
--root_path="/path/to/output/directory" |
|
``` |
|
|
|
""" |
|
|
|
|
|
def get_pyannote_objs_from_rttms(rttm_file_path_list): |
|
"""Generate PyAnnote objects from RTTM file list |
|
""" |
|
pyannote_obj_list = [] |
|
for rttm_file in rttm_file_path_list: |
|
rttm_file = rttm_file.strip() |
|
if rttm_file is not None and os.path.exists(rttm_file): |
|
uniq_id = get_uniqname_from_filepath(rttm_file) |
|
ref_labels = rttm_to_labels(rttm_file) |
|
reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) |
|
pyannote_obj_list.append([uniq_id, reference]) |
|
return pyannote_obj_list |
|
|
|
|
|
def make_meta_dict(hyp_rttm_list, ref_rttm_list): |
|
"""Create a temporary `audio_rttm_map_dict` for evaluation |
|
""" |
|
meta_dict = {} |
|
for k, rttm_file in enumerate(ref_rttm_list): |
|
uniq_id = get_uniqname_from_filepath(rttm_file) |
|
meta_dict[uniq_id] = {"rttm_filepath": rttm_file.strip()} |
|
if hyp_rttm_list is not None: |
|
hyp_rttm_file = hyp_rttm_list[k] |
|
meta_dict[uniq_id].update({"hyp_rttm_filepath": hyp_rttm_file.strip()}) |
|
return meta_dict |
|
|
|
|
|
def make_trans_info_dict(hyp_json_list_path): |
|
"""Create `trans_info_dict` from the `.json` files |
|
""" |
|
trans_info_dict = {} |
|
for json_file in hyp_json_list_path: |
|
json_file = json_file.strip() |
|
with open(json_file) as jsf: |
|
json_data = json.load(jsf) |
|
uniq_id = get_uniqname_from_filepath(json_file) |
|
trans_info_dict[uniq_id] = json_data |
|
return trans_info_dict |
|
|
|
|
|
def read_file_path(list_path): |
|
"""Read file path and strip to remove line change symbol |
|
""" |
|
return sorted([x.strip() for x in read_file(list_path)]) |
|
|
|
|
|
def main( |
|
hyp_rttm_list_path: str, |
|
ref_rttm_list_path: str, |
|
hyp_ctm_list_path: str, |
|
ref_ctm_list_path: str, |
|
hyp_json_list_path: str, |
|
diar_eval_mode: str = "all", |
|
root_path: str = "./", |
|
): |
|
|
|
|
|
hyp_rttm_list = read_file_path(hyp_rttm_list_path) if hyp_rttm_list_path else None |
|
ref_rttm_list = read_file_path(ref_rttm_list_path) if ref_rttm_list_path else None |
|
hyp_ctm_list = read_file_path(hyp_ctm_list_path) if hyp_ctm_list_path else None |
|
ref_ctm_list = read_file_path(ref_ctm_list_path) if ref_ctm_list_path else None |
|
hyp_json_list = read_file_path(hyp_json_list_path) if hyp_json_list_path else None |
|
|
|
audio_rttm_map_dict = make_meta_dict(hyp_rttm_list, ref_rttm_list) |
|
|
|
trans_info_dict = make_trans_info_dict(hyp_json_list) if hyp_json_list else None |
|
|
|
all_hypothesis = get_pyannote_objs_from_rttms(hyp_rttm_list) |
|
all_reference = get_pyannote_objs_from_rttms(ref_rttm_list) |
|
|
|
diar_score = evaluate_der( |
|
audio_rttm_map_dict=audio_rttm_map_dict, |
|
all_reference=all_reference, |
|
all_hypothesis=all_hypothesis, |
|
diar_eval_mode=diar_eval_mode, |
|
) |
|
|
|
|
|
der_results = OfflineDiarWithASR.gather_eval_results( |
|
diar_score=diar_score, |
|
audio_rttm_map_dict=audio_rttm_map_dict, |
|
trans_info_dict=trans_info_dict, |
|
root_path=root_path, |
|
) |
|
|
|
if ref_ctm_list is not None: |
|
|
|
if hyp_ctm_list is not None: |
|
wer_results = OfflineDiarWithASR.evaluate( |
|
audio_file_list=hyp_rttm_list, |
|
hyp_trans_info_dict=None, |
|
hyp_ctm_file_list=hyp_ctm_list, |
|
ref_ctm_file_list=ref_ctm_list, |
|
) |
|
elif hyp_json_list is not None: |
|
wer_results = OfflineDiarWithASR.evaluate( |
|
audio_file_list=hyp_rttm_list, |
|
hyp_trans_info_dict=trans_info_dict, |
|
hyp_ctm_file_list=None, |
|
ref_ctm_file_list=ref_ctm_list, |
|
) |
|
else: |
|
raise ValueError("Hypothesis information is not provided in the correct format.") |
|
else: |
|
wer_results = {} |
|
|
|
|
|
OfflineDiarWithASR.print_errors(der_results=der_results, wer_results=wer_results) |
|
|
|
|
|
OfflineDiarWithASR.write_session_level_result_in_csv( |
|
der_results=der_results, |
|
wer_results=wer_results, |
|
root_path=root_path, |
|
csv_columns=OfflineDiarWithASR.get_csv_columns(), |
|
) |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--hyp_rttm_list", help="path to the filelist of hypothesis RTTM files", type=str, required=True, default=None |
|
) |
|
parser.add_argument( |
|
"--ref_rttm_list", help="path to the filelist of reference RTTM files", type=str, required=True, default=None |
|
) |
|
parser.add_argument( |
|
"--hyp_ctm_list", help="path to the filelist of hypothesis CTM files", type=str, required=False, default=None |
|
) |
|
parser.add_argument( |
|
"--ref_ctm_list", help="path to the filelist of reference CTM files", type=str, required=False, default=None |
|
) |
|
parser.add_argument( |
|
"--hyp_json_list", |
|
help="(Optional) path to the filelist of hypothesis JSON files", |
|
type=str, |
|
required=False, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--diar_eval_mode", |
|
help='evaluation mode: "all", "full", "fair", "forgiving"', |
|
type=str, |
|
required=False, |
|
default="all", |
|
) |
|
parser.add_argument( |
|
"--root_path", help='directory for saving result files', type=str, required=False, default="./" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
main( |
|
args.hyp_rttm_list, |
|
args.ref_rttm_list, |
|
args.hyp_ctm_list, |
|
args.ref_ctm_list, |
|
args.hyp_json_list, |
|
args.diar_eval_mode, |
|
args.root_path, |
|
) |
|
|