|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import os.path as osp |
|
import math |
|
import numpy as np |
|
import tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
from shutil import copyfile |
|
|
|
from npy_append_array import NpyAppendArray |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
description="mean pools representations by compressing uniform splits of the data" |
|
) |
|
|
|
parser.add_argument('source', help='directory with features') |
|
parser.add_argument('--split', help='which split to read', required=True) |
|
parser.add_argument('--save-dir', help='where to save the output', required=True) |
|
parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to') |
|
|
|
parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s') |
|
|
|
|
|
return parser |
|
|
|
|
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
source_path = osp.join(args.source, args.split) |
|
|
|
print(f"data path: {source_path}") |
|
|
|
features = np.load(source_path + ".npy", mmap_mode="r") |
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
save_path = osp.join(args.save_dir, args.split) |
|
|
|
copyfile(source_path + ".tsv", save_path + ".tsv") |
|
|
|
if os.path.exists(source_path + ".phn"): |
|
copyfile(source_path + ".phn", save_path + ".phn") |
|
if os.path.exists(source_path + ".wrd"): |
|
copyfile(source_path + ".wrd", save_path + ".wrd") |
|
|
|
if os.path.exists(osp.join(args.source, "dict.phn.txt")): |
|
copyfile( |
|
osp.join(args.source, "dict.phn.txt"), |
|
osp.join(args.save_dir, "dict.phn.txt"), |
|
) |
|
|
|
if osp.exists(save_path + ".npy"): |
|
os.remove(save_path + ".npy") |
|
npaa = NpyAppendArray(save_path + ".npy") |
|
|
|
with open(source_path + ".lengths", "r") as lf: |
|
lengths = lf.readlines() |
|
|
|
fsz = features.shape[-1] |
|
start = 0 |
|
with torch.no_grad(): |
|
with open(save_path + ".lengths", "w") as lengths_out: |
|
for length in tqdm.tqdm(lengths): |
|
length = int(length) |
|
end = start + length |
|
feats = features[start:end] |
|
start += length |
|
x = torch.from_numpy(feats).cuda() |
|
target_num = math.ceil(length * args.subsample_rate) |
|
rem = length % target_num |
|
|
|
if rem > 0: |
|
if args.remove_extra: |
|
to_rem = target_num - rem |
|
target_num -= 1 |
|
x = x[:-to_rem] |
|
else: |
|
to_add = target_num - rem |
|
x = F.pad(x, [0, 0, 0, to_add]) |
|
x[-to_add:] = x[-to_add - 1] |
|
|
|
x = x.view(target_num, -1, fsz) |
|
x = x.mean(dim=-2) |
|
print(target_num, file=lengths_out) |
|
npaa.append(x.cpu().numpy()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|