crystal-technologies's picture
Upload 1287 files
2d8da09
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from nemo.collections.asr.metrics.der import evaluate_der
from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR
from nemo.collections.asr.parts.utils.manifest_utils import read_file
from nemo.collections.asr.parts.utils.speaker_utils import (
get_uniqname_from_filepath,
labels_to_pyannote_object,
rttm_to_labels,
)
"""
Evaluation script for diarization with ASR.
Calculates Diarization Error Rate (DER) with RTTM files and WER and cpWER with CTM files.
In the output ctm_eval.csv file in the output folder,
session-level DER, WER, cpWER and speaker counting accuracies are evaluated.
- Evaluation mode
diar_eval_mode == "full":
DIHARD challenge style evaluation, the most strict way of evaluating diarization
(collar, ignore_overlap) = (0.0, False)
diar_eval_mode == "fair":
Evaluation setup used in VoxSRC challenge
(collar, ignore_overlap) = (0.25, False)
diar_eval_mode == "forgiving":
Traditional evaluation setup
(collar, ignore_overlap) = (0.25, True)
diar_eval_mode == "all":
Compute all three modes (default)
Use CTM files to calculate WER and cpWER
```
python eval_diar_with_asr.py \
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
--hyp_ctm_list="/path/to/hypothesis_ctm_filepaths.list" \
--ref_ctm_list="/path/to/reference_ctm_filepaths.list" \
--root_path="/path/to/output/directory"
```
Use .json files to calculate WER and cpWER
```
python eval_diar_with_asr.py \
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
--hyp_json_list="/path/to/hypothesis_json_filepaths.list" \
--ref_ctm_list="/path/to/reference_ctm_filepaths.list" \
--root_path="/path/to/output/directory"
```
Only use RTTMs to calculate DER
```
python eval_diar_with_asr.py \
--hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
--ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
--root_path="/path/to/output/directory"
```
"""
def get_pyannote_objs_from_rttms(rttm_file_path_list):
"""Generate PyAnnote objects from RTTM file list
"""
pyannote_obj_list = []
for rttm_file in rttm_file_path_list:
rttm_file = rttm_file.strip()
if rttm_file is not None and os.path.exists(rttm_file):
uniq_id = get_uniqname_from_filepath(rttm_file)
ref_labels = rttm_to_labels(rttm_file)
reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id)
pyannote_obj_list.append([uniq_id, reference])
return pyannote_obj_list
def make_meta_dict(hyp_rttm_list, ref_rttm_list):
"""Create a temporary `audio_rttm_map_dict` for evaluation
"""
meta_dict = {}
for k, rttm_file in enumerate(ref_rttm_list):
uniq_id = get_uniqname_from_filepath(rttm_file)
meta_dict[uniq_id] = {"rttm_filepath": rttm_file.strip()}
if hyp_rttm_list is not None:
hyp_rttm_file = hyp_rttm_list[k]
meta_dict[uniq_id].update({"hyp_rttm_filepath": hyp_rttm_file.strip()})
return meta_dict
def make_trans_info_dict(hyp_json_list_path):
"""Create `trans_info_dict` from the `.json` files
"""
trans_info_dict = {}
for json_file in hyp_json_list_path:
json_file = json_file.strip()
with open(json_file) as jsf:
json_data = json.load(jsf)
uniq_id = get_uniqname_from_filepath(json_file)
trans_info_dict[uniq_id] = json_data
return trans_info_dict
def read_file_path(list_path):
"""Read file path and strip to remove line change symbol
"""
return sorted([x.strip() for x in read_file(list_path)])
def main(
hyp_rttm_list_path: str,
ref_rttm_list_path: str,
hyp_ctm_list_path: str,
ref_ctm_list_path: str,
hyp_json_list_path: str,
diar_eval_mode: str = "all",
root_path: str = "./",
):
# Read filepath list files
hyp_rttm_list = read_file_path(hyp_rttm_list_path) if hyp_rttm_list_path else None
ref_rttm_list = read_file_path(ref_rttm_list_path) if ref_rttm_list_path else None
hyp_ctm_list = read_file_path(hyp_ctm_list_path) if hyp_ctm_list_path else None
ref_ctm_list = read_file_path(ref_ctm_list_path) if ref_ctm_list_path else None
hyp_json_list = read_file_path(hyp_json_list_path) if hyp_json_list_path else None
audio_rttm_map_dict = make_meta_dict(hyp_rttm_list, ref_rttm_list)
trans_info_dict = make_trans_info_dict(hyp_json_list) if hyp_json_list else None
all_hypothesis = get_pyannote_objs_from_rttms(hyp_rttm_list)
all_reference = get_pyannote_objs_from_rttms(ref_rttm_list)
diar_score = evaluate_der(
audio_rttm_map_dict=audio_rttm_map_dict,
all_reference=all_reference,
all_hypothesis=all_hypothesis,
diar_eval_mode=diar_eval_mode,
)
# Get session-level diarization error rate and speaker counting error
der_results = OfflineDiarWithASR.gather_eval_results(
diar_score=diar_score,
audio_rttm_map_dict=audio_rttm_map_dict,
trans_info_dict=trans_info_dict,
root_path=root_path,
)
if ref_ctm_list is not None:
# Calculate WER and cpWER if reference CTM files exist
if hyp_ctm_list is not None:
wer_results = OfflineDiarWithASR.evaluate(
audio_file_list=hyp_rttm_list,
hyp_trans_info_dict=None,
hyp_ctm_file_list=hyp_ctm_list,
ref_ctm_file_list=ref_ctm_list,
)
elif hyp_json_list is not None:
wer_results = OfflineDiarWithASR.evaluate(
audio_file_list=hyp_rttm_list,
hyp_trans_info_dict=trans_info_dict,
hyp_ctm_file_list=None,
ref_ctm_file_list=ref_ctm_list,
)
else:
raise ValueError("Hypothesis information is not provided in the correct format.")
else:
wer_results = {}
# Print average DER, WER and cpWER
OfflineDiarWithASR.print_errors(der_results=der_results, wer_results=wer_results)
# Save detailed session-level evaluation results in `root_path`.
OfflineDiarWithASR.write_session_level_result_in_csv(
der_results=der_results,
wer_results=wer_results,
root_path=root_path,
csv_columns=OfflineDiarWithASR.get_csv_columns(),
)
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--hyp_rttm_list", help="path to the filelist of hypothesis RTTM files", type=str, required=True, default=None
)
parser.add_argument(
"--ref_rttm_list", help="path to the filelist of reference RTTM files", type=str, required=True, default=None
)
parser.add_argument(
"--hyp_ctm_list", help="path to the filelist of hypothesis CTM files", type=str, required=False, default=None
)
parser.add_argument(
"--ref_ctm_list", help="path to the filelist of reference CTM files", type=str, required=False, default=None
)
parser.add_argument(
"--hyp_json_list",
help="(Optional) path to the filelist of hypothesis JSON files",
type=str,
required=False,
default=None,
)
parser.add_argument(
"--diar_eval_mode",
help='evaluation mode: "all", "full", "fair", "forgiving"',
type=str,
required=False,
default="all",
)
parser.add_argument(
"--root_path", help='directory for saving result files', type=str, required=False, default="./"
)
args = parser.parse_args()
main(
args.hyp_rttm_list,
args.ref_rttm_list,
args.hyp_ctm_list,
args.ref_ctm_list,
args.hyp_json_list,
args.diar_eval_mode,
args.root_path,
)