Spaces:
Runtime error
Runtime error
File size: 24,127 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes
import math
from typing import Any, Callable, Optional, Sequence, Union
from absl import logging
import numpy as np
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
_Initializer = Union[str, tf_keras.initializers.Initializer]
_Activation = Union[str, Callable[..., Any]]
_MAX = 'max'
_AVG = 'avg'
_TRUNCATED_AVG = 'truncated_avg'
_transformer_cls2str = {
layers.TransformerEncoderBlock: 'TransformerEncoderBlock',
layers.ReZeroTransformer: 'ReZeroTransformer'
}
_str2transformer_cls = {
'TransformerEncoderBlock': layers.TransformerEncoderBlock,
'ReZeroTransformer': layers.ReZeroTransformer
}
_approx_gelu = lambda x: tf_keras.activations.gelu(x, approximate=True)
def _get_policy_dtype():
try:
return tf_keras.mixed_precision.global_policy().compute_dtype or tf.float32
except AttributeError: # tf1 has no attribute 'global_policy'
return tf.float32
def _pool_and_concat(mask, unpool_length: int, strides: Union[Sequence[int],
int],
axes: Union[Sequence[int], int]):
"""Pools the mask along a given axis with stride.
It also skips first unpool_length elements.
Args:
mask: Tensor to be pooled.
unpool_length: Leading elements to be skipped.
strides: Strides for the given axes.
axes: Axes to pool the Tensor.
Returns:
Pooled and concatenated Tensor.
"""
# Wraps the axes as a list.
if isinstance(axes, int):
axes = [axes]
if isinstance(strides, int):
strides = [strides] * len(axes)
else:
if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.')
# Bypass no pooling cases.
if np.all(np.array(strides) == 1):
return mask
for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = mask[unpool_tensor_shape]
# Pools the second half.
pool_tensor_shape = [slice(None)] * axis + [
slice(unpool_length, None, stride)
]
pool_tensor = mask[pool_tensor_shape]
mask = tf.concat((unpool_tensor, pool_tensor), axis=axis)
return mask
def _create_fractional_pool_transform(sl: int, pool_factor: float):
"""Create pooling transform for fractional pooling factor."""
assert pool_factor > 1.0, '`pool_factor` should be > 1.0.'
psl = int(sl / pool_factor)
gcd_ = math.gcd(sl, psl)
# It is expected chunk_sl and chunk_psl are small integers.
# The transform is built by tiling a [chunk_sl, chunk_psl] submatrix
# gcd_ times. The submatrix sums to chunk_psl.
chunk_sl = sl // gcd_
chunk_psl = psl // gcd_
num_one_entries = chunk_psl - 1
num_frac_entries = chunk_sl - (chunk_psl - 1)
# The transform is of shape [sl, psl].
transform = np.zeros((sl, psl))
for i in range(sl // chunk_sl):
row_start = chunk_sl * i
col_start = chunk_psl * i
for idx in range(num_one_entries):
transform[row_start + idx][col_start + idx] = 1.0
for idx in range(num_frac_entries):
transform[row_start + num_one_entries + idx][
col_start + num_one_entries
] = (1.0 / num_frac_entries)
return tf.constant(transform, dtype=_get_policy_dtype())
def _create_truncated_avg_transforms(
seq_length: int, pool_strides: Sequence[int]
):
"""Computes pooling transforms.
The pooling_transform is of shape [seq_length,
seq_length//pool_stride] and
pooling_transform[i,j] = 1.0/pool_stride if i//pool_stride == j
0.0 otherwise.
It's in essense average pooling but truncate the final window if it
seq_length % pool_stride != 0.
For seq_length==6 and pool_stride==2, it is
[[ 0.5, 0.0, 0.0 ],
[ 0.5, 0.0, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.5, 0.0 ],
[ 0.0, 0.0, 0.5 ],
[ 0.0, 0.0, 0.5 ]]
Args:
seq_length: int, sequence length.
pool_strides: Sequence of pooling strides for each layer.
Returns:
pooling_transforms: Sequence of pooling transforms (Tensors) for each layer.
"""
pooling_transforms = []
for pool_stride in pool_strides:
if pool_stride == 1:
pooling_transforms.append(None)
else:
pooled_seq_length = int(seq_length / pool_stride)
if (1.0 * pool_stride).is_integer():
pfac, sl, psl = pool_stride, seq_length, pooled_seq_length
transform = [
[1.0 if (i // pfac) == j else 0.0 for j in range(psl)]
for i in range(sl)
]
transform = (
tf.constant(transform, dtype=_get_policy_dtype()) / pool_stride
)
else:
transform = _create_fractional_pool_transform(seq_length, pool_stride)
pooling_transforms.append(transform)
seq_length = pooled_seq_length
return pooling_transforms
def _create_truncated_avg_masks(input_mask: tf.Tensor,
pool_strides: Sequence[int],
transforms: Sequence[tf.Tensor]):
"""Computes attention masks.
For [1,1,1,0,0]
Args:
input_mask: Tensor of shape [batch_size, seq_length].
pool_strides: Sequence of pooling strides for each layer.
transforms: Sequence of off-diagonal matrices filling with 0.0 and
1/pool_stride.
Returns:
attention_masks: Sequence of attention masks for each layer.
"""
def create_2d_mask(from_length, mask):
return tf.einsum('F,BT->BFT', tf.ones([from_length], dtype=mask.dtype),
mask)
attention_masks = []
seq_length = tf.shape(input_mask)[-1]
layer_mask = tf.cast(input_mask, dtype=_get_policy_dtype())
for pool_stride, transform in zip(pool_strides, transforms):
if pool_stride == 1:
attention_masks.append(create_2d_mask(seq_length, layer_mask))
else:
pooled_seq_length = tf.cast(
tf.cast(seq_length, tf.float32) / tf.cast(pool_stride, tf.float32),
tf.int32,
)
attention_masks.append(create_2d_mask(pooled_seq_length, layer_mask))
layer_mask = tf.cast(
tf.einsum('BF,FT->BT', layer_mask, transform) > 0.0,
dtype=layer_mask.dtype,
)
seq_length = pooled_seq_length
del seq_length
return attention_masks
@tf_keras.utils.register_keras_serializable(package='Text')
class FunnelTransformerEncoder(tf_keras.layers.Layer):
"""Funnel Transformer-based encoder network.
Funnel Transformer Implementation of https://arxiv.org/abs/2006.03236.
This implementation utilizes the base framework with Bert
(https://arxiv.org/abs/1810.04805).
Its output is compatible with `BertEncoder`.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg', 'truncated_avg'].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized. This does not apply to ReZero.
transformer_cls: str or a keras Layer. This is the base TransformerBlock the
funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero.
with_dense_inputs: Whether to accept dense embeddings as the input.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: _Activation = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
pool_type: str = _MAX,
pool_stride: Union[int, Sequence[Union[int, float]]] = 2,
unpool_length: int = 0,
initializer: _Initializer = tf_keras.initializers.TruncatedNormal(
stddev=0.02
),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf_keras.layers.Layer] = None,
norm_first: bool = False,
transformer_cls: Union[
str, tf_keras.layers.Layer
] = layers.TransformerEncoderBlock,
share_rezero: bool = False,
append_dense_inputs: bool = False,
**kwargs
):
super().__init__(**kwargs)
if output_range is not None:
logging.warning('`output_range` is available as an argument for `call()`.'
'The `output_range` as __init__ argument is deprecated.')
activation = tf_keras.activations.get(inner_activation)
initializer = tf_keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf_keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf_keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf_keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
# Will raise an error if the string is not supported.
if isinstance(transformer_cls, str):
transformer_cls = _str2transformer_cls[transformer_cls]
self._num_layers = num_layers
for i in range(num_layers):
layer = transformer_cls(
num_attention_heads=num_attention_heads,
intermediate_size=inner_dim,
inner_dim=inner_dim,
intermediate_activation=inner_activation,
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf_keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared.
pool_strides = [pool_stride] * num_layers
else:
if len(pool_stride) != num_layers:
raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride
is_fractional_pooling = False in [
(1.0 * pool_stride).is_integer() for pool_stride in pool_strides
]
if is_fractional_pooling and pool_type in [_MAX, _AVG]:
raise ValueError(
'Fractional pooling is only supported for'
' `pool_type`=`truncated_average`'
)
# TODO(crickwu): explore tf_keras.layers.serialize method.
if pool_type == _MAX:
pool_cls = tf_keras.layers.MaxPooling1D
elif pool_type == _AVG:
pool_cls = tf_keras.layers.AveragePooling1D
elif pool_type == _TRUNCATED_AVG:
# TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.')
else:
raise ValueError('pool_type not supported.')
if pool_type in (_MAX, _AVG):
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides:
att_input_pool_layer = pool_cls(
pool_size=layer_pool_stride,
strides=layer_pool_stride,
padding='same',
name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer)
self._max_sequence_length = max_sequence_length
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length
self._pool_type = pool_type
self._append_dense_inputs = append_dense_inputs
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf_keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf_keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'pool_type': pool_type,
'pool_stride': pool_stride,
'unpool_length': unpool_length,
'transformer_cls': _transformer_cls2str.get(
transformer_cls, str(transformer_cls)
),
}
self.inputs = dict(
input_word_ids=tf_keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf_keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf_keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs, output_range: Optional[tf.Tensor] = None):
# inputs are [word_ids, mask, type_ids]
word_embeddings = None
if isinstance(inputs, (list, tuple)):
logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3:
word_ids, mask, type_ids = inputs
dense_inputs = None
dense_mask = None
dense_type_ids = None
elif len(inputs) == 6:
word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids = (
inputs
)
else:
raise ValueError(
'Unexpected inputs to %s with length at %d.'
% (self.__class__, len(inputs))
)
elif isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Allow concatenation of the dense embeddings at sequence end if requested
# and `unpool_length`` is set as zero
if self._append_dense_inputs:
if self._unpool_length != 0:
raise ValueError(
'unpool_length is not supported by append_dense_inputs now.'
)
word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
else:
# Concat the dense embeddings at sequence begin so unpool_len can
# control embedding not being pooled.
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1)
type_ids = tf.concat([dense_type_ids, type_ids], axis=1)
mask = tf.concat([dense_mask, mask], axis=1)
# absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = tf_keras.layers.add(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
# TODO(b/195972228): attention_mask can be co-generated with pooling.
if self._pool_type in (_MAX, _AVG):
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
strides=self._pool_strides[0],
axes=[1])
for i, layer in enumerate(self._transformer_layers):
transformer_output_range = None
if i == self._num_layers - 1:
transformer_output_range = output_range
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer(
[x, x, attention_mask], output_range=transformer_output_range
)
else:
# Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs),
axis=1)
x = layer([query_inputs, x, attention_mask],
output_range=transformer_output_range)
# Pools the corresponding attention_mask.
if i < len(self._transformer_layers) - 1:
attention_mask = _pool_and_concat(
attention_mask,
unpool_length=self._unpool_length,
strides=[self._pool_strides[i + 1], self._pool_strides[i]],
axes=[1, 2])
encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms = _create_truncated_avg_transforms(
self._max_sequence_length, self._pool_strides)
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
pooling_transforms)
for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i]
transformer_output_range = None
if i == self._num_layers - 1:
transformer_output_range = output_range
# Bypass no pooling cases.
if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask],
output_range=transformer_output_range)
else:
pooled_inputs = tf.einsum(
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation.
pooling_transforms[i])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs),
axis=1)
x = layer([query_inputs, x, attention_mask],
output_range=transformer_output_range)
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
word_embeddings=word_embeddings,
embedding_output=embeddings,
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
|