File size: 11,772 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import multiprocessing
import shutil
from collections import OrderedDict
from pathlib import Path
from pprint import pprint
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sox
from scipy.stats import expon
from tqdm import tqdm

from nemo.collections.asr.parts.utils.vad_utils import (
    get_nonspeech_segments,
    load_speech_overlap_segments_from_rttm,
    plot_sample_from_rttm,
)

"""
This script analyzes multi-speaker speech dataset and generates statistics.
The input directory </path/to/rttm_and_wav_directory> is required to contain the following files:
    - rttm files (*.rttm)
    - wav files (*.wav)

Usage:
    python <NEMO_ROOT>/scripts/speaker_tasks/multispeaker_data_analysis.py \
        </path/to/rttm_and_wav_directory> \
        --session_dur 20 \
        --silence_mean 0.2 \
        --silence_var 100 \
        --overlap_mean 0.15 \
        --overlap_var 50 \
        --num_workers 8 \
        --num_samples 10 \
        --output_dir <path/to/output_directory>
"""


def process_sample(sess_dict: Dict) -> Dict:
    """
    Process each synthetic sample

    Args:
        sess_dict (dict): dictionary containing the following keys
            rttm_file (str): path to the rttm file
            session_dur (float): duration of the session (specified by argument)
            precise (bool): whether to measure the precise duration of the session using sox

    Returns:
        results (dict): dictionary containing the following keys
            session_dur (float): duration of the session
            silence_len_list (list): list of silence durations of each silence occurrence
            silence_dur (float): total silence duration in a session
            silence_ratio (float): ratio of silence duration to session duration
            overlap_len_list (list): list of overlap durations of each overlap occurrence
            overlap_dur (float): total overlap duration
            overlap_ratio (float): ratio of overlap duration to speech (non-silence) duration
    """

    rttm_file = sess_dict["rttm_file"]
    session_dur = sess_dict["session_dur"]
    precise = sess_dict["precise"]
    if precise or session_dur is None:
        wav_file = rttm_file.parent / Path(rttm_file.stem + ".wav")
        session_dur = sox.file_info.duration(str(wav_file))

    speech_seg, overlap_seg = load_speech_overlap_segments_from_rttm(rttm_file)
    speech_dur = sum([sess_dict[1] - sess_dict[0] for sess_dict in speech_seg])

    silence_seg = get_nonspeech_segments(speech_seg, session_dur)
    silence_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in silence_seg]
    silence_dur = max(0, session_dur - speech_dur)
    silence_ratio = silence_dur / session_dur

    overlap_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in overlap_seg]
    overlap_dur = sum(overlap_len_list) if len(overlap_len_list) else 0
    overlap_ratio = overlap_dur / speech_dur

    results = {
        "session_dur": session_dur,
        "silence_len_list": silence_len_list,
        "silence_dur": silence_dur,
        "silence_ratio": silence_ratio,
        "overlap_len_list": overlap_len_list,
        "overlap_dur": overlap_dur,
        "overlap_ratio": overlap_ratio,
    }

    return results


def run_multispeaker_data_analysis(
    input_dir,
    session_dur=None,
    silence_mean=None,
    silence_var=None,
    overlap_mean=None,
    overlap_var=None,
    precise=False,
    save_path=None,
    num_workers=1,
) -> Dict:
    rttm_list = list(Path(input_dir).glob("*.rttm"))
    """
    Analyze the multispeaker data and plot the distribution of silence and overlap durations.

    Args:
        input_dir (str): path to the directory containing the rttm files
        session_dur (float): duration of the session (specified by argument)
        silence_mean (float): mean of the silence duration distribution
        silence_var (float): variance of the silence duration distribution
        overlap_mean (float): mean of the overlap duration distribution
        overlap_var (float): variance of the overlap duration distribution
        precise (bool): whether to measure the precise duration of the session using sox
        save_path (str): path to save the plots

    Returns:
        stats (dict): dictionary containing the statistics of the analyzed data
    """

    print(f"Found {len(rttm_list)} files to be processed")
    if len(rttm_list) == 0:
        raise ValueError(f"No rttm files found in {input_dir}")

    silence_duration = 0.0
    total_duration = 0.0
    overlap_duration = 0.0

    silence_ratio_all = []
    overlap_ratio_all = []
    silence_length_all = []
    overlap_length_all = []

    queue = []
    for rttm_file in tqdm(rttm_list):
        queue.append(
            {"rttm_file": rttm_file, "session_dur": session_dur, "precise": precise,}
        )

    if num_workers <= 1:
        results = [process_sample(sess_dict) for sess_dict in tqdm(queue)]
    else:
        with multiprocessing.Pool(processes=num_workers) as p:
            results = list(tqdm(p.imap(process_sample, queue), total=len(queue), desc='Processing', leave=True,))

    for item in results:
        total_duration += item["session_dur"]
        silence_duration += item["silence_dur"]
        overlap_duration += item["overlap_dur"]

        silence_length_all += item["silence_len_list"]
        overlap_length_all += item["overlap_len_list"]

        silence_ratio_all.append(item["silence_ratio"])
        overlap_ratio_all.append(item["overlap_ratio"])

    actual_silence_mean = silence_duration / total_duration
    actual_silence_var = np.var(silence_ratio_all)
    actual_overlap_mean = overlap_duration / (total_duration - silence_duration)
    actual_overlap_var = np.var(overlap_ratio_all)

    stats = OrderedDict()
    stats["total duration (hours)"] = f"{total_duration / 3600:.2f}"
    stats["number of sessions"] = len(rttm_list)
    stats["average session duration (seconds)"] = f"{total_duration / len(rttm_list):.2f}"
    stats["actual silence ratio mean/var"] = f"{actual_silence_mean:.4f}/{actual_silence_var:.4f}"
    stats["actual overlap ratio mean/var"] = f"{actual_overlap_mean:.4f}/{actual_overlap_var:.4f}"
    stats["expected silence ratio mean/var"] = f"{silence_mean}/{silence_var}"
    stats["expected overlap ratio mean/var"] = f"{overlap_mean}/{overlap_var}"
    stats["save_path"] = save_path

    print("-----------------------------------------------")
    print("                    Results                    ")
    print("-----------------------------------------------")
    for k, v in stats.items():
        print(k, ": ", v)
    print("-----------------------------------------------")

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 14))
    fig.suptitle(
        f"Average session={total_duration/len(rttm_list):.2f} seconds, num sessions={len(rttm_list)}, total={total_duration/3600:.2f} hours"
    )
    sns.histplot(silence_ratio_all, ax=ax1)
    ax1.set_xlabel("Silence ratio in a session")
    ax1.set_title(
        f"Target silence mean={silence_mean}, var={silence_var}. \nActual silence ratio={actual_silence_mean:.4f}, var={actual_silence_var:.4f}"
    )

    _, scale = expon.fit(silence_length_all, floc=0)
    sns.histplot(silence_length_all, ax=ax2)
    ax2.set_xlabel("Per-silence length in seconds")
    ax2.set_title(f"Per-silence length histogram, \nfitted exponential distribution with mean={scale:.4f}")

    sns.histplot(overlap_ratio_all, ax=ax3)
    ax3.set_title(
        f"Target overlap mean={overlap_mean}, var={overlap_var}. \nActual ratio={actual_overlap_mean:.4f}, var={actual_overlap_var:.4f}"
    )
    ax3.set_xlabel("Overlap ratio in a session")
    _, scale2 = expon.fit(overlap_length_all, floc=0)
    sns.histplot(overlap_length_all, ax=ax4)
    ax4.set_title(f"Per overlap length histogram, \nfitted exponential distribution with mean={scale2:.4f}")
    ax4.set_xlabel("Duration in seconds")

    if save_path:
        fig.savefig(save_path)
        print(f"Figure saved at: {save_path}")

    return stats


def visualize_multispeaker_data(input_dir: str, output_dir: str, num_samples: int = 10) -> None:
    """
    Visualize a set of randomly sampled data in the input directory

    Args:
        input_dir (str): Path to the input directory
        output_dir (str): Path to the output directory
        num_samples (int): Number of samples to visualize
    """
    rttm_list = list(Path(input_dir).glob("*.rttm"))
    idx_list = np.random.permutation(len(rttm_list))[:num_samples]
    print(f"Visualizing {num_samples} random samples")
    for idx in idx_list:
        rttm_file = rttm_list[idx]
        audio_file = rttm_file.parent / Path(rttm_file.stem + ".wav")
        output_file = Path(output_dir) / Path(rttm_file.stem + ".png")
        plot_sample_from_rttm(audio_file=audio_file, rttm_file=rttm_file, save_path=str(output_file), show=False)
    print(f"Sample plots saved at: {output_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_dir", default="", help="Input directory")
    parser.add_argument("-sd", "--session_dur", default=None, type=float, help="Duration per session in seconds")
    parser.add_argument("-sm", "--silence_mean", default=None, type=float, help="Expected silence ratio mean")
    parser.add_argument("-sv", "--silence_var", default=None, type=float, help="Expected silence ratio variance")
    parser.add_argument("-om", "--overlap_mean", default=None, type=float, help="Expected overlap ratio mean")
    parser.add_argument("-ov", "--overlap_var", default=None, type=float, help="Expected overlap ratio variance")
    parser.add_argument("-w", "--num_workers", default=1, type=int, help="Number of CPU workers to use")
    parser.add_argument("-s", "--num_samples", default=10, type=int, help="Number of random samples to plot")
    parser.add_argument("-o", "--output_dir", default="analysis/", type=str, help="Directory for saving output figure")
    parser.add_argument(
        "--precise", action="store_true", help="Set to get precise duration, with significant time cost"
    )
    args = parser.parse_args()

    print("Running with params:")
    pprint(vars(args))

    output_dir = Path(args.output_dir)
    if output_dir.exists():
        print(f"Removing existing output directory: {args.output_dir}")
        shutil.rmtree(str(output_dir))
    output_dir.mkdir(parents=True)

    run_multispeaker_data_analysis(
        input_dir=args.input_dir,
        session_dur=args.session_dur,
        silence_mean=args.silence_mean,
        silence_var=args.silence_var,
        overlap_mean=args.overlap_mean,
        overlap_var=args.overlap_var,
        precise=args.precise,
        save_path=str(Path(args.output_dir, "statistics.png")),
        num_workers=args.num_workers,
    )

    visualize_multispeaker_data(input_dir=args.input_dir, output_dir=args.output_dir, num_samples=args.num_samples)

    print("The multispeaker data analysis has been completed.")
    print(f"Please check the output directory: \n{args.output_dir}")