import argparse import json import os from pathlib import Path import torch from tqdm import tqdm import hw_asr.model as module_model from hw_asr.trainer import Trainer from hw_asr.utils import ROOT_PATH from hw_asr.utils.object_loading import get_dataloaders from hw_asr.utils.parse_config import ConfigParser from hw_asr.metric.utils import calc_wer DEFAULT_CHECKPOINT_PATH = ROOT_PATH / "default_test_model" / "checkpoint.pth" def main(config, out_file): logger = config.get_logger("test") # define cpu or gpu if possible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # text_encoder text_encoder = config.get_text_encoder() # setup data_loader instances dataloaders = get_dataloaders(config, text_encoder) # build model architecture model = config.init_obj(config["arch"], module_model, n_class=len(text_encoder)) logger.info(model) logger.info("Loading checkpoint: {} ...".format(config.resume)) checkpoint = torch.load(config.resume, map_location=device) state_dict = checkpoint["state_dict"] if config["n_gpu"] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) # prepare model for testing model = model.to(device) model.eval() results = [] argmax_wer_sum = 0 beam_search_wer_sum = 0 lm_wer_sum = 0 with torch.no_grad(): for batch_num, batch in enumerate(tqdm(dataloaders["test"])): batch = Trainer.move_batch_to_device(batch, device) output = model(**batch) if type(output) is dict: batch.update(output) else: batch["logits"] = output batch["log_probs"] = torch.log_softmax(batch["logits"], dim=-1) batch["log_probs_length"] = model.transform_input_lengths(batch["spectrogram_length"]) batch["probs"] = batch["log_probs"].exp().cpu() batch["argmax"] = batch["probs"].argmax(-1) for i in range(len(batch["text"])): length = int(batch["log_probs_length"][i]) ground_truth = batch["text"][i] argmax = batch["argmax"][i][:length].cpu().numpy() text_argmax = text_encoder.ctc_decode(argmax) probs = batch["probs"][i][:length].detach().cpu().numpy() text_beam_search = text_encoder.ctc_beam_search(probs, beam_size=4) logits = batch["logits"][i][:length].detach().cpu().numpy() text_lm = text_encoder.ctc_lm_beam_search(logits) argmax_wer = calc_wer(ground_truth, text_argmax) * 100 beam_search_wer = calc_wer(ground_truth, text_beam_search) * 100 lm_wer = calc_wer(ground_truth, text_lm) * 100 argmax_wer_sum += argmax_wer beam_search_wer_sum += beam_search_wer lm_wer_sum += lm_wer results.append( { "ground_truth": ground_truth, "pred_text_argmax": text_argmax, "pred_text_beam_search": text_beam_search, "pred_text_lm": text_lm, "argmax_wer": argmax_wer, "beam_search_wer": beam_search_wer, "lm_wer": lm_wer, } ) n = len(results) logger.info("argmax_wer_mean:") logger.info(argmax_wer_sum / n) logger.info("beam_search_wer_mean:") logger.info(beam_search_wer_sum / n) logger.info("lm_wer_mean:") logger.info(lm_wer_sum / n) with Path(out_file).open("w") as f: json.dump(results, f, indent=2) if __name__ == "__main__": args = argparse.ArgumentParser(description="PyTorch Template") args.add_argument( "-c", "--config", default=None, type=str, help="config file path (default: None)", ) args.add_argument( "-r", "--resume", default=str(DEFAULT_CHECKPOINT_PATH.absolute().resolve()), type=str, help="path to latest checkpoint (default: None)", ) args.add_argument( "-d", "--device", default=None, type=str, help="indices of GPUs to enable (default: all)", ) args.add_argument( "-o", "--output", default="output.json", type=str, help="File to write results (.json)", ) args.add_argument( "-t", "--test-data-folder", default=None, type=str, help="Path to dataset", ) args.add_argument( "-b", "--batch-size", default=20, type=int, help="Test dataset batch size", ) args.add_argument( "-j", "--jobs", default=1, type=int, help="Number of workers for test dataloader", ) args = args.parse_args() # set GPUs if args.device is not None: os.environ["CUDA_VISIBLE_DEVICES"] = args.device # first, we need to obtain config with model parameters # we assume it is located with checkpoint in the same folder model_config = Path(args.resume).parent / "config.json" with model_config.open() as f: config = ConfigParser(json.load(f), resume=args.resume) # update with addition configs from `args.config` if provided if args.config is not None: with Path(args.config).open() as f: config.config.update(json.load(f)) # if `--test-data-folder` was provided, set it as a default test set if args.test_data_folder is not None: test_data_folder = Path(args.test_data_folder).absolute().resolve() assert test_data_folder.exists() config.config["data"] = { "test": { "batch_size": args.batch_size, "num_workers": args.jobs, "datasets": [ { "type": "CustomDirAudioDataset", "args": { "audio_dir": str(test_data_folder / "audio"), "transcription_dir": str(test_data_folder / "transcriptions"), }, } ], } } assert config.config.get("data", {}).get("test", None) is not None config["data"]["test"]["batch_size"] = args.batch_size config["data"]["test"]["n_jobs"] = args.jobs main(config, args.output)