File size: 10,225 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import h5py
import numpy as np
from pathlib import Path
import tensorflow as tf
import torch
from tqdm import tqdm
import os
import sys
import glob
import random
import pdb
np.random.seed(0)
generate_tf_record = False

tfrecord_path = "/path/to/save/your/tfrecord/"
path_to_wavlm_feat = "/path/to/your/wavlm/feat"

if not os.path.exists(tfrecord_path):
    generate_tf_record = True
os.makedirs(tfrecord_path, exist_ok=True)
train_filename = tfrecord_path + 'train'
valid_filename= tfrecord_path + 'valid'
test_filename= tfrecord_path + 'test'
train_path = Path(os.path.join(path_to_wavlm_feat, "train-clean-100"))
valid_path = Path(os.path.join(path_to_wavlm_feat, "dev-clean"))
test_path = Path(os.path.join(path_to_wavlm_feat, "test-clean"))

train_size = 27269
valid_size = 1940
test_size = 1850

def get_filenames(path):
    all_files = []
    all_files.extend(list(path.rglob("**/*.pt")))
    return all_files

def length_filter(paths):
    filtered_paths = []
    print("filter short files")
    for each in tqdm(paths):
        data = torch.load(each).numpy().astype(np.float32)
        if data.shape[0] < 200:
            continue
        filtered_paths.append(each)
    return filtered_paths


def generate_mask(x, mask_type):
    if mask_type == b'expand':
        m = np.zeros_like(x)
        N = np.random.randint(x.shape[0]//8, x.shape[0])
        ind = np.random.choice(x.shape[0], N, replace=False)
        m[ind] = 1.
    elif mask_type == b'few_expand':
        m = np.zeros_like(x)
        N = np.random.randint(x.shape[0]//8)
        ind = np.random.choice(x.shape[0], N, replace=False)
        m[ind] = 1.
    elif mask_type == b'arb_expand':
        m = np.zeros_like(x)
        N = np.random.randint(x.shape[0])
        ind = np.random.choice(x.shape[0], N, replace=False)
        m[ind] = 1.
    elif mask_type == b'det_expand':
        m = np.zeros_like(x)
        ind = np.random.choice(x.shape[0], 100, replace=False)
        m[ind] = 1.
    elif mask_type == b'complete':
        m = np.zeros_like(x)
        while np.sum(m[:,0]) < x.shape[0] // 8:
            p = np.random.uniform(-0.5, 0.5, size=4)
            xa = np.concatenate([x, np.ones([x.shape[0],1])], axis=1)
            m = (np.dot(xa, p) > 0).astype(np.float32)
            m = np.repeat(np.expand_dims(m, axis=1), 3, axis=1)
    else:
        raise ValueError()

    return m


def wrap_int64(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def wrap_bytes(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def print_progress(count, total):
    # Percentage completion.
    pct_complete = float(count) / total

    # Status-message.
    # Note the \r which means the line should overwrite itself.
    msg = "\r- Progress: {0:.1%}".format(pct_complete)

    # Print it.
    sys.stdout.write(msg)
    sys.stdout.flush()
def convert(image_paths, out_path, max_files=1000):
    # Args:
    # image_paths   List of file-paths for the images.
    # labels        Class-labels for the images.
    # out_path      File-path for the TFRecords output file.
    
    print("Converting: " + out_path)
    
    # Number of images. Used when printing the progress.
    num_images = len(image_paths)
    splits = (num_images//max_files) + 1 
    if num_images%max_files == 0:
        splits-=1
    print(f"\nUsing {splits} shard(s) for {num_images} files, with up to {max_files} samples per shard")
    file_count = 0
    for i in tqdm(range(splits)):
        # Open a TFRecordWriter for the output-file.
        with tf.io.TFRecordWriter("{}_{}_{}.tfrecords".format(out_path, i+1, splits)) as writer:
            
            # Iterate over all the image-paths and class-labels.
            current_shard_count = 0
            while current_shard_count < max_files: 
                index = i*max_files+current_shard_count
                if index == len(image_paths):
                    break
                current_image = image_paths[index]

                # Load the image-file using matplotlib's imread function.
                img = torch.load(current_image).numpy().astype(np.float32)
                
                # Convert the image to raw bytes.
                img_bytes = img.tostring()

                # Create a dict with the data we want to save in the
                # TFRecords file. You can add more relevant data here.
                data = \
                    {
                        'image': wrap_bytes(img_bytes),
                        'length': wrap_int64(img.shape[0]),
                        "filename": wrap_bytes(bytes(os.path.splitext(current_image.name)[0], 'utf-8'))
                    }

                # Wrap the data as TensorFlow Features.
                feature = tf.train.Features(feature=data)

                # Wrap again as a TensorFlow Example.
                example = tf.train.Example(features=feature)

                # Serialize the data.
                serialized = example.SerializeToString()
                
                # Write the serialized data to the TFRecords file.
                writer.write(serialized)
                current_shard_count+=1
                file_count += 1
    print(f"\nWrote {file_count} elements to TFRecord")


if generate_tf_record:
    train_image_paths = length_filter(get_filenames(train_path))
    valid_image_paths = length_filter(get_filenames(valid_path))
    test_image_paths = length_filter(get_filenames(test_path))
    print(f"Number of training data after length filering: {len(train_image_paths)}")
    print(f"Number of valid data after length filering: {len(valid_image_paths)}")
    print(f"Number of testing data after length filering: {len(test_image_paths)}")
    random.Random(4).shuffle(train_image_paths)

    train_size = len(train_image_paths)
    valid_size = len(valid_image_paths)
    test_size = len(test_image_paths)
    convert(image_paths=train_image_paths,
            out_path=train_filename)

    convert(image_paths=valid_image_paths,
            out_path=valid_filename)

    convert(image_paths=test_image_paths,
            out_path=test_filename)
    

def parse(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.
    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'length': tf.io.FixedLenFeature([], tf.int64),
            'filename': tf.io.FixedLenFeature([], tf.string),
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_raw(image_raw, tf.float32)
    

    # Get the label associated with the image.
    length = parsed_example['length']

    image = tf.reshape(image, [length, 1024])
    filename = parsed_example['filename']

    # The image and label are now correct TensorFlow types.
    return image, filename

def process(x, filename, set_size, mask_type):
    x = x/10
    ind = np.random.choice(x.shape[0], set_size, replace=False)
    x = x[ind] 
    m = generate_mask(x, mask_type)
    #N = np.random.randint(set_size)
    #S = np.random.randint(x.shape[0] - set_size + 1)
    #x = x[S:S+set_size]
    #m = np.zeros_like(x)
    #S = np.random.randint(set_size - N + 1)
    #m[S:S+N] = 1.0
    return x, m, filename


def get_dst(split, set_size, mask_type):
    if split == 'train':
        files = glob.glob(train_filename+"*.tfrecords", recursive=False)
        dst = tf.data.TFRecordDataset(files)
        size = train_size
        dst = dst.map(parse)
        dst = dst.shuffle(256)
        dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
    elif split == 'valid':
        files = glob.glob(valid_filename+"*.tfrecords", recursive=False)
        dst = tf.data.TFRecordDataset(files)
        size = valid_size
        dst = dst.map(parse)
        dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
    else:
        files = glob.glob(test_filename+"*.tfrecords", recursive=False)
        dst = tf.data.TFRecordDataset(files)
        size = test_size
        dst = dst.map(parse)
        dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
    return dst, size

class Dataset(object):
    def __init__(self, split, batch_size, set_size, mask_type):
        g = tf.Graph()
        with g.as_default():
            # open a session
            config = tf.compat.v1.ConfigProto()
            config.log_device_placement = True
            config.allow_soft_placement = True
            config.gpu_options.allow_growth = True
            self.sess = tf.compat.v1.Session(config=config, graph=g)
            # build dataset
            dst, size = get_dst(split, set_size, mask_type)
            self.size = size
            self.num_batches = self.size // batch_size
            dst = dst.batch(batch_size, drop_remainder=False)
            dst = dst.prefetch(1)

            dst_it = tf.compat.v1.data.make_initializable_iterator(dst)
            x, b, filename  = dst_it.get_next()
            self.x = x
            self.b = b
            self.filename = filename
            #self.x = tf.reshape(x, [batch_size, set_size, 1024])
            #self.b = tf.reshape(b, [batch_size, set_size, 1024])
            self.dimension = 1024
            self.initializer = dst_it.initializer

    def initialize(self):
        self.sess.run(self.initializer)

    def next_batch(self):
        x, b, filename = self.sess.run([self.x, self.b, self.filename])
        m = np.ones_like(b)
        return {'x':x, 'b':b, 'm':m, "f":filename}