deanna-emery's picture
updates
93528c6
raw
history blame
4.47 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for Mixture of Experts (MoE) routing.
For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""
import tensorflow as tf, tf_keras
@tf_keras.utils.register_keras_serializable(package="Text")
class TokenImportanceWithMovingAvg(tf_keras.layers.Layer):
"""Routing based on per-token importance value."""
def __init__(self,
vocab_size,
init_importance,
moving_average_beta=0.995,
**kwargs):
self._vocab_size = vocab_size
self._init_importance = init_importance
self._moving_average_beta = moving_average_beta
super().__init__(**kwargs)
def build(self, input_shape):
self._importance_embedding = self.add_weight(
name="importance_embed",
shape=(self._vocab_size),
initializer=tf_keras.initializers.Constant(self._init_importance),
trainable=False)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"init_importance":
self._init_importance,
"moving_average_beta":
self._moving_average_beta,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def update_token_importance(self, token_ids, importance):
token_ids = tf.reshape(token_ids, shape=[-1])
importance = tf.reshape(importance, shape=[-1])
beta = self._moving_average_beta
old_importance = tf.gather(self._importance_embedding, token_ids)
self._importance_embedding.assign(tf.tensor_scatter_nd_update(
self._importance_embedding,
tf.expand_dims(token_ids, axis=1),
old_importance * beta + tf.cast(importance * (1.0 - beta),
dtype=tf.float32)))
def call(self, inputs):
return tf.gather(self._importance_embedding, inputs)
@tf_keras.utils.register_keras_serializable(package="Text")
class SelectTopK(tf_keras.layers.Layer):
"""Select top-k + random-k tokens according to importance."""
def __init__(self,
top_k=None,
random_k=None,
**kwargs):
self._top_k = top_k
self._random_k = random_k
super().__init__(**kwargs)
def get_config(self):
config = {
"top_k":
self._top_k,
"random_k":
self._random_k,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self._random_k is None:
# Pure top-k, not randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._top_k])
not_selected = tf.slice(pos, [0, self._top_k], [-1, -1])
elif self._top_k is None:
# Pure randomness, no top-k.
pos = tf.argsort(tf.random.uniform(shape=tf.shape(inputs)),
direction="DESCENDING")
selected = tf.slice(pos, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos, [0, self._random_k], [-1, -1])
else:
# Top-k plus randomness.
pos = tf.argsort(inputs, direction="DESCENDING")
selected_top_k = tf.slice(pos, [0, 0], [-1, self._top_k])
pos_left = tf.slice(pos, [0, self._top_k], [-1, -1])
# Randomly shuffle pos_left
sort_index = tf.argsort(
tf.random.uniform(shape=tf.shape(pos_left)),
direction="DESCENDING")
pos_left = tf.gather(pos_left, sort_index, batch_dims=1, axis=1)
selected_rand = tf.slice(pos_left, [0, 0], [-1, self._random_k])
not_selected = tf.slice(pos_left, [0, self._random_k], [-1, -1])
selected = tf.concat([selected_top_k, selected_rand], axis=1)
# Return the indices of selected and not-selected tokens.
return selected, not_selected