|
import argparse |
|
import json |
|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
import hw_asr.model as module_model |
|
from hw_asr.trainer import Trainer |
|
from hw_asr.utils import ROOT_PATH |
|
from hw_asr.utils.object_loading import get_dataloaders |
|
from hw_asr.utils.parse_config import ConfigParser |
|
from hw_asr.metric.utils import calc_wer |
|
|
|
DEFAULT_CHECKPOINT_PATH = ROOT_PATH / "default_test_model" / "checkpoint.pth" |
|
|
|
|
|
def main(config, out_file): |
|
logger = config.get_logger("test") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
text_encoder = config.get_text_encoder() |
|
|
|
|
|
dataloaders = get_dataloaders(config, text_encoder) |
|
|
|
|
|
model = config.init_obj(config["arch"], module_model, n_class=len(text_encoder)) |
|
logger.info(model) |
|
|
|
logger.info("Loading checkpoint: {} ...".format(config.resume)) |
|
checkpoint = torch.load(config.resume, map_location=device) |
|
state_dict = checkpoint["state_dict"] |
|
if config["n_gpu"] > 1: |
|
model = torch.nn.DataParallel(model) |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
results = [] |
|
|
|
argmax_wer_sum = 0 |
|
beam_search_wer_sum = 0 |
|
lm_wer_sum = 0 |
|
|
|
with torch.no_grad(): |
|
for batch_num, batch in enumerate(tqdm(dataloaders["test"])): |
|
batch = Trainer.move_batch_to_device(batch, device) |
|
output = model(**batch) |
|
if type(output) is dict: |
|
batch.update(output) |
|
else: |
|
batch["logits"] = output |
|
batch["log_probs"] = torch.log_softmax(batch["logits"], dim=-1) |
|
batch["log_probs_length"] = model.transform_input_lengths(batch["spectrogram_length"]) |
|
batch["probs"] = batch["log_probs"].exp().cpu() |
|
batch["argmax"] = batch["probs"].argmax(-1) |
|
for i in range(len(batch["text"])): |
|
length = int(batch["log_probs_length"][i]) |
|
ground_truth = batch["text"][i] |
|
|
|
argmax = batch["argmax"][i][:length].cpu().numpy() |
|
text_argmax = text_encoder.ctc_decode(argmax) |
|
|
|
probs = batch["probs"][i][:length].detach().cpu().numpy() |
|
text_beam_search = text_encoder.ctc_beam_search(probs, beam_size=4) |
|
|
|
logits = batch["logits"][i][:length].detach().cpu().numpy() |
|
text_lm = text_encoder.ctc_lm_beam_search(logits) |
|
|
|
argmax_wer = calc_wer(ground_truth, text_argmax) * 100 |
|
beam_search_wer = calc_wer(ground_truth, text_beam_search) * 100 |
|
lm_wer = calc_wer(ground_truth, text_lm) * 100 |
|
|
|
argmax_wer_sum += argmax_wer |
|
beam_search_wer_sum += beam_search_wer |
|
lm_wer_sum += lm_wer |
|
|
|
results.append( |
|
{ |
|
"ground_truth": ground_truth, |
|
"pred_text_argmax": text_argmax, |
|
"pred_text_beam_search": text_beam_search, |
|
"pred_text_lm": text_lm, |
|
"argmax_wer": argmax_wer, |
|
"beam_search_wer": beam_search_wer, |
|
"lm_wer": lm_wer, |
|
} |
|
) |
|
|
|
n = len(results) |
|
logger.info("argmax_wer_mean:") |
|
logger.info(argmax_wer_sum / n) |
|
logger.info("beam_search_wer_mean:") |
|
logger.info(beam_search_wer_sum / n) |
|
logger.info("lm_wer_mean:") |
|
logger.info(lm_wer_sum / n) |
|
|
|
with Path(out_file).open("w") as f: |
|
json.dump(results, f, indent=2) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = argparse.ArgumentParser(description="PyTorch Template") |
|
args.add_argument( |
|
"-c", |
|
"--config", |
|
default=None, |
|
type=str, |
|
help="config file path (default: None)", |
|
) |
|
args.add_argument( |
|
"-r", |
|
"--resume", |
|
default=str(DEFAULT_CHECKPOINT_PATH.absolute().resolve()), |
|
type=str, |
|
help="path to latest checkpoint (default: None)", |
|
) |
|
args.add_argument( |
|
"-d", |
|
"--device", |
|
default=None, |
|
type=str, |
|
help="indices of GPUs to enable (default: all)", |
|
) |
|
args.add_argument( |
|
"-o", |
|
"--output", |
|
default="output.json", |
|
type=str, |
|
help="File to write results (.json)", |
|
) |
|
args.add_argument( |
|
"-t", |
|
"--test-data-folder", |
|
default=None, |
|
type=str, |
|
help="Path to dataset", |
|
) |
|
args.add_argument( |
|
"-b", |
|
"--batch-size", |
|
default=20, |
|
type=int, |
|
help="Test dataset batch size", |
|
) |
|
args.add_argument( |
|
"-j", |
|
"--jobs", |
|
default=1, |
|
type=int, |
|
help="Number of workers for test dataloader", |
|
) |
|
|
|
args = args.parse_args() |
|
|
|
|
|
if args.device is not None: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device |
|
|
|
|
|
|
|
model_config = Path(args.resume).parent / "config.json" |
|
with model_config.open() as f: |
|
config = ConfigParser(json.load(f), resume=args.resume) |
|
|
|
|
|
if args.config is not None: |
|
with Path(args.config).open() as f: |
|
config.config.update(json.load(f)) |
|
|
|
|
|
if args.test_data_folder is not None: |
|
test_data_folder = Path(args.test_data_folder).absolute().resolve() |
|
assert test_data_folder.exists() |
|
config.config["data"] = { |
|
"test": { |
|
"batch_size": args.batch_size, |
|
"num_workers": args.jobs, |
|
"datasets": [ |
|
{ |
|
"type": "CustomDirAudioDataset", |
|
"args": { |
|
"audio_dir": str(test_data_folder / "audio"), |
|
"transcription_dir": str(test_data_folder / "transcriptions"), |
|
}, |
|
} |
|
], |
|
} |
|
} |
|
|
|
assert config.config.get("data", {}).get("test", None) is not None |
|
config["data"]["test"]["batch_size"] = args.batch_size |
|
config["data"]["test"]["n_jobs"] = args.jobs |
|
|
|
main(config, args.output) |
|
|