File size: 8,061 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script creates a manifest file for diarization training. If you specify `pairwise_rttm_output_folder`, the script generates
a two-speaker subset of the original RTTM files. For example, an RTTM file with 4 speakers will obtain 6 different pairs and
6 RTTM files with two speakers in each RTTM file.

Args:
   --input_manifest_path: input json file name
   --output_manifest_path: output manifest_file name
   --pairwise_rttm_output_folder: Save two-speaker pair RTTM files
   --window: Window length for segmentation
   --shift: Shift length for segmentation
   --decimals: Rounding decimals
"""

import argparse
import copy
import itertools
import os
import random

from tqdm import tqdm

from nemo.collections.asr.parts.utils.manifest_utils import (
    get_input_manifest_dict,
    get_subsegment_dict,
    rreplace,
    write_truncated_subsegments,
)
from nemo.collections.asr.parts.utils.speaker_utils import (
    audio_rttm_map,
    rttm_to_labels,
    segments_manifest_to_subsegments_manifest,
    write_rttm2manifest,
)
from nemo.utils import logging

random.seed(42)


def labels_to_rttmfile(labels, uniq_id, filename, out_rttm_dir):
    """
    Write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels
    """
    filename = os.path.join(out_rttm_dir, filename + '.rttm')
    with open(filename, 'w') as f:
        for line in labels:
            line = line.strip()
            start, end, speaker = line.split()
            duration = float(end) - float(start)
            start = float(start)
            log = 'SPEAKER {} 1   {:.3f}   {:.3f} <NA> <NA> {} <NA> <NA>\n'.format(uniq_id, start, duration, speaker)
            f.write(log)

    return filename


def split_into_pairwise_rttm(audio_rttm_map, input_manifest_path, output_dir):
    """
    Create pairwise RTTM files and save it to `output_dir`. This function picks two speakers from the original RTTM files
    then saves the two-speaker subset of RTTM to `output_dir`.

    Args:
        audio_rttm_map (dict):
            A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files
        input_manifest_path (str):
            Path of the input manifest file.
        output_dir (str):
            Path to the directory where the new RTTM files are saved.
    """
    input_manifest_dict = get_input_manifest_dict(input_manifest_path)
    rttmlist = []
    rttm_split_manifest_dict = {}
    split_audio_rttm_map = {}
    logging.info("Creating split RTTM files.")
    for uniq_id, line in tqdm(input_manifest_dict.items(), total=len(input_manifest_dict)):
        audiopath = line['audio_filepath']
        num_speakers = line['num_speakers']
        rttm_filepath = line['rttm_filepath']

        rttm = rttm_to_labels(rttm_filepath)
        speakers = []
        j = 0
        while len(speakers) < num_speakers:
            if rttm[j].split(' ')[2] not in speakers:
                speakers.append(rttm[j].split(' ')[2])
            j += 1
        base_fn = audiopath.split('/')[-1].replace('.wav', '')
        for pair in itertools.combinations(speakers, 2):
            i, target_rttm = 0, []
            while i < len(rttm):
                entry = rttm[i]
                sp_id = entry.split(' ')[2]
                if sp_id in pair:
                    target_rttm.append(entry)
                i += 1

            pair_string = f".{pair[0]}_{pair[1]}"
            uniq_id_pair = uniq_id + pair_string
            filename = base_fn + pair_string
            labels_to_rttmfile(target_rttm, base_fn, filename, output_dir)
            rttm_path = output_dir + filename + ".rttm"
            rttmlist.append(rttm_path)
            line_mod = copy.deepcopy(line)
            line_mod['rttm_filepath'] = rttm_path
            meta = copy.deepcopy(audio_rttm_map[uniq_id])
            meta['rttm_filepath'] = rttm_path
            rttm_split_manifest_dict[uniq_id_pair] = line_mod
            split_audio_rttm_map[uniq_id_pair] = meta

    return rttm_split_manifest_dict, split_audio_rttm_map


def main(input_manifest_path, output_manifest_path, pairwise_rttm_output_folder, window, shift, step_count, decimals):

    if '.json' not in input_manifest_path:
        raise ValueError("input_manifest_path file should be .json file format")
    if output_manifest_path and '.json' not in output_manifest_path:
        raise ValueError("output_manifest_path file should be .json file format")
    elif not output_manifest_path:
        output_manifest_path = rreplace(input_manifest_path, '.json', f'.{step_count}seg.json')

    if pairwise_rttm_output_folder is not None:
        if not pairwise_rttm_output_folder.endswith('/'):
            pairwise_rttm_output_folder = f"{pairwise_rttm_output_folder}/"
        org_audio_rttm_map = audio_rttm_map(input_manifest_path)
        input_manifest_dict, AUDIO_RTTM_MAP = split_into_pairwise_rttm(
            audio_rttm_map=org_audio_rttm_map,
            input_manifest_path=input_manifest_path,
            output_dir=pairwise_rttm_output_folder,
        )
    else:
        input_manifest_dict = get_input_manifest_dict(input_manifest_path)
        AUDIO_RTTM_MAP = audio_rttm_map(input_manifest_path)

    segment_manifest_path = rreplace(input_manifest_path, '.json', '_seg.json')
    subsegment_manifest_path = rreplace(input_manifest_path, '.json', '_subseg.json')

    # todo: do we need to expose this?
    min_subsegment_duration = 0.05
    step_count = int(step_count)

    segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, decimals)
    subsegments_manifest_file = subsegment_manifest_path

    logging.info("Creating subsegments.")
    segments_manifest_to_subsegments_manifest(
        segments_manifest_file=segments_manifest_file,
        subsegments_manifest_file=subsegments_manifest_file,
        window=window,
        shift=shift,
        min_subsegment_duration=min_subsegment_duration,
        include_uniq_id=True,
    )
    subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, decimals)
    write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, decimals)
    os.remove(segment_manifest_path)
    os.remove(subsegment_manifest_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_manifest_path", help="input json file name", type=str, required=True)
    parser.add_argument(
        "--output_manifest_path", help="output manifest_file name", type=str, default=None, required=False
    )
    parser.add_argument(
        "--pairwise_rttm_output_folder",
        help="Save two-speaker pair RTTM files",
        type=str,
        default=None,
        required=False,
    )
    parser.add_argument("--window", help="Window length for segmentation", type=float, required=True)
    parser.add_argument("--shift", help="Shift length for segmentation", type=float, required=True)
    parser.add_argument("--decimals", help="Rounding decimals", type=int, default=3, required=False)
    parser.add_argument(
        "--step_count", help="Number of the unit segments you want to create per utterance", required=True,
    )
    args = parser.parse_args()

    main(
        args.input_manifest_path,
        args.output_manifest_path,
        args.pairwise_rttm_output_folder,
        args.window,
        args.shift,
        args.step_count,
        args.decimals,
    )