|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Keras-based transformer block layer.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import functools |
|
|
|
import tensorflow as tf |
|
|
|
|
|
class TfFunctionIfEagerDecorator(object): |
|
"""Helper decorator function to optionally apply the @tf.function annotation.""" |
|
|
|
def __init__(self, **kwargs): |
|
self.func_kwargs = kwargs |
|
|
|
def __call__(self, func): |
|
|
|
@functools.wraps(func) |
|
def wrapped_func(*args): |
|
|
|
if not hasattr(tf.compat.v1, "executing_eagerly_outside_functions" |
|
) or tf.compat.v1.executing_eagerly_outside_functions(): |
|
return tf.function(func=func, **self.func_kwargs)(*args) |
|
return func(*args) |
|
|
|
|
|
if not hasattr(self, "_call_impl"): |
|
self._call_impl = wrapped_func |
|
return self._call_impl |
|
|
|
|
|
def tf_function_if_eager(**kwargs): |
|
"""Applies the @tf.function decorator only if running in eager mode.""" |
|
return TfFunctionIfEagerDecorator(**kwargs) |
|
|