File size: 8,499 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import os

from nemo.collections.asr.metrics.der import evaluate_der
from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR
from nemo.collections.asr.parts.utils.manifest_utils import read_file
from nemo.collections.asr.parts.utils.speaker_utils import (
    get_uniqname_from_filepath,
    labels_to_pyannote_object,
    rttm_to_labels,
)


"""
Evaluation script for diarization with ASR.
Calculates Diarization Error Rate (DER) with RTTM files and WER and cpWER with CTM files.
In the output ctm_eval.csv file in the output folder,
session-level DER, WER, cpWER and speaker counting accuracies are evaluated.

- Evaluation mode

diar_eval_mode == "full":
    DIHARD challenge style evaluation, the most strict way of evaluating diarization
    (collar, ignore_overlap) = (0.0, False)
diar_eval_mode == "fair":
    Evaluation setup used in VoxSRC challenge
    (collar, ignore_overlap) = (0.25, False)
diar_eval_mode == "forgiving":
    Traditional evaluation setup
    (collar, ignore_overlap) = (0.25, True)
diar_eval_mode == "all":
    Compute all three modes (default)


Use CTM files to calculate WER and cpWER
```
python eval_diar_with_asr.py \
 --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
 --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
 --hyp_ctm_list="/path/to/hypothesis_ctm_filepaths.list" \
 --ref_ctm_list="/path/to/reference_ctm_filepaths.list" \
 --root_path="/path/to/output/directory"
```

Use .json files to calculate WER and cpWER
```
python eval_diar_with_asr.py \
 --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
 --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
 --hyp_json_list="/path/to/hypothesis_json_filepaths.list" \
 --ref_ctm_list="/path/to/reference_ctm_filepaths.list" \
 --root_path="/path/to/output/directory"
```

Only use RTTMs to calculate DER
```
python eval_diar_with_asr.py \
 --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \
 --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \
 --root_path="/path/to/output/directory"
```

"""


def get_pyannote_objs_from_rttms(rttm_file_path_list):
    """Generate PyAnnote objects from RTTM file list
    """
    pyannote_obj_list = []
    for rttm_file in rttm_file_path_list:
        rttm_file = rttm_file.strip()
        if rttm_file is not None and os.path.exists(rttm_file):
            uniq_id = get_uniqname_from_filepath(rttm_file)
            ref_labels = rttm_to_labels(rttm_file)
            reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id)
            pyannote_obj_list.append([uniq_id, reference])
    return pyannote_obj_list


def make_meta_dict(hyp_rttm_list, ref_rttm_list):
    """Create a temporary `audio_rttm_map_dict` for evaluation
    """
    meta_dict = {}
    for k, rttm_file in enumerate(ref_rttm_list):
        uniq_id = get_uniqname_from_filepath(rttm_file)
        meta_dict[uniq_id] = {"rttm_filepath": rttm_file.strip()}
        if hyp_rttm_list is not None:
            hyp_rttm_file = hyp_rttm_list[k]
            meta_dict[uniq_id].update({"hyp_rttm_filepath": hyp_rttm_file.strip()})
    return meta_dict


def make_trans_info_dict(hyp_json_list_path):
    """Create `trans_info_dict` from the `.json` files
    """
    trans_info_dict = {}
    for json_file in hyp_json_list_path:
        json_file = json_file.strip()
        with open(json_file) as jsf:
            json_data = json.load(jsf)
        uniq_id = get_uniqname_from_filepath(json_file)
        trans_info_dict[uniq_id] = json_data
    return trans_info_dict


def read_file_path(list_path):
    """Read file path and strip to remove line change symbol
    """
    return sorted([x.strip() for x in read_file(list_path)])


def main(
    hyp_rttm_list_path: str,
    ref_rttm_list_path: str,
    hyp_ctm_list_path: str,
    ref_ctm_list_path: str,
    hyp_json_list_path: str,
    diar_eval_mode: str = "all",
    root_path: str = "./",
):

    # Read filepath list files
    hyp_rttm_list = read_file_path(hyp_rttm_list_path) if hyp_rttm_list_path else None
    ref_rttm_list = read_file_path(ref_rttm_list_path) if ref_rttm_list_path else None
    hyp_ctm_list = read_file_path(hyp_ctm_list_path) if hyp_ctm_list_path else None
    ref_ctm_list = read_file_path(ref_ctm_list_path) if ref_ctm_list_path else None
    hyp_json_list = read_file_path(hyp_json_list_path) if hyp_json_list_path else None

    audio_rttm_map_dict = make_meta_dict(hyp_rttm_list, ref_rttm_list)

    trans_info_dict = make_trans_info_dict(hyp_json_list) if hyp_json_list else None

    all_hypothesis = get_pyannote_objs_from_rttms(hyp_rttm_list)
    all_reference = get_pyannote_objs_from_rttms(ref_rttm_list)

    diar_score = evaluate_der(
        audio_rttm_map_dict=audio_rttm_map_dict,
        all_reference=all_reference,
        all_hypothesis=all_hypothesis,
        diar_eval_mode=diar_eval_mode,
    )

    # Get session-level diarization error rate and speaker counting error
    der_results = OfflineDiarWithASR.gather_eval_results(
        diar_score=diar_score,
        audio_rttm_map_dict=audio_rttm_map_dict,
        trans_info_dict=trans_info_dict,
        root_path=root_path,
    )

    if ref_ctm_list is not None:
        # Calculate WER and cpWER if reference CTM files exist
        if hyp_ctm_list is not None:
            wer_results = OfflineDiarWithASR.evaluate(
                audio_file_list=hyp_rttm_list,
                hyp_trans_info_dict=None,
                hyp_ctm_file_list=hyp_ctm_list,
                ref_ctm_file_list=ref_ctm_list,
            )
        elif hyp_json_list is not None:
            wer_results = OfflineDiarWithASR.evaluate(
                audio_file_list=hyp_rttm_list,
                hyp_trans_info_dict=trans_info_dict,
                hyp_ctm_file_list=None,
                ref_ctm_file_list=ref_ctm_list,
            )
        else:
            raise ValueError("Hypothesis information is not provided in the correct format.")
    else:
        wer_results = {}

    # Print average DER, WER and cpWER
    OfflineDiarWithASR.print_errors(der_results=der_results, wer_results=wer_results)

    # Save detailed session-level evaluation results in `root_path`.
    OfflineDiarWithASR.write_session_level_result_in_csv(
        der_results=der_results,
        wer_results=wer_results,
        root_path=root_path,
        csv_columns=OfflineDiarWithASR.get_csv_columns(),
    )
    return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--hyp_rttm_list", help="path to the filelist of hypothesis RTTM files", type=str, required=True, default=None
    )
    parser.add_argument(
        "--ref_rttm_list", help="path to the filelist of reference RTTM files", type=str, required=True, default=None
    )
    parser.add_argument(
        "--hyp_ctm_list", help="path to the filelist of hypothesis CTM files", type=str, required=False, default=None
    )
    parser.add_argument(
        "--ref_ctm_list", help="path to the filelist of reference CTM files", type=str, required=False, default=None
    )
    parser.add_argument(
        "--hyp_json_list",
        help="(Optional) path to the filelist of hypothesis JSON files",
        type=str,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--diar_eval_mode",
        help='evaluation mode: "all", "full", "fair", "forgiving"',
        type=str,
        required=False,
        default="all",
    )
    parser.add_argument(
        "--root_path", help='directory for saving result files', type=str, required=False, default="./"
    )

    args = parser.parse_args()

    main(
        args.hyp_rttm_list,
        args.ref_rttm_list,
        args.hyp_ctm_list,
        args.ref_ctm_list,
        args.hyp_json_list,
        args.diar_eval_mode,
        args.root_path,
    )