File size: 5,207 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities shared across models."""

from absl import logging
import big_vision.utils as u
import flax.linen as nn
import jax
import jax.numpy as jnp


def merge_params(loaded, inited, dont_load=(), match_dtype=False):
  """Makes `loaded` pytree match `init`, warning or failing on mismatch.

  Args:
    loaded: pytree of parameters, typically loaded from a checkpoint.
    inited: pytree of parameter, typically coming from model init.
    dont_load: List of regexes for parameters which shall not be taken
      from `loaded`, either because they should remain at their init value,
      or because they are missing on either side.
    match_dtype: returned pytree as leaves converted to dtype from `inited`.

  Returns:
    If successful, a new pytree which matches the structure of `init`
    but contains values from `loaded`, except for `dont_load`.

    If structures don't match and mismatches are not covered by regexes in
    `dont_load` argument, then raises an exception with more information.
  """
  if inited is None:  # A useful shortcut for example for colabs.
    return loaded

  dont_load = u.check_and_compile_patterns(dont_load)

  def should_merge(name):
    return not any(pattern.fullmatch(name) for pattern in dont_load)

  loaded_flat, _ = u.tree_flatten_with_names(loaded)
  inited_flat, _ = u.tree_flatten_with_names(inited)
  loaded_flat = {k: v for k, v in loaded_flat}
  inited_flat = {k: v for k, v in inited_flat}

  # Let's first build the pytree from all common keys.
  merged = {}
  for name, init_val in inited_flat.items():
    # param is present in both. Load or ignore it!
    if name in loaded_flat and should_merge(name):
      merged[name] = loaded_flat[name]
      if match_dtype:
        merged[name] = loaded_flat[name].astype(init_val.dtype)
    else:
      logging.info("Ignoring checkpoint and using init value for %s", name)
      merged[name] = init_val

  def pp(title, names, indent="  "):  # Just pretty-printing
    if names:
      return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names))
    else:
      return ""

  # Now, if there are keys that only exist in inited or loaded, be helpful:
  not_in_loaded = inited_flat.keys() - loaded_flat.keys()
  not_in_inited = loaded_flat.keys() - inited_flat.keys()
  logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded))
  logging.info(pp("Parameters in checkpoint but not in model", not_in_inited))

  # And now see if any of them are not explicitly ignored => an error
  not_in_loaded = {k for k in not_in_loaded if should_merge(k)}
  not_in_inited = {k for k in not_in_inited if should_merge(k)}

  if not_in_loaded or not_in_inited:
    raise ValueError(
        pp("Params in checkpoint", loaded_flat.keys()) + "\n" +
        pp("Params in model (code)", inited_flat.keys()) + "\n" +
        pp("Params in model (code) but not in checkpoint and not `dont_load`ed",
           not_in_loaded, indent=" - ") + "\n" +  # Special indent for tests.
        pp("Params in checkpoint but not in model (code) and not `dont_load`ed",
           not_in_inited, indent=" + "))  # Special indent for tests.

  return u.recover_tree(merged.keys(), merged.values())


class AddPositionEmbs(nn.Module):
  """Adds positional embeddings to the inputs, supports caching for decode.

  Attributes:
    decode: whether to run in single-position autoregressive mode.
  """
  decode: bool = False

  @nn.compact
  def __call__(self, inputs, posemb):
    """Applies AddPositionEmbs module.

    Adds posemb to the inputs, supports single-position autoregressive mode.

    Args:
      inputs: input data [batch_size, seq_len, emb_dim].
      posemb: positional embeddings.

    Returns:
      output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim].
    """
    assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}"
    _, seq_len, emb_dim = inputs.shape
    pe = posemb[:, :seq_len, :]

    if self.decode:
      is_initialized = self.has_variable("cache", "cache_index")
      # We use a cache position index for tracking decoding position.
      cache_index = self.variable("cache", "cache_index",
                                  lambda: jnp.array(0, dtype=jnp.uint32))
      if is_initialized:
        i = cache_index.value
        cache_index.value = i + 1
        # Returns posemb[0, i, :], the positional embedding for the
        # current decoding position.
        pe = jax.lax.dynamic_slice(posemb,
                                   start_indices=jnp.array((0, i, 0)),
                                   slice_sizes=(1, 1, emb_dim))
    return inputs + pe