|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import git |
|
from omegaconf import OmegaConf, open_dict |
|
from utils import cal_target_metadata_wer, run_asr_inference |
|
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
""" |
|
This script serves as evaluator of ASR models |
|
Usage: |
|
python asr_evaluator.py \ |
|
engine.pretrained_name="stt_en_conformer_transducer_large" \ |
|
engine.inference.mode="offline" \ |
|
engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data> \ |
|
..... |
|
|
|
Check out parameters in ./conf/eval.yaml |
|
""" |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="eval.yaml") |
|
def main(cfg): |
|
report = {} |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
|
|
if cfg.env.save_git_hash: |
|
repo = git.Repo(search_parent_directories=True) |
|
report['git_hash'] = repo.head.object.hexsha |
|
|
|
|
|
|
|
|
|
if cfg.analyst.metric_calculator.exist_pred_manifest is None: |
|
|
|
|
|
|
|
|
|
|
|
cfg.engine = run_asr_inference(cfg=cfg.engine) |
|
|
|
else: |
|
logging.info( |
|
f"Use generated prediction manifest {cfg.analyst.metric_calculator.exist_pred_manifest} and skip enigneer" |
|
) |
|
with open_dict(cfg): |
|
cfg.engine.output_filename = cfg.analyst.metric_calculator.exist_pred_manifest |
|
|
|
|
|
output_manifest_w_wer, total_res, eval_metric = cal_write_wer( |
|
pred_manifest=cfg.engine.output_filename, |
|
clean_groundtruth_text=cfg.analyst.metric_calculator.clean_groundtruth_text, |
|
langid=cfg.analyst.metric_calculator.langid, |
|
use_cer=cfg.analyst.metric_calculator.use_cer, |
|
output_filename=cfg.analyst.metric_calculator.output_filename, |
|
) |
|
with open_dict(cfg): |
|
cfg.analyst.metric_calculator.output_filename = output_manifest_w_wer |
|
|
|
report.update({"res": total_res}) |
|
|
|
for target in cfg.analyst.metadata: |
|
if cfg.analyst.metadata[target].enable: |
|
occ_avg_wer = cal_target_metadata_wer( |
|
manifest=cfg.analyst.metric_calculator.output_filename, |
|
target=target, |
|
meta_cfg=cfg.analyst.metadata[target], |
|
eval_metric=eval_metric, |
|
) |
|
report[target] = occ_avg_wer |
|
|
|
config_engine = OmegaConf.to_object(cfg.engine) |
|
report.update(config_engine) |
|
|
|
config_metric_calculator = OmegaConf.to_object(cfg.analyst.metric_calculator) |
|
report.update(config_metric_calculator) |
|
|
|
pretty = json.dumps(report, indent=4) |
|
res = "%.3f" % (report["res"][eval_metric] * 100) |
|
logging.info(pretty) |
|
logging.info(f"Overall {eval_metric} is {res} %") |
|
|
|
|
|
report_file = "report.json" |
|
if "report_filename" in cfg.writer and cfg.writer.report_filename: |
|
report_file = cfg.writer.report_filename |
|
|
|
with open(report_file, "a") as fout: |
|
json.dump(report, fout) |
|
fout.write('\n') |
|
fout.flush() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|