import os import json import numpy as np import torch import hydra import logging from omegaconf import DictConfig, OmegaConf from funasr_detach.register import tables from funasr_detach.download.download_from_hub import download_model from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed @hydra.main(config_name=None, version_base=None) def main_hydra(kwargs: DictConfig): if kwargs.get("debug", False): import pdb pdb.set_trace() assert "model" in kwargs if "model_conf" not in kwargs: logging.info( "download models from model hub: {}".format(kwargs.get("model_hub", "ms")) ) kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) main(**kwargs) def main(**kwargs): print(kwargs) # set random seed tables.print() set_all_random_seed(kwargs.get("seed", 0)) torch.backends.cudnn.enabled = kwargs.get( "cudnn_enabled", torch.backends.cudnn.enabled ) torch.backends.cudnn.benchmark = kwargs.get( "cudnn_benchmark", torch.backends.cudnn.benchmark ) torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) tokenizer = kwargs.get("tokenizer", None) # build frontend if frontend is none None frontend = kwargs.get("frontend", None) if frontend is not None: frontend_class = tables.frontend_classes.get(frontend) frontend = frontend_class(**kwargs["frontend_conf"]) kwargs["frontend"] = frontend kwargs["input_size"] = frontend.output_size() # dataset dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) dataset_train = dataset_class( kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf") ) # dataloader batch_sampler = kwargs["dataset_conf"].get( "batch_sampler", "DynamicBatchLocalShuffleSampler" ) batch_sampler_train = None if batch_sampler is not None: batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) dataset_conf = kwargs.get("dataset_conf") dataset_conf["batch_type"] = "example" dataset_conf["batch_size"] = 1 batch_sampler_train = batch_sampler_class( dataset_train, is_training=False, **dataset_conf ) dataloader_train = torch.utils.data.DataLoader( dataset_train, collate_fn=dataset_train.collator, batch_sampler=batch_sampler_train, num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)), pin_memory=True, ) iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train)) total_frames = 0 for batch_idx, batch in enumerate(dataloader_train): if batch_idx >= iter_stop: break fbank = batch["speech"].numpy()[0, :, :] if total_frames == 0: mean_stats = np.sum(fbank, axis=0) var_stats = np.sum(np.square(fbank), axis=0) else: mean_stats += np.sum(fbank, axis=0) var_stats += np.sum(np.square(fbank), axis=0) total_frames += fbank.shape[0] cmvn_info = { "mean_stats": list(mean_stats.tolist()), "var_stats": list(var_stats.tolist()), "total_frames": total_frames, } cmvn_file = kwargs.get("cmvn_file", "cmvn.json") # import pdb;pdb.set_trace() with open(cmvn_file, "w") as fout: fout.write(json.dumps(cmvn_info)) mean = -1.0 * mean_stats / total_frames var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean) dims = mean.shape[0] am_mvn = os.path.dirname(cmvn_file) + "/am.mvn" with open(am_mvn, "w") as fout: fout.write( "" + "\n" + " " + str(dims) + " " + str(dims) + "\n" + "[ 0 ]" + "\n" + " " + str(dims) + " " + str(dims) + "\n" ) mean_str = ( str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]") ) fout.write(" 0 " + mean_str + "\n") fout.write(" " + str(dims) + " " + str(dims) + "\n") var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]") fout.write(" 0 " + var_str + "\n") fout.write("" + "\n") """ python funasr/bin/compute_audio_cmvn.py \ --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \ --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \ ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \ ++dataset_conf.num_workers=0 """ if __name__ == "__main__": main_hydra()