File size: 4,469 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layers for Mixture of Experts (MoE) routing.

For MoE routing, we need to separate a set of tokens to sets of tokens.
Later on, different sets of tokens can potentially go to different experts.
"""

import tensorflow as tf, tf_keras


@tf_keras.utils.register_keras_serializable(package="Text")
class TokenImportanceWithMovingAvg(tf_keras.layers.Layer):
  """Routing based on per-token importance value."""

  def __init__(self,
               vocab_size,
               init_importance,
               moving_average_beta=0.995,
               **kwargs):
    self._vocab_size = vocab_size
    self._init_importance = init_importance
    self._moving_average_beta = moving_average_beta
    super().__init__(**kwargs)

  def build(self, input_shape):
    self._importance_embedding = self.add_weight(
        name="importance_embed",
        shape=(self._vocab_size),
        initializer=tf_keras.initializers.Constant(self._init_importance),
        trainable=False)

  def get_config(self):
    config = {
        "vocab_size":
            self._vocab_size,
        "init_importance":
            self._init_importance,
        "moving_average_beta":
            self._moving_average_beta,
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def update_token_importance(self, token_ids, importance):
    token_ids = tf.reshape(token_ids, shape=[-1])
    importance = tf.reshape(importance, shape=[-1])

    beta = self._moving_average_beta
    old_importance = tf.gather(self._importance_embedding, token_ids)
    self._importance_embedding.assign(tf.tensor_scatter_nd_update(
        self._importance_embedding,
        tf.expand_dims(token_ids, axis=1),
        old_importance * beta + tf.cast(importance * (1.0 - beta),
                                        dtype=tf.float32)))

  def call(self, inputs):
    return tf.gather(self._importance_embedding, inputs)


@tf_keras.utils.register_keras_serializable(package="Text")
class SelectTopK(tf_keras.layers.Layer):
  """Select top-k + random-k tokens according to importance."""

  def __init__(self,
               top_k=None,
               random_k=None,
               **kwargs):
    self._top_k = top_k
    self._random_k = random_k
    super().__init__(**kwargs)

  def get_config(self):
    config = {
        "top_k":
            self._top_k,
        "random_k":
            self._random_k,
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    if self._random_k is None:
      # Pure top-k, not randomness.
      pos = tf.argsort(inputs, direction="DESCENDING")
      selected = tf.slice(pos, [0, 0], [-1, self._top_k])
      not_selected = tf.slice(pos, [0, self._top_k], [-1, -1])
    elif self._top_k is None:
      # Pure randomness, no top-k.
      pos = tf.argsort(tf.random.uniform(shape=tf.shape(inputs)),
                       direction="DESCENDING")
      selected = tf.slice(pos, [0, 0], [-1, self._random_k])
      not_selected = tf.slice(pos, [0, self._random_k], [-1, -1])
    else:
      # Top-k plus randomness.
      pos = tf.argsort(inputs, direction="DESCENDING")
      selected_top_k = tf.slice(pos, [0, 0], [-1, self._top_k])
      pos_left = tf.slice(pos, [0, self._top_k], [-1, -1])

      # Randomly shuffle pos_left
      sort_index = tf.argsort(
          tf.random.uniform(shape=tf.shape(pos_left)),
          direction="DESCENDING")
      pos_left = tf.gather(pos_left, sort_index, batch_dims=1, axis=1)

      selected_rand = tf.slice(pos_left, [0, 0], [-1, self._random_k])
      not_selected = tf.slice(pos_left, [0, self._random_k], [-1, -1])

      selected = tf.concat([selected_top_k, selected_rand], axis=1)

    # Return the indices of selected and not-selected tokens.
    return selected, not_selected