File size: 4,751 Bytes
2d8da09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# these tests are not included in CI, since they take moderate amount of time
# they are supposed to be run in the nightly pipeline instead
import os
import subprocess
import sys
from pathlib import Path
import pytest
from nemo.collections.asr.parts.utils.transcribe_utils import TextProcessingConfig
sys.path.append(str(Path(__file__).parents[2] / 'examples' / 'asr'))
import speech_to_text_eval
@pytest.mark.parametrize(
'build_args',
[
"ensemble.0.model=stt_es_conformer_ctc_large ensemble.1.model=stt_it_conformer_ctc_large",
"ensemble.0.model=stt_es_conformer_transducer_large ensemble.1.model=stt_it_conformer_transducer_large",
(
"ensemble.0.model=stt_es_fastconformer_hybrid_large_pc ensemble.1.model=stt_it_fastconformer_hybrid_large_pc "
"confidence.method_cfg.alpha=0.33 confidence.method_cfg.entropy_norm=exp "
),
(
"ensemble.0.model=stt_es_fastconformer_hybrid_large_pc "
"ensemble.1.model=stt_it_fastconformer_hybrid_large_pc "
"transcription.decoder_type=ctc "
),
"ensemble.0.model=stt_es_conformer_ctc_large ensemble.1.model=stt_it_conformer_transducer_large",
(
"ensemble.0.model=stt_es_conformer_ctc_large "
"ensemble.1.model=stt_it_conformer_ctc_large "
f"ensemble.0.dev_manifest={Path(os.getenv('TEST_DATA_PATH', '')) / 'es' / 'dev_manifest.json'} "
f"ensemble.1.dev_manifest={Path(os.getenv('TEST_DATA_PATH', '')) / 'it' / 'dev_manifest.json'} "
"tune_confidence=True "
),
(
"ensemble.0.model=stt_es_conformer_transducer_large "
"ensemble.1.model=stt_it_conformer_transducer_large "
f"ensemble.0.dev_manifest={Path(os.getenv('TEST_DATA_PATH', '')) / 'es' / 'dev_manifest.json'} "
f"ensemble.1.dev_manifest={Path(os.getenv('TEST_DATA_PATH', '')) / 'it' / 'dev_manifest.json'} "
"tune_confidence=True "
),
],
ids=(
[
"CTC models",
"Transducer models",
"Hybrid models (Transducer mode)",
"Hybrid models (CTC mode)",
"CTC + Transducer",
"CTC models + confidence tuning",
"Transducer models + confidence tuning",
]
),
)
def test_confidence_ensemble(tmp_path, build_args):
"""Integration tests for confidence-ensembles.
Tests building ensemble and running inference with the model.
To use, make sure to define TEST_DATA_PATH env variable with path to
the test data. The following structure is assumed:
$TEST_DATA_PATH
βββ es
β βββ dev
β βββ dev_manifest.json
β βββ test
β βββ train
β βββ train_manifest.json
βββ it
β βββ dev
β βββ dev_manifest.json
β βββ test
β βββ test_manifest.json
β βββ train
β βββ train_manifest.json
βββ test_manifest.json
"""
# checking for test data and failing right away if not defined
if not os.getenv("TEST_DATA_PATH"):
raise ValueError("TEST_DATA_PATH env variable has to be defined!")
test_data_path = Path(os.environ['TEST_DATA_PATH'])
build_ensemble_cmd = f"""
python {Path(__file__).parent / 'build_ensemble.py'} \
--config-name=ensemble_config.yaml \
output_path={tmp_path / 'ensemble.nemo'} \
{build_args}
"""
subprocess.run(build_ensemble_cmd, check=True, shell=True)
eval_cfg = speech_to_text_eval.EvaluationConfig(
dataset_manifest=str(test_data_path / 'test_manifest.json'),
output_filename=str(tmp_path / 'output.json'),
model_path=str(tmp_path / 'ensemble.nemo'),
text_processing=TextProcessingConfig(punctuation_marks=".,?", do_lowercase=True, rm_punctuation=True),
)
results = speech_to_text_eval.main(eval_cfg)
assert results.metric_value < 0.20 # relaxed check for better than 20% WER
|