File size: 7,311 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gemma wrapper to make it work for us."""

from big_vision.models.ppp import gemma
import flax.linen as nn
import jax
import jax.numpy as jnp


def _get_config(model):
  config = gemma.get_config(model.variant)
  config.scan = model.scan
  config.remat_policy = model.remat_policy
  if model.vocab_size is not None:
    config.vocab_size = model.vocab_size
  config.dropout = model.dropout
  config.dropout_bdims = model.dropout_bdims
  config.cache_dtype = model.cache_dtype
  return config


@jax.vmap
def _left_to_right_align(x, input_mask, attn_mask):
  """Converts input from left-align to right-aligned."""
  # Due to vmap, this is operating in a single example (not batch level).
  assert x.ndim == 2 and input_mask.ndim == 1 and attn_mask.ndim == 2
  assert x.shape[0] == input_mask.shape[0]
  assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
  seqlen = jnp.sum(input_mask)
  x = jnp.roll(x, -seqlen, axis=0)
  input_mask = jnp.roll(input_mask, -seqlen, axis=0)
  attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
  return x, input_mask, attn_mask


class Model(nn.Module):
  """Wrapping gemma big_vision model."""
  variant: str = "gemma_2b"
  scan: bool = True
  remat_policy: str = "nothing_saveable"
  vocab_size: int | None = None

  dropout: float = 0.0
  dropout_bdims: tuple[int, ...] = ()  # Every float is dropped independently.
  cache_dtype: str | None = "bfloat16"  # bfloat16 to save memory and transfers.

  def setup(self):
    # The parent+name avoids an unnecessary nesting in params pytree.
    self.model = gemma.Model(**_get_config(self), parent=self.scope, name="")

  def embed_tokens(self, tokens, train=False):
    # Turns int32[B,T] tokens into float32[B,T,d_model] embeddings.
    # Really just the vocab embedding.
    return self.model(tokens, embed_only=True, deterministic=not train)

  def compute_logits(self, pre_logits, train=False):
    return self.model(None, pre_logits=pre_logits, deterministic=not train)[0]

  def __call__(self, embs, mask=None, train=False):
    # Turns float32[B,T,d_model] embedding sequence to logits.
    # call(emb_tokens(tokens)) should be a forward pass.
    # Allow for specifying int32[B,T,T] attention masks. For convenience
    # default to triangular autorgressive mask when None, but not P0.
    # Return float32[B,T,vocab_size] logits and out-dict.

    batch_size, _, d_model = embs.shape
    assert d_model == self.embdim
    logits, out = self.model(
        tokens=jnp.zeros([batch_size, 0], dtype=jnp.int32),
        embedded_prefix=embs,
        mask=mask,
        deterministic=not train,
    )
    return logits, out

  def prefill_cache(self, x, input_mask, attn_mask, *, cache_size):
    """Initializes decoding cache with `x` [B, N, E] as prompt.

    IMPORTANT: Inputs MUST be left-aligned and attn_mask should not allow
    input tokens to attend to padding tokens.

    TODO: Relax left-align requirement by converting any input into
    a right aligned input with no attention to padding tokens.

    Args:
      x: float[B, N, E] with prompt tokens.
      input_mask: bool[B, N]. True indicates tokens are part of the prompt.
        False indicates padding tokens. This class doesn't combine this with
        attn_mask, so mask out the attention to padding tokens beforehand.
      attn_mask: bool[B, N, N]. Indicates which tokens can attend to which while
        processing the prompt tokens. During extend_cache tokens, it is assumed
        that tokens can attend all previous valid tokens.
      cache_size: int. Indicates the size of the cache. The prompt will consume
        the first N entries of the cache. Each subsequent extend_cache will
        consume one entry. Behaviour is undefined when prefill_len plus number
        of extend_cache exceeds the cache_size.

    Returns:
      logits of the last valid token (i.e. last logits where input_mask=True).
    """
    # To call the model with decode=True we need to be able to provide:
    #   (a) positions of tokens [B, N], ([B, 1] for extend)
    #   (b) attention mask [B, N, cache_size] ([B, 1, cache_size] for extend)
    #
    # To do so we track how many tokens each example has seen so far, and we
    # align the prompt to the right so that cache usage for each example is in
    # a continuous subsequent of (cache_begin, cache_end] such that cache_end
    # is the same for all sequences (this allows to do faster row updates of
    # the cache during decoding).
    x, input_mask, attn_mask = _left_to_right_align(x, input_mask, attn_mask)

    # Track sequence len
    seq_len = jnp.sum(input_mask, axis=-1)
    self.put_variable("cache", "seq_len", seq_len)
    positions = jnp.cumsum(input_mask, axis=-1) - 1

    # Initialize cache_begin and cache_end. Note: cache_end is the same for all
    # sequences but we keep it per example to allow easy sharding rules with
    # batch as the first axis.
    batch_size, prefill_len, _ = x.shape
    self.put_variable("cache", "cache_begin", prefill_len - seq_len)
    self.put_variable(
        "cache", "cache_end", jnp.full((batch_size,), prefill_len, jnp.int32)
    )

    # Pad attention to set the cache size.
    mask = jnp.pad(attn_mask, ((0, 0), (0, 0), (0, cache_size - prefill_len)))

    _, aux = self.model(
        tokens=None,
        embedded_prefix=x,
        positions=positions,
        mask=mask,
        decode=True,
    )
    return self.compute_logits(aux["pre_logits"][:, -1:])

  def extend_cache(self, x):
    """Extends decoding cache with `x` [B, 1, E] and returns logits."""
    assert x.shape[1] == 1, "Only supports extend the cache by one token."
    if self.model.scan:
      cache_size = self.variables["cache"]["layers"]["attn"]["k_cache"].shape[2]
    else:
      raise NotImplementedError("Not implemented yet.")

    # Lookup current token position and increment by one for next call.
    positions = self.get_variable("cache", "seq_len")
    self.put_variable("cache", "seq_len", positions + 1)

    # Update which cache positions are in use and construct attention mask.
    # Tokens can attend to all cache positions which are in use including self.
    cache_begin = self.get_variable("cache", "cache_begin")
    cache_end = self.get_variable("cache", "cache_end") + 1
    self.put_variable("cache", "cache_end", cache_end)
    mask = jnp.logical_and(
        jnp.arange(cache_size)[None, None, :] >= cache_begin[:, None, None],
        jnp.arange(cache_size)[None, None, :] < cache_end[:, None, None])

    logits, _ = self.model(
        tokens=None, embedded_prefix=x,
        positions=positions[:, None], mask=mask, decode=True)
    return logits

  @property
  def embdim(self):
    return _get_config(self).width


load = gemma.load