Sreeja123 commited on
Commit
c13bbb7
·
1 Parent(s): 3b3d551
Files changed (1) hide show
  1. Time_Distr.py +356 -0
Time_Distr.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper layer to apply every temporal slice of an input."""
2
+
3
+
4
+ import tensorflow.compat.v2 as tf
5
+
6
+ from keras import backend
7
+ from keras.engine.base_layer import Layer
8
+ from keras.engine.input_spec import InputSpec
9
+ from keras.layers.rnn.base_wrapper import Wrapper
10
+ from keras.utils import generic_utils
11
+ from keras.utils import layer_utils
12
+ from keras.utils import tf_utils
13
+
14
+ # isort: off
15
+ from tensorflow.python.util.tf_export import keras_export
16
+
17
+ @keras_export("keras.layers.TimeDistributed")
18
+ class TimeDistributed(Wrapper):
19
+ """This wrapper allows to apply a layer to every temporal slice of an input.
20
+
21
+ Every input should be at least 3D, and the dimension of index one of the
22
+ first input will be considered to be the temporal dimension.
23
+
24
+ Consider a batch of 32 video samples, where each sample is a 128x128 RGB
25
+ image with `channels_last` data format, across 10 timesteps.
26
+ The batch input shape is `(32, 10, 128, 128, 3)`.
27
+
28
+ You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
29
+ of the 10 timesteps, independently:
30
+
31
+ >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3))
32
+ >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
33
+ >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
34
+ >>> outputs.shape
35
+ TensorShape([None, 10, 126, 126, 64])
36
+
37
+ Because `TimeDistributed` applies the same instance of `Conv2D` to each of
38
+ the timestamps, the same set of weights are used at each timestamp.
39
+
40
+ Args:
41
+ layer: a `tf.keras.layers.Layer` instance.
42
+
43
+ Call arguments:
44
+ inputs: Input tensor of shape (batch, time, ...) or nested tensors,
45
+ and each of which has shape (batch, time, ...).
46
+ training: Python boolean indicating whether the layer should behave in
47
+ training mode or in inference mode. This argument is passed to the
48
+ wrapped layer (only if the layer supports this argument).
49
+ mask: Binary tensor of shape `(samples, timesteps)` indicating whether
50
+ a given timestep should be masked. This argument is passed to the
51
+ wrapped layer (only if the layer supports this argument).
52
+
53
+ Raises:
54
+ ValueError: If not initialized with a `tf.keras.layers.Layer` instance.
55
+ """
56
+
57
+
58
+ def __init__(self, layer, **kwargs):
59
+ if not isinstance(layer, Layer):
60
+ raise ValueError(
61
+ "Please initialize `TimeDistributed` layer with a "
62
+ f"`tf.keras.layers.Layer` instance. Received: {layer}"
63
+ )
64
+ super().__init__(layer, **kwargs)
65
+ self.supports_masking = True
66
+
67
+
68
+ # It is safe to use the fast, reshape-based approach with all of our
69
+ # built-in Layers.
70
+ self._always_use_reshape = layer_utils.is_builtin_layer(
71
+ layer
72
+ ) and not getattr(layer, "stateful", False)
73
+
74
+
75
+ def _get_shape_tuple(self, init_tuple, tensor, start_idx):
76
+ """Finds non-specific dimensions in the static shapes.
77
+
78
+ The static shapes are replaced with the corresponding dynamic shapes of
79
+ the tensor.
80
+ Args:
81
+ init_tuple: a tuple, the first part of the output shape
82
+ tensor: the tensor from which to get the (static and dynamic) shapes
83
+ as the last part of the output shape
84
+ start_idx: int, which indicate the first dimension to take from
85
+ the static shape of the tensor
86
+ Returns:
87
+ The new shape with the first part from `init_tuple` and the last part
88
+ from or `tensor.shape`, where every `None` is replaced by the
89
+ corresponding dimension from `tf.shape(tensor)`.
90
+ """
91
+ # replace all None in int_shape by backend.shape
92
+ int_shape = backend.int_shape(tensor)[start_idx:]
93
+ if not any(s is None for s in int_shape):
94
+ return init_tuple + int_shape
95
+ shape = backend.shape(tensor)
96
+ int_shape = list(int_shape)
97
+ for i, s in enumerate(int_shape):
98
+ if s is None:
99
+ int_shape[i] = shape[start_idx + i]
100
+ return init_tuple + tuple(int_shape)
101
+
102
+
103
+ def _remove_timesteps(self, dims):
104
+ dims = dims.as_list()
105
+ return tf.TensorShape([dims[0]] + dims[2:])
106
+
107
+
108
+ def build(self, input_shape):
109
+ input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
110
+ input_dims = tf.nest.flatten(
111
+ tf.nest.map_structure(lambda x: x.ndims, input_shape)
112
+ )
113
+ if any(dim < 3 for dim in input_dims):
114
+ raise ValueError(
115
+ "`TimeDistributed` Layer should be passed an `input_shape ` "
116
+ f"with at least 3 dimensions, received: {input_shape}"
117
+ )
118
+ # Don't enforce the batch or time dimension.
119
+ self.input_spec = tf.nest.map_structure(
120
+ lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]),
121
+ input_shape,
122
+ )
123
+ child_input_shape = tf.nest.map_structure(
124
+ self._remove_timesteps, input_shape
125
+ )
126
+ child_input_shape = tf_utils.convert_shapes(child_input_shape)
127
+ super().build(tuple(child_input_shape))
128
+ self.built = True
129
+
130
+
131
+ def compute_output_shape(self, input_shape):
132
+ input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
133
+
134
+
135
+ child_input_shape = tf.nest.map_structure(
136
+ self._remove_timesteps, input_shape
137
+ )
138
+ child_output_shape = self.layer.compute_output_shape(child_input_shape)
139
+ child_output_shape = tf_utils.convert_shapes(
140
+ child_output_shape, to_tuples=False
141
+ )
142
+ timesteps = tf_utils.convert_shapes(input_shape)
143
+ timesteps = tf.nest.flatten(timesteps)[1]
144
+
145
+
146
+ def insert_timesteps(dims):
147
+ dims = dims.as_list()
148
+ return tf.TensorShape([dims[0], timesteps] + dims[1:])
149
+
150
+
151
+ return tf.nest.map_structure(insert_timesteps, child_output_shape)
152
+
153
+
154
+ def call(self, inputs, training=None, mask=None):
155
+ kwargs = {}
156
+ if generic_utils.has_arg(self.layer.call, "training"):
157
+ kwargs["training"] = training
158
+
159
+
160
+ input_shape = tf.nest.map_structure(
161
+ lambda x: tf.TensorShape(backend.int_shape(x)), inputs
162
+ )
163
+ batch_size = tf_utils.convert_shapes(input_shape)
164
+ batch_size = tf.nest.flatten(batch_size)[0]
165
+ if batch_size and not self._always_use_reshape:
166
+ inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
167
+ is_ragged_input = row_lengths is not None
168
+ input_length = tf_utils.convert_shapes(input_shape)
169
+ input_length = tf.nest.flatten(input_length)[1]
170
+
171
+
172
+ # batch size matters, use rnn-based implementation
173
+ def step(x, _):
174
+ output = self.layer(x, **kwargs)
175
+ return output, []
176
+
177
+
178
+ _, outputs, _ = backend.rnn(
179
+ step,
180
+ inputs,
181
+ initial_states=[],
182
+ input_length=row_lengths[0]
183
+ if is_ragged_input
184
+ else input_length,
185
+ mask=mask,
186
+ unroll=False,
187
+ )
188
+
189
+
190
+ y = tf.nest.map_structure(
191
+ lambda output: backend.maybe_convert_to_ragged(
192
+ is_ragged_input, output, row_lengths
193
+ ),
194
+ outputs,
195
+ )
196
+ else:
197
+ # No batch size specified, therefore the layer will be able
198
+ # to process batches of any size.
199
+ # We can go with reshape-based implementation for performance.
200
+ is_ragged_input = tf.nest.map_structure(
201
+ lambda x: isinstance(x, tf.RaggedTensor), inputs
202
+ )
203
+ is_ragged_input = tf.nest.flatten(is_ragged_input)
204
+ if all(is_ragged_input):
205
+ input_values = tf.nest.map_structure(lambda x: x.values, inputs)
206
+ input_row_lenghts = tf.nest.map_structure(
207
+ lambda x: x.nested_row_lengths()[0], inputs
208
+ )
209
+ y = self.layer(input_values, **kwargs)
210
+ y = tf.nest.map_structure(
211
+ tf.RaggedTensor.from_row_lengths, y, input_row_lenghts
212
+ )
213
+ elif any(is_ragged_input):
214
+ raise ValueError(
215
+ "All inputs has to be either ragged or not, "
216
+ f"but not mixed. Received: {inputs}"
217
+ )
218
+ else:
219
+ input_length = tf_utils.convert_shapes(input_shape)
220
+ input_length = tf.nest.flatten(input_length)[1]
221
+ if not input_length:
222
+ input_length = tf.nest.map_structure(
223
+ lambda x: tf.shape(x)[1], inputs
224
+ )
225
+ input_length = generic_utils.to_list(
226
+ tf.nest.flatten(input_length)
227
+ )[0]
228
+
229
+
230
+ inner_input_shape = tf.nest.map_structure(
231
+ lambda x: self._get_shape_tuple((-1,), x, 2), inputs
232
+ )
233
+ # Shape: (num_samples * timesteps, ...). And track the
234
+ # transformation in self._input_map.
235
+ inputs = tf.__internal__.nest.map_structure_up_to(
236
+ inputs, tf.reshape, inputs, inner_input_shape
237
+ )
238
+ # (num_samples * timesteps, ...)
239
+ if (
240
+ generic_utils.has_arg(self.layer.call, "mask")
241
+ and mask is not None
242
+ ):
243
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
244
+ kwargs["mask"] = backend.reshape(mask, inner_mask_shape)
245
+
246
+
247
+ y = self.layer(inputs, **kwargs)
248
+
249
+
250
+ # Reconstruct the output shape by re-splitting the 0th dimension
251
+ # back into (num_samples, timesteps, ...)
252
+ # We use batch_size when available so that the 0th dimension is
253
+ # set in the static shape of the reshaped output
254
+ reshape_batch_size = batch_size if batch_size else -1
255
+ output_shape = tf.nest.map_structure(
256
+ lambda tensor: self._get_shape_tuple(
257
+ (reshape_batch_size, input_length), tensor, 1
258
+ ),
259
+ y,
260
+ )
261
+ y = tf.__internal__.nest.map_structure_up_to(
262
+ y, tf.reshape, y, output_shape
263
+ )
264
+
265
+
266
+ return y
267
+
268
+
269
+ def compute_mask(self, inputs, mask=None):
270
+ """Computes an output mask tensor for Embedding layer.
271
+
272
+ This is based on the inputs, mask, and the inner layer.
273
+ If batch size is specified:
274
+ Simply return the input `mask`. (An rnn-based implementation with
275
+ more than one rnn inputs is required but not supported in tf.keras yet.)
276
+ Otherwise we call `compute_mask` of the inner layer at each time step.
277
+ If the output mask at each time step is not `None`:
278
+ (E.g., inner layer is Masking or RNN)
279
+ Concatenate all of them and return the concatenation.
280
+ If the output mask at each time step is `None` and the input mask is not
281
+ `None`:(E.g., inner layer is Dense)
282
+ Reduce the input_mask to 2 dimensions and return it.
283
+ Otherwise (both the output mask and the input mask are `None`):
284
+ (E.g., `mask` is not used at all)
285
+ Return `None`.
286
+
287
+ Args:
288
+ inputs: Tensor with shape [batch size, timesteps, ...] indicating the
289
+ input to TimeDistributed. If static shape information is available
290
+ for "batch size", `mask` is returned unmodified.
291
+ mask: Either None (indicating no masking) or a Tensor indicating the
292
+ input mask for TimeDistributed. The shape can be static or dynamic.
293
+
294
+ Returns:
295
+ Either None (no masking), or a [batch size, timesteps, ...] Tensor
296
+ with an output mask for the TimeDistributed layer with the shape
297
+ beyond the second dimension being the value of the input mask shape(if
298
+ the computed output mask is none), an output mask with the shape
299
+ beyond the first dimension being the value of the mask shape(if mask
300
+ is not None) or output mask with the shape beyond the first dimension
301
+ being the value of the computed output shape.
302
+
303
+ """
304
+ # cases need to call the layer.compute_mask when input_mask is None:
305
+ # Masking layer and Embedding layer with mask_zero
306
+ input_shape = tf.nest.map_structure(
307
+ lambda x: tf.TensorShape(backend.int_shape(x)), inputs
308
+ )
309
+ input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
310
+ batch_size = tf_utils.convert_shapes(input_shape)
311
+ batch_size = tf.nest.flatten(batch_size)[0]
312
+ is_ragged_input = tf.nest.map_structure(
313
+ lambda x: isinstance(x, tf.RaggedTensor), inputs
314
+ )
315
+ is_ragged_input = generic_utils.to_list(
316
+ tf.nest.flatten(is_ragged_input)
317
+ )
318
+ if batch_size and not self._always_use_reshape or any(is_ragged_input):
319
+ # batch size matters, we currently do not handle mask explicitly, or
320
+ # if the layer always uses reshape approach, or the input is a
321
+ # ragged tensor.
322
+ return mask
323
+ inner_mask = mask
324
+ if inner_mask is not None:
325
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
326
+ inner_mask = backend.reshape(inner_mask, inner_mask_shape)
327
+ inner_input_shape = tf.nest.map_structure(
328
+ lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs
329
+ )
330
+ inner_inputs = tf.__internal__.nest.map_structure_up_to(
331
+ inputs, tf.reshape, inputs, inner_input_shape
332
+ )
333
+ output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
334
+ if output_mask is None:
335
+ if mask is None:
336
+ return None
337
+ # input_mask is not None, and output_mask is None:
338
+ # we should return a not-None mask
339
+ output_mask = mask
340
+ for _ in range(2, len(backend.int_shape(mask))):
341
+ output_mask = backend.any(output_mask, axis=-1)
342
+ else:
343
+ # output_mask is not None. We need to reshape it
344
+ input_length = tf_utils.convert_shapes(input_shape)
345
+ input_length = tf.nest.flatten(input_length)[1]
346
+ if not input_length:
347
+ input_length = tf.nest.map_structure(
348
+ lambda x: backend.shape(x)[1], inputs
349
+ )
350
+ input_length = tf.nest.flatten(input_length)[0]
351
+ reshape_batch_size = batch_size if batch_size else -1
352
+ output_mask_shape = self._get_shape_tuple(
353
+ (reshape_batch_size, input_length), output_mask, 1
354
+ )
355
+ output_mask = backend.reshape(output_mask, output_mask_shape)
356
+ return output_mask