Spaces:
Sleeping
Sleeping
import h5py | |
import numpy as np | |
from pathlib import Path | |
import tensorflow as tf | |
import torch | |
from tqdm import tqdm | |
import os | |
import sys | |
import glob | |
import random | |
import pdb | |
np.random.seed(0) | |
generate_tf_record = False | |
tfrecord_path = "/path/to/save/your/tfrecord/" | |
path_to_wavlm_feat = "/path/to/your/wavlm/feat" | |
if not os.path.exists(tfrecord_path): | |
generate_tf_record = True | |
os.makedirs(tfrecord_path, exist_ok=True) | |
train_filename = tfrecord_path + 'train' | |
valid_filename= tfrecord_path + 'valid' | |
test_filename= tfrecord_path + 'test' | |
train_path = Path(os.path.join(path_to_wavlm_feat, "train-clean-100")) | |
valid_path = Path(os.path.join(path_to_wavlm_feat, "dev-clean")) | |
test_path = Path(os.path.join(path_to_wavlm_feat, "test-clean")) | |
train_size = 27269 | |
valid_size = 1940 | |
test_size = 1850 | |
def get_filenames(path): | |
all_files = [] | |
all_files.extend(list(path.rglob("**/*.pt"))) | |
return all_files | |
def length_filter(paths): | |
filtered_paths = [] | |
print("filter short files") | |
for each in tqdm(paths): | |
data = torch.load(each).numpy().astype(np.float32) | |
if data.shape[0] < 200: | |
continue | |
filtered_paths.append(each) | |
return filtered_paths | |
def generate_mask(x, mask_type): | |
if mask_type == b'expand': | |
m = np.zeros_like(x) | |
N = np.random.randint(x.shape[0]//8, x.shape[0]) | |
ind = np.random.choice(x.shape[0], N, replace=False) | |
m[ind] = 1. | |
elif mask_type == b'few_expand': | |
m = np.zeros_like(x) | |
N = np.random.randint(x.shape[0]//8) | |
ind = np.random.choice(x.shape[0], N, replace=False) | |
m[ind] = 1. | |
elif mask_type == b'arb_expand': | |
m = np.zeros_like(x) | |
N = np.random.randint(x.shape[0]) | |
ind = np.random.choice(x.shape[0], N, replace=False) | |
m[ind] = 1. | |
elif mask_type == b'det_expand': | |
m = np.zeros_like(x) | |
ind = np.random.choice(x.shape[0], 100, replace=False) | |
m[ind] = 1. | |
elif mask_type == b'complete': | |
m = np.zeros_like(x) | |
while np.sum(m[:,0]) < x.shape[0] // 8: | |
p = np.random.uniform(-0.5, 0.5, size=4) | |
xa = np.concatenate([x, np.ones([x.shape[0],1])], axis=1) | |
m = (np.dot(xa, p) > 0).astype(np.float32) | |
m = np.repeat(np.expand_dims(m, axis=1), 3, axis=1) | |
else: | |
raise ValueError() | |
return m | |
def wrap_int64(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def wrap_bytes(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def print_progress(count, total): | |
# Percentage completion. | |
pct_complete = float(count) / total | |
# Status-message. | |
# Note the \r which means the line should overwrite itself. | |
msg = "\r- Progress: {0:.1%}".format(pct_complete) | |
# Print it. | |
sys.stdout.write(msg) | |
sys.stdout.flush() | |
def convert(image_paths, out_path, max_files=1000): | |
# Args: | |
# image_paths List of file-paths for the images. | |
# labels Class-labels for the images. | |
# out_path File-path for the TFRecords output file. | |
print("Converting: " + out_path) | |
# Number of images. Used when printing the progress. | |
num_images = len(image_paths) | |
splits = (num_images//max_files) + 1 | |
if num_images%max_files == 0: | |
splits-=1 | |
print(f"\nUsing {splits} shard(s) for {num_images} files, with up to {max_files} samples per shard") | |
file_count = 0 | |
for i in tqdm(range(splits)): | |
# Open a TFRecordWriter for the output-file. | |
with tf.io.TFRecordWriter("{}_{}_{}.tfrecords".format(out_path, i+1, splits)) as writer: | |
# Iterate over all the image-paths and class-labels. | |
current_shard_count = 0 | |
while current_shard_count < max_files: | |
index = i*max_files+current_shard_count | |
if index == len(image_paths): | |
break | |
current_image = image_paths[index] | |
# Load the image-file using matplotlib's imread function. | |
img = torch.load(current_image).numpy().astype(np.float32) | |
# Convert the image to raw bytes. | |
img_bytes = img.tostring() | |
# Create a dict with the data we want to save in the | |
# TFRecords file. You can add more relevant data here. | |
data = \ | |
{ | |
'image': wrap_bytes(img_bytes), | |
'length': wrap_int64(img.shape[0]), | |
"filename": wrap_bytes(bytes(os.path.splitext(current_image.name)[0], 'utf-8')) | |
} | |
# Wrap the data as TensorFlow Features. | |
feature = tf.train.Features(feature=data) | |
# Wrap again as a TensorFlow Example. | |
example = tf.train.Example(features=feature) | |
# Serialize the data. | |
serialized = example.SerializeToString() | |
# Write the serialized data to the TFRecords file. | |
writer.write(serialized) | |
current_shard_count+=1 | |
file_count += 1 | |
print(f"\nWrote {file_count} elements to TFRecord") | |
if generate_tf_record: | |
train_image_paths = length_filter(get_filenames(train_path)) | |
valid_image_paths = length_filter(get_filenames(valid_path)) | |
test_image_paths = length_filter(get_filenames(test_path)) | |
print(f"Number of training data after length filering: {len(train_image_paths)}") | |
print(f"Number of valid data after length filering: {len(valid_image_paths)}") | |
print(f"Number of testing data after length filering: {len(test_image_paths)}") | |
random.Random(4).shuffle(train_image_paths) | |
train_size = len(train_image_paths) | |
valid_size = len(valid_image_paths) | |
test_size = len(test_image_paths) | |
convert(image_paths=train_image_paths, | |
out_path=train_filename) | |
convert(image_paths=valid_image_paths, | |
out_path=valid_filename) | |
convert(image_paths=test_image_paths, | |
out_path=test_filename) | |
def parse(serialized): | |
# Define a dict with the data-names and types we expect to | |
# find in the TFRecords file. | |
# It is a bit awkward that this needs to be specified again, | |
# because it could have been written in the header of the | |
# TFRecords file instead. | |
features = \ | |
{ | |
'image': tf.io.FixedLenFeature([], tf.string), | |
'length': tf.io.FixedLenFeature([], tf.int64), | |
'filename': tf.io.FixedLenFeature([], tf.string), | |
} | |
# Parse the serialized data so we get a dict with our data. | |
parsed_example = tf.io.parse_single_example(serialized=serialized, | |
features=features) | |
# Get the image as raw bytes. | |
image_raw = parsed_example['image'] | |
# Decode the raw bytes so it becomes a tensor with type. | |
image = tf.io.decode_raw(image_raw, tf.float32) | |
# Get the label associated with the image. | |
length = parsed_example['length'] | |
image = tf.reshape(image, [length, 1024]) | |
filename = parsed_example['filename'] | |
# The image and label are now correct TensorFlow types. | |
return image, filename | |
def process(x, filename, set_size, mask_type): | |
x = x/10 | |
ind = np.random.choice(x.shape[0], set_size, replace=False) | |
x = x[ind] | |
m = generate_mask(x, mask_type) | |
#N = np.random.randint(set_size) | |
#S = np.random.randint(x.shape[0] - set_size + 1) | |
#x = x[S:S+set_size] | |
#m = np.zeros_like(x) | |
#S = np.random.randint(set_size - N + 1) | |
#m[S:S+N] = 1.0 | |
return x, m, filename | |
def get_dst(split, set_size, mask_type): | |
if split == 'train': | |
files = glob.glob(train_filename+"*.tfrecords", recursive=False) | |
dst = tf.data.TFRecordDataset(files) | |
size = train_size | |
dst = dst.map(parse) | |
dst = dst.shuffle(256) | |
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8) | |
elif split == 'valid': | |
files = glob.glob(valid_filename+"*.tfrecords", recursive=False) | |
dst = tf.data.TFRecordDataset(files) | |
size = valid_size | |
dst = dst.map(parse) | |
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8) | |
else: | |
files = glob.glob(test_filename+"*.tfrecords", recursive=False) | |
dst = tf.data.TFRecordDataset(files) | |
size = test_size | |
dst = dst.map(parse) | |
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8) | |
return dst, size | |
class Dataset(object): | |
def __init__(self, split, batch_size, set_size, mask_type): | |
g = tf.Graph() | |
with g.as_default(): | |
# open a session | |
config = tf.compat.v1.ConfigProto() | |
config.log_device_placement = True | |
config.allow_soft_placement = True | |
config.gpu_options.allow_growth = True | |
self.sess = tf.compat.v1.Session(config=config, graph=g) | |
# build dataset | |
dst, size = get_dst(split, set_size, mask_type) | |
self.size = size | |
self.num_batches = self.size // batch_size | |
dst = dst.batch(batch_size, drop_remainder=False) | |
dst = dst.prefetch(1) | |
dst_it = tf.compat.v1.data.make_initializable_iterator(dst) | |
x, b, filename = dst_it.get_next() | |
self.x = x | |
self.b = b | |
self.filename = filename | |
#self.x = tf.reshape(x, [batch_size, set_size, 1024]) | |
#self.b = tf.reshape(b, [batch_size, set_size, 1024]) | |
self.dimension = 1024 | |
self.initializer = dst_it.initializer | |
def initialize(self): | |
self.sess.run(self.initializer) | |
def next_batch(self): | |
x, b, filename = self.sess.run([self.x, self.b, self.filename]) | |
m = np.ones_like(b) | |
return {'x':x, 'b':b, 'm':m, "f":filename} |