|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import os |
|
import random |
|
import shutil |
|
import numpy as np |
|
|
|
import torch |
|
import tqdm |
|
from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import ( |
|
CpcFeatureReader, |
|
) |
|
from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import ( |
|
HubertFeatureReader, |
|
) |
|
from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import ( |
|
LogMelFeatureReader, |
|
) |
|
from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import ( |
|
Wav2VecFeatureReader, |
|
) |
|
|
|
|
|
def get_feature_reader(feature_type): |
|
if feature_type == "logmel": |
|
return LogMelFeatureReader |
|
elif feature_type == "hubert": |
|
return HubertFeatureReader |
|
elif feature_type == "w2v2": |
|
return Wav2VecFeatureReader |
|
elif feature_type == "cpc": |
|
return CpcFeatureReader |
|
else: |
|
raise NotImplementedError(f"{feature_type} is not supported.") |
|
|
|
|
|
def get_feature_iterator( |
|
feature_type, checkpoint_path, layer, manifest_path, sample_pct |
|
): |
|
feature_reader_cls = get_feature_reader(feature_type) |
|
with open(manifest_path, "r") as fp: |
|
lines = fp.read().split("\n") |
|
root = lines.pop(0).strip() |
|
file_path_list = [ |
|
os.path.join(root, line.split("\t")[0]) |
|
for line in lines |
|
if len(line) > 0 |
|
] |
|
if sample_pct < 1.0: |
|
file_path_list = random.sample( |
|
file_path_list, int(sample_pct * len(file_path_list)) |
|
) |
|
num_files = len(file_path_list) |
|
reader = feature_reader_cls( |
|
checkpoint_path=checkpoint_path, layer=layer |
|
) |
|
|
|
def iterate(): |
|
for file_path in file_path_list: |
|
feats = reader.get_feats(file_path) |
|
yield feats.cpu().numpy() |
|
|
|
return iterate, num_files |
|
|
|
|
|
def get_features( |
|
feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten |
|
): |
|
generator, num_files = get_feature_iterator( |
|
feature_type=feature_type, |
|
checkpoint_path=checkpoint_path, |
|
layer=layer, |
|
manifest_path=manifest_path, |
|
sample_pct=sample_pct, |
|
) |
|
iterator = generator() |
|
|
|
features_list = [] |
|
for features in tqdm.tqdm(iterator, total=num_files): |
|
features_list.append(features) |
|
|
|
|
|
del iterator |
|
del generator |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if flatten: |
|
return np.concatenate(features_list) |
|
|
|
return features_list |
|
|
|
|
|
def get_and_dump_features( |
|
feature_type, |
|
checkpoint_path, |
|
layer, |
|
manifest_path, |
|
sample_pct, |
|
flatten, |
|
out_features_path, |
|
): |
|
|
|
features_batch = get_features( |
|
feature_type=feature_type, |
|
checkpoint_path=checkpoint_path, |
|
layer=layer, |
|
manifest_path=manifest_path, |
|
sample_pct=sample_pct, |
|
flatten=flatten, |
|
) |
|
|
|
|
|
out_dir_path = os.path.dirname(out_features_path) |
|
os.makedirs(out_dir_path, exist_ok=True) |
|
shutil.copyfile( |
|
manifest_path, |
|
os.path.join(out_dir_path, os.path.basename(manifest_path)), |
|
) |
|
np.save(out_features_path, features_batch) |
|
|
|
return features_batch |
|
|