File size: 12,103 Bytes
2d8da09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass, field, is_dataclass
from pathlib import Path
from typing import Optional
import pytorch_lightning as pl
import torch
from omegaconf import MISSING, OmegaConf
from sklearn.model_selection import ParameterGrid
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import ASRModel, EncDecRNNTModel
from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import (
apply_confidence_parameters,
run_confidence_benchmark,
)
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig
from nemo.core.config import hydra_runner
from nemo.utils import logging, model_utils
"""
Get confidence metrics and curve plots for a given model, dataset, and confidence parameters.
# Arguments
model_path: Path to .nemo ASR checkpoint
pretrained_name: Name of pretrained ASR model (from NGC registry)
dataset_manifest: Path to dataset JSON manifest file (in NeMo format)
output_dir: Output directory to store a report and curve plot directories
batch_size: batch size during inference
num_workers: number of workers during inference
cuda: Optional int to enable or disable execution of model on certain CUDA device
amp: Bool to decide if Automatic Mixed Precision should be used during inference
audio_type: Str filetype of the audio. Supported = wav, flac, mp3
target_level: Word- or token-level confidence. Supported = word, token, auto (for computing both word and token)
confidence_cfg: Config with confidence parameters
grid_params: Dictionary with lists of parameters to iteratively benchmark on
# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription are defined with "dataset_manifest".
Results are returned as a benchmark report and curve plots.
python benchmark_asr_confidence.py \
model_path=null \
pretrained_name=null \
dataset_manifest="" \
output_dir="" \
batch_size=64 \
num_workers=8 \
cuda=0 \
amp=True \
target_level="word" \
confidence_cfg.exclude_blank=False \
'grid_params="{\"aggregation\": [\"min\", \"prod\"], \"alpha\": [0.33, 0.5]}"'
"""
def get_experiment_params(cfg):
"""Get experiment parameters from a confidence config and generate the experiment name.
Returns:
List of experiment parameters.
String with the experiment name.
"""
blank = "no_blank" if cfg.exclude_blank else "blank"
aggregation = cfg.aggregation
method_name = cfg.method_cfg.name
alpha = cfg.method_cfg.alpha
if method_name == "entropy":
entropy_type = cfg.method_cfg.entropy_type
entropy_norm = cfg.method_cfg.entropy_norm
experiment_param_list = [
aggregation,
str(cfg.exclude_blank),
method_name,
entropy_type,
entropy_norm,
str(alpha),
]
experiment_str = "-".join([aggregation, blank, method_name, entropy_type, entropy_norm, str(alpha)])
else:
experiment_param_list = [aggregation, str(cfg.exclude_blank), method_name, "-", "-", str(alpha)]
experiment_str = "-".join([aggregation, blank, method_name, str(alpha)])
return experiment_param_list, experiment_str
@dataclass
class ConfidenceBenchmarkingConfig:
# Required configs
model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
dataset_manifest: str = MISSING
output_dir: str = MISSING
# General configs
batch_size: int = 32
num_workers: int = 4
# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
# If `cuda` is a negative number, inference will be on CPU only.
cuda: Optional[int] = None
amp: bool = False
audio_type: str = "wav"
# Confidence configs
target_level: str = "auto" # Choices: "word", "token", "auto" (for both word- and token-level confidence)
confidence_cfg: ConfidenceConfig = field(
default_factory=lambda: ConfidenceConfig(preserve_word_confidence=True, preserve_token_confidence=True)
)
grid_params: Optional[str] = None # a dictionary with lists of parameters to iteratively benchmark on
@hydra_runner(config_name="ConfidenceBenchmarkingConfig", schema=ConfidenceBenchmarkingConfig)
def main(cfg: ConfidenceBenchmarkingConfig):
torch.set_grad_enabled(False)
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)
if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
# setup GPU
if cfg.cuda is None:
if torch.cuda.is_available():
device = [0] # use 0th CUDA device
accelerator = 'gpu'
else:
device = 1
accelerator = 'cpu'
else:
device = [cfg.cuda]
accelerator = 'gpu'
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
# setup model
if cfg.model_path is not None:
# restore model from .nemo file path
model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True)
classpath = model_cfg.target # original class path
imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel
logging.info(f"Restoring model : {imported_class.__name__}")
asr_model = imported_class.restore_from(
restore_path=cfg.model_path, map_location=map_location
) # type: ASRModel
else:
# restore model by name
asr_model = ASRModel.from_pretrained(
model_name=cfg.pretrained_name, map_location=map_location
) # type: ASRModel
trainer = pl.Trainer(devices=device, accelerator=accelerator)
asr_model.set_trainer(trainer)
asr_model = asr_model.eval()
# Check if ctc or rnnt model
is_rnnt = isinstance(asr_model, EncDecRNNTModel)
# Check that the model has the `change_decoding_strategy` method
if not hasattr(asr_model, 'change_decoding_strategy'):
raise RuntimeError("The asr_model you are using must have the `change_decoding_strategy` method.")
# get filenames and reference texts from manifest
filepaths = []
reference_texts = []
if os.stat(cfg.dataset_manifest).st_size == 0:
logging.error(f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!")
return None
manifest_dir = Path(cfg.dataset_manifest).parent
with open(cfg.dataset_manifest, 'r') as f:
for line in f:
item = json.loads(line)
audio_file = Path(item['audio_filepath'])
if not audio_file.is_file() and not audio_file.is_absolute():
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))
reference_texts.append(item['text'])
# setup AMP (optional)
autocast = None
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
logging.info("AMP enabled!\n")
autocast = torch.cuda.amp.autocast
# do grid-based benchmarking if grid_params is provided, otherwise a regular one
work_dir = Path(cfg.output_dir)
os.makedirs(work_dir, exist_ok=True)
report_legend = (
",".join(
[
"model_type",
"aggregation",
"blank",
"method_name",
"entropy_type",
"entropy_norm",
"alpha",
"target_level",
"auc_roc",
"auc_pr",
"auc_nt",
"nce",
"ece",
"auc_yc",
"std_yc",
"max_yc",
]
)
+ "\n"
)
model_typename = "RNNT" if is_rnnt else "CTC"
report_file = work_dir / Path("report.csv")
if cfg.grid_params:
asr_model.change_decoding_strategy(
RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy_batch", confidence_cfg=cfg.confidence_cfg)
if is_rnnt
else CTCDecodingConfig(confidence_cfg=cfg.confidence_cfg)
)
params = json.loads(cfg.grid_params)
hp_grid = ParameterGrid(params)
hp_grid = list(hp_grid)
logging.info(f"==============================Running a benchmarking with grid search=========================")
logging.info(f"Grid search size: {len(hp_grid)}")
logging.info(f"Results will be written to:\nreport file `{report_file}`\nand plot directories near the file")
logging.info(f"==============================================================================================")
with open(report_file, "tw", encoding="utf-8") as f:
f.write(report_legend)
f.flush()
for i, hp in enumerate(hp_grid):
logging.info(f"Run # {i + 1}, grid: `{hp}`")
asr_model.change_decoding_strategy(apply_confidence_parameters(asr_model.cfg.decoding, hp))
param_list, experiment_name = get_experiment_params(asr_model.cfg.decoding.confidence_cfg)
plot_dir = work_dir / Path(experiment_name)
results = run_confidence_benchmark(
asr_model,
cfg.target_level,
filepaths,
reference_texts,
cfg.batch_size,
cfg.num_workers,
plot_dir,
autocast,
)
for level, result in results.items():
f.write(f"{model_typename},{','.join(param_list)},{level},{','.join([str(r) for r in result])}\n")
f.flush()
else:
asr_model.change_decoding_strategy(
RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy_batch", confidence_cfg=cfg.confidence_cfg)
if is_rnnt
else CTCDecodingConfig(confidence_cfg=cfg.confidence_cfg)
)
param_list, experiment_name = get_experiment_params(asr_model.cfg.decoding.confidence_cfg)
plot_dir = work_dir / Path(experiment_name)
logging.info(f"==============================Running a single benchmarking===================================")
logging.info(f"Results will be written to:\nreport file `{report_file}`\nand plot directory `{plot_dir}`")
with open(report_file, "tw", encoding="utf-8") as f:
f.write(report_legend)
f.flush()
results = run_confidence_benchmark(
asr_model,
cfg.batch_size,
cfg.num_workers,
cfg.target_level,
filepaths,
reference_texts,
plot_dir,
autocast,
)
for level, result in results.items():
f.write(f"{model_typename},{','.join(param_list)},{level},{','.join([str(r) for r in result])}\n")
logging.info(f"===========================================Done===============================================")
if __name__ == '__main__':
main()
|