|
|
|
|
|
|
|
|
|
|
|
import csv |
|
from pathlib import Path |
|
import zipfile |
|
from functools import reduce |
|
from multiprocessing import cpu_count |
|
from typing import Any, Dict, List, Optional, Union |
|
import io |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import sentencepiece as sp |
|
from fairseq.data.audio.audio_utils import ( |
|
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data, |
|
is_sf_audio_data |
|
) |
|
import torch |
|
import soundfile as sf |
|
from tqdm import tqdm |
|
|
|
|
|
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3 |
|
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0 |
|
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2 |
|
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1 |
|
|
|
|
|
def gen_vocab( |
|
input_path: Path, output_path_prefix: Path, model_type="bpe", |
|
vocab_size=1000, special_symbols: Optional[List[str]] = None |
|
): |
|
|
|
arguments = [ |
|
f"--input={input_path.as_posix()}", |
|
f"--model_prefix={output_path_prefix.as_posix()}", |
|
f"--model_type={model_type}", |
|
f"--vocab_size={vocab_size}", |
|
"--character_coverage=1.0", |
|
f"--num_threads={cpu_count()}", |
|
f"--unk_id={UNK_TOKEN_ID}", |
|
f"--bos_id={BOS_TOKEN_ID}", |
|
f"--eos_id={EOS_TOKEN_ID}", |
|
f"--pad_id={PAD_TOKEN_ID}", |
|
] |
|
if special_symbols is not None: |
|
_special_symbols = ",".join(special_symbols) |
|
arguments.append(f"--user_defined_symbols={_special_symbols}") |
|
sp.SentencePieceTrainer.Train(" ".join(arguments)) |
|
|
|
spm = sp.SentencePieceProcessor() |
|
spm.Load(output_path_prefix.as_posix() + ".model") |
|
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} |
|
assert ( |
|
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN |
|
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN |
|
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN |
|
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN |
|
) |
|
vocab = { |
|
i: s |
|
for i, s in vocab.items() |
|
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} |
|
} |
|
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out: |
|
for _, s in sorted(vocab.items(), key=lambda x: x[0]): |
|
f_out.write(f"{s} 1\n") |
|
|
|
|
|
def extract_fbank_features( |
|
waveform: torch.FloatTensor, |
|
sample_rate: int, |
|
output_path: Optional[Path] = None, |
|
n_mel_bins: int = 80, |
|
overwrite: bool = False, |
|
): |
|
if output_path is not None and output_path.is_file() and not overwrite: |
|
return |
|
|
|
_waveform = convert_waveform(waveform, sample_rate, to_mono=True) |
|
|
|
_waveform = _waveform * (2 ** 15) |
|
_waveform = _waveform.numpy() |
|
|
|
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) |
|
if features is None: |
|
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) |
|
if features is None: |
|
raise ImportError( |
|
"Please install pyKaldi or torchaudio to enable fbank feature extraction" |
|
) |
|
|
|
if output_path is not None: |
|
np.save(output_path.as_posix(), features) |
|
return features |
|
|
|
|
|
def create_zip(data_root: Path, zip_path: Path): |
|
paths = list(data_root.glob("*.npy")) |
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f: |
|
for path in tqdm(paths): |
|
f.write(path, arcname=path.name) |
|
|
|
|
|
def get_zip_manifest( |
|
zip_path: Path, zip_root: Optional[Path] = None, is_audio=False |
|
): |
|
_zip_path = Path.joinpath(zip_root or Path(""), zip_path) |
|
with zipfile.ZipFile(_zip_path, mode="r") as f: |
|
info = f.infolist() |
|
paths, lengths = {}, {} |
|
for i in tqdm(info): |
|
utt_id = Path(i.filename).stem |
|
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size |
|
paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}" |
|
with open(_zip_path, "rb") as f: |
|
f.seek(offset) |
|
byte_data = f.read(file_size) |
|
assert len(byte_data) > 1 |
|
if is_audio: |
|
assert is_sf_audio_data(byte_data), i |
|
else: |
|
assert is_npy_data(byte_data), i |
|
byte_data_fp = io.BytesIO(byte_data) |
|
if is_audio: |
|
lengths[utt_id] = sf.info(byte_data_fp).frames |
|
else: |
|
lengths[utt_id] = np.load(byte_data_fp).shape[0] |
|
return paths, lengths |
|
|
|
|
|
def gen_config_yaml( |
|
manifest_root: Path, |
|
spm_filename: Optional[str] = None, |
|
vocab_name: Optional[str] = None, |
|
yaml_filename: str = "config.yaml", |
|
specaugment_policy: Optional[str] = "lb", |
|
prepend_tgt_lang_tag: bool = False, |
|
sampling_alpha: Optional[float] = None, |
|
input_channels: Optional[int] = 1, |
|
input_feat_per_channel: Optional[int] = 80, |
|
audio_root: str = "", |
|
cmvn_type: str = "utterance", |
|
gcmvn_path: Optional[Path] = None, |
|
extra=None |
|
): |
|
manifest_root = manifest_root.absolute() |
|
writer = S2TDataConfigWriter(manifest_root / yaml_filename) |
|
assert spm_filename is not None or vocab_name is not None |
|
vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \ |
|
else vocab_name |
|
writer.set_vocab_filename(vocab_name) |
|
if input_channels is not None: |
|
writer.set_input_channels(input_channels) |
|
if input_feat_per_channel is not None: |
|
writer.set_input_feat_per_channel(input_feat_per_channel) |
|
specaugment_setters = { |
|
"lb": writer.set_specaugment_lb_policy, |
|
"ld": writer.set_specaugment_ld_policy, |
|
"sm": writer.set_specaugment_sm_policy, |
|
"ss": writer.set_specaugment_ss_policy, |
|
} |
|
specaugment_setter = specaugment_setters.get(specaugment_policy, None) |
|
if specaugment_setter is not None: |
|
specaugment_setter() |
|
if spm_filename is not None: |
|
writer.set_bpe_tokenizer( |
|
{ |
|
"bpe": "sentencepiece", |
|
"sentencepiece_model": (manifest_root / spm_filename).as_posix(), |
|
} |
|
) |
|
if prepend_tgt_lang_tag: |
|
writer.set_prepend_tgt_lang_tag(True) |
|
if sampling_alpha is not None: |
|
writer.set_sampling_alpha(sampling_alpha) |
|
|
|
if cmvn_type not in ["global", "utterance"]: |
|
raise NotImplementedError |
|
|
|
if specaugment_policy is not None: |
|
writer.set_feature_transforms( |
|
"_train", [f"{cmvn_type}_cmvn", "specaugment"] |
|
) |
|
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) |
|
|
|
if cmvn_type == "global": |
|
if gcmvn_path is None: |
|
raise ValueError("Please provide path of global cmvn file.") |
|
else: |
|
writer.set_global_cmvn(gcmvn_path.as_posix()) |
|
|
|
if len(audio_root) > 0: |
|
writer.set_audio_root(audio_root) |
|
|
|
if extra is not None: |
|
writer.set_extra(extra) |
|
writer.flush() |
|
|
|
|
|
def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame: |
|
_path = path if isinstance(path, str) else path.as_posix() |
|
return pd.read_csv( |
|
_path, |
|
sep="\t", |
|
header=0, |
|
encoding="utf-8", |
|
escapechar="\\", |
|
quoting=csv.QUOTE_NONE, |
|
na_filter=False, |
|
) |
|
|
|
|
|
def save_df_to_tsv(dataframe, path: Union[str, Path]): |
|
_path = path if isinstance(path, str) else path.as_posix() |
|
dataframe.to_csv( |
|
_path, |
|
sep="\t", |
|
header=True, |
|
index=False, |
|
encoding="utf-8", |
|
escapechar="\\", |
|
quoting=csv.QUOTE_NONE, |
|
) |
|
|
|
|
|
def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]: |
|
with open(path, "r") as f: |
|
reader = csv.DictReader( |
|
f, |
|
delimiter="\t", |
|
quotechar=None, |
|
doublequote=False, |
|
lineterminator="\n", |
|
quoting=csv.QUOTE_NONE, |
|
) |
|
rows = [dict(e) for e in reader] |
|
return rows |
|
|
|
|
|
def filter_manifest_df( |
|
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000 |
|
): |
|
filters = { |
|
"no speech": df["audio"] == "", |
|
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames, |
|
"empty sentence": df["tgt_text"] == "", |
|
} |
|
if is_train_split: |
|
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames |
|
if extra_filters is not None: |
|
filters.update(extra_filters) |
|
invalid = reduce(lambda x, y: x | y, filters.values()) |
|
valid = ~invalid |
|
print( |
|
"| " |
|
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items()) |
|
+ f", total {invalid.sum()} filtered, {valid.sum()} remained." |
|
) |
|
return df[valid] |
|
|
|
|
|
def cal_gcmvn_stats(features_list): |
|
features = np.concatenate(features_list) |
|
square_sums = (features ** 2).sum(axis=0) |
|
mean = features.mean(axis=0) |
|
features = np.subtract(features, mean) |
|
var = square_sums / features.shape[0] - mean ** 2 |
|
std = np.sqrt(np.maximum(var, 1e-8)) |
|
return {"mean": mean.astype("float32"), "std": std.astype("float32")} |
|
|
|
|
|
class S2TDataConfigWriter(object): |
|
DEFAULT_VOCAB_FILENAME = "dict.txt" |
|
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 |
|
DEFAULT_INPUT_CHANNELS = 1 |
|
|
|
def __init__(self, yaml_path: Path): |
|
try: |
|
import yaml |
|
except ImportError: |
|
print("Please install PyYAML for S2T data config YAML files") |
|
self.yaml = yaml |
|
self.yaml_path = yaml_path |
|
self.config = {} |
|
|
|
def flush(self): |
|
with open(self.yaml_path, "w") as f: |
|
self.yaml.dump(self.config, f) |
|
|
|
def set_audio_root(self, audio_root=""): |
|
self.config["audio_root"] = audio_root |
|
|
|
def set_vocab_filename(self, vocab_filename: str = "dict.txt"): |
|
self.config["vocab_filename"] = vocab_filename |
|
|
|
def set_specaugment( |
|
self, |
|
time_wrap_w: int, |
|
freq_mask_n: int, |
|
freq_mask_f: int, |
|
time_mask_n: int, |
|
time_mask_t: int, |
|
time_mask_p: float, |
|
): |
|
self.config["specaugment"] = { |
|
"time_wrap_W": time_wrap_w, |
|
"freq_mask_N": freq_mask_n, |
|
"freq_mask_F": freq_mask_f, |
|
"time_mask_N": time_mask_n, |
|
"time_mask_T": time_mask_t, |
|
"time_mask_p": time_mask_p, |
|
} |
|
|
|
def set_specaugment_lb_policy(self): |
|
self.set_specaugment( |
|
time_wrap_w=0, |
|
freq_mask_n=1, |
|
freq_mask_f=27, |
|
time_mask_n=1, |
|
time_mask_t=100, |
|
time_mask_p=1.0, |
|
) |
|
|
|
def set_specaugment_ld_policy(self): |
|
self.set_specaugment( |
|
time_wrap_w=0, |
|
freq_mask_n=2, |
|
freq_mask_f=27, |
|
time_mask_n=2, |
|
time_mask_t=100, |
|
time_mask_p=1.0, |
|
) |
|
|
|
def set_specaugment_sm_policy(self): |
|
self.set_specaugment( |
|
time_wrap_w=0, |
|
freq_mask_n=2, |
|
freq_mask_f=15, |
|
time_mask_n=2, |
|
time_mask_t=70, |
|
time_mask_p=0.2, |
|
) |
|
|
|
def set_specaugment_ss_policy(self): |
|
self.set_specaugment( |
|
time_wrap_w=0, |
|
freq_mask_n=2, |
|
freq_mask_f=27, |
|
time_mask_n=2, |
|
time_mask_t=70, |
|
time_mask_p=0.2, |
|
) |
|
|
|
def set_input_channels(self, input_channels: int = 1): |
|
self.config["input_channels"] = input_channels |
|
|
|
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): |
|
self.config["input_feat_per_channel"] = input_feat_per_channel |
|
|
|
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): |
|
self.config["bpe_tokenizer"] = bpe_tokenizer |
|
|
|
def set_global_cmvn(self, stats_npz_path: str): |
|
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path} |
|
|
|
def set_feature_transforms(self, split: str, transforms: List[str]): |
|
if "transforms" not in self.config: |
|
self.config["transforms"] = {} |
|
self.config["transforms"][split] = transforms |
|
|
|
def set_prepend_tgt_lang_tag(self, flag: bool = True): |
|
self.config["prepend_tgt_lang_tag"] = flag |
|
|
|
def set_sampling_alpha(self, sampling_alpha: float = 1.0): |
|
self.config["sampling_alpha"] = sampling_alpha |
|
|
|
def set_extra(self, data): |
|
self.config.update(data) |
|
|