File size: 5,207 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities shared across models."""
from absl import logging
import big_vision.utils as u
import flax.linen as nn
import jax
import jax.numpy as jnp
def merge_params(loaded, inited, dont_load=(), match_dtype=False):
"""Makes `loaded` pytree match `init`, warning or failing on mismatch.
Args:
loaded: pytree of parameters, typically loaded from a checkpoint.
inited: pytree of parameter, typically coming from model init.
dont_load: List of regexes for parameters which shall not be taken
from `loaded`, either because they should remain at their init value,
or because they are missing on either side.
match_dtype: returned pytree as leaves converted to dtype from `inited`.
Returns:
If successful, a new pytree which matches the structure of `init`
but contains values from `loaded`, except for `dont_load`.
If structures don't match and mismatches are not covered by regexes in
`dont_load` argument, then raises an exception with more information.
"""
if inited is None: # A useful shortcut for example for colabs.
return loaded
dont_load = u.check_and_compile_patterns(dont_load)
def should_merge(name):
return not any(pattern.fullmatch(name) for pattern in dont_load)
loaded_flat, _ = u.tree_flatten_with_names(loaded)
inited_flat, _ = u.tree_flatten_with_names(inited)
loaded_flat = {k: v for k, v in loaded_flat}
inited_flat = {k: v for k, v in inited_flat}
# Let's first build the pytree from all common keys.
merged = {}
for name, init_val in inited_flat.items():
# param is present in both. Load or ignore it!
if name in loaded_flat and should_merge(name):
merged[name] = loaded_flat[name]
if match_dtype:
merged[name] = loaded_flat[name].astype(init_val.dtype)
else:
logging.info("Ignoring checkpoint and using init value for %s", name)
merged[name] = init_val
def pp(title, names, indent=" "): # Just pretty-printing
if names:
return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names))
else:
return ""
# Now, if there are keys that only exist in inited or loaded, be helpful:
not_in_loaded = inited_flat.keys() - loaded_flat.keys()
not_in_inited = loaded_flat.keys() - inited_flat.keys()
logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded))
logging.info(pp("Parameters in checkpoint but not in model", not_in_inited))
# And now see if any of them are not explicitly ignored => an error
not_in_loaded = {k for k in not_in_loaded if should_merge(k)}
not_in_inited = {k for k in not_in_inited if should_merge(k)}
if not_in_loaded or not_in_inited:
raise ValueError(
pp("Params in checkpoint", loaded_flat.keys()) + "\n" +
pp("Params in model (code)", inited_flat.keys()) + "\n" +
pp("Params in model (code) but not in checkpoint and not `dont_load`ed",
not_in_loaded, indent=" - ") + "\n" + # Special indent for tests.
pp("Params in checkpoint but not in model (code) and not `dont_load`ed",
not_in_inited, indent=" + ")) # Special indent for tests.
return u.recover_tree(merged.keys(), merged.values())
class AddPositionEmbs(nn.Module):
"""Adds positional embeddings to the inputs, supports caching for decode.
Attributes:
decode: whether to run in single-position autoregressive mode.
"""
decode: bool = False
@nn.compact
def __call__(self, inputs, posemb):
"""Applies AddPositionEmbs module.
Adds posemb to the inputs, supports single-position autoregressive mode.
Args:
inputs: input data [batch_size, seq_len, emb_dim].
posemb: positional embeddings.
Returns:
output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim].
"""
assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}"
_, seq_len, emb_dim = inputs.shape
pe = posemb[:, :seq_len, :]
if self.decode:
is_initialized = self.has_variable("cache", "cache_index")
# We use a cache position index for tracking decoding position.
cache_index = self.variable("cache", "cache_index",
lambda: jnp.array(0, dtype=jnp.uint32))
if is_initialized:
i = cache_index.value
cache_index.value = i + 1
# Returns posemb[0, i, :], the positional embedding for the
# current decoding position.
pe = jax.lax.dynamic_slice(posemb,
start_indices=jnp.array((0, i, 0)),
slice_sizes=(1, 1, emb_dim))
return inputs + pe
|