Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
p = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0] | |
sys.path.append(p) | |
import argparse | |
import numpy as np | |
import tensorflow as tf | |
tf.compat.v1.disable_v2_behavior() | |
from utils.hparams import HParams | |
from models import get_model | |
import torch | |
set_size = 200 | |
threshold = 100 | |
def fast_cosine_dist(source_feats, matching_pool): | |
source_norms = torch.norm(source_feats, p=2, dim=-1) | |
matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
dotprod /= 2 | |
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
return dists | |
def evaluate(batch, model): | |
sample = model.execute(model.sample, batch) | |
return sample | |
def prematch(path, expanded): | |
uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.pt'))) | |
uttrs_from_same_spk.remove(path) | |
candidates = [] | |
for each in uttrs_from_same_spk: | |
candidates.append(torch.load(each)) | |
candidates = torch.cat(candidates,0) | |
candidates = torch.cat([candidates, torch.tensor(expanded)], 0) | |
source_feats = torch.load(path) | |
source_feats=source_feats.to(torch.float32) | |
dists = fast_cosine_dist(source_feats.cpu(), candidates.cpu()).cpu() | |
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4) | |
out_feats = candidates[best.indices].mean(dim=1) # (N, dim) | |
return out_feats | |
def single_expand(path, model, num_samples, seed=1234, out_path=None): | |
np.random.seed(seed) | |
tf.compat.v1.set_random_seed(seed) | |
# test | |
matching_set = torch.load(path, map_location=torch.device('cpu')).numpy() | |
matching_set = matching_set / 10 | |
matching_size = matching_set.shape[0] | |
new_samples = [] | |
cur_num_samples = 0 | |
while cur_num_samples < num_samples: | |
batch = dict() | |
if matching_size < threshold: | |
num_new_samples = set_size - matching_size | |
padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...] | |
batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...] | |
batch['m'] = np.ones_like(batch['b']) | |
sample = evaluate(batch, model) | |
new_sample = sample[0,matching_size:] * 10 | |
cur_num_samples += num_new_samples | |
else: | |
num_new_samples = set_size - threshold | |
ind = np.random.choice(matching_size, threshold, replace=False) | |
padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
obs_data = matching_set[ind] | |
batch['x'] = np.concatenate([obs_data, padded_data], 0)[None, ...] | |
batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...] | |
batch['m'] = np.ones_like(batch['b']) | |
sample = evaluate(batch, model) | |
new_sample = sample[0,num_new_samples:,:] * 10 | |
cur_num_samples += num_new_samples | |
new_samples.append(new_sample) | |
new_samples = np.concatenate(new_samples, 0) | |
new_samples = new_samples[:num_samples] | |
if out_path: | |
os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
np.save(out_path, new_samples) | |
return new_samples | |
def single_expand_fast(path): | |
# test | |
matching_set = torch.load(path).cpu().numpy() | |
matching_set = matching_set / 10 | |
matching_size = matching_set.shape[0] | |
batch = dict() | |
if matching_size < threshold: | |
num_new_samples = set_size - matching_size | |
else: | |
num_new_samples = set_size - threshold | |
batch_size = int(np.ceil(args.num_samples // num_new_samples)) | |
if matching_size < threshold: | |
padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...] | |
batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...] | |
batch['b'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
batch['x'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
batch['m'] = np.ones_like(batch['b']) | |
sample = evaluate(batch, model) | |
new_samples = sample[:,matching_size:, :] * 10 | |
new_samples = new_samples.reshape((-1, new_samples.shape[-1])) | |
else: | |
padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
batch['x'] = [] | |
for i in range(batch_size): | |
ind = np.random.choice(matching_size, threshold, replace=False) | |
obs_data = matching_set[ind] | |
batch['x'].append(np.concatenate([obs_data, padded_data], 0)[None, ...]) | |
batch['x'] = np.concatenate(batch['x'], 0) | |
batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...] | |
batch['b'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
batch['m'] = np.ones_like(batch['b']) | |
sample = evaluate(batch, model) | |
new_samples = sample[:,matching_size:, :] * 10 | |
new_samples = new_samples.reshape((-1, new_samples.shape[-1])) | |
new_samples = new_samples[:args.num_samples] | |
return new_samples | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--cfg_file', type=str) | |
parser.add_argument('--seed', type=int, default=1234) | |
parser.add_argument('--gpu', type=str, default='0') | |
parser.add_argument('--num_samples', type=int, default=100) | |
parser.add_argument('--path', type=str, default="matching_set.pt") | |
parser.add_argument('--out_path', type=str, default="expanded_set.pt") | |
parser.add_argument('--topk', type=int, default=4) | |
args = parser.parse_args() | |
params = HParams(args.cfg_file) | |
# modify config | |
t0 = time.time() | |
# model | |
model = get_model(params) | |
model.load() | |
t1 = time.time() | |
print(f"{t1-t0:.2f}s to load the model") | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |
path = args.path | |
if path.endswith(".pt"): | |
t0 = time.time() | |
expanded = single_expand(path, model, args.num_samples, args.seed, args.out_path) | |
t1 = time.time() | |
print(f"{t1-t0:.2f}s to expand the set") | |