|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities shared across models.""" |
|
|
|
from absl import logging |
|
import big_vision.utils as u |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
def merge_params(loaded, inited, dont_load=(), match_dtype=False): |
|
"""Makes `loaded` pytree match `init`, warning or failing on mismatch. |
|
|
|
Args: |
|
loaded: pytree of parameters, typically loaded from a checkpoint. |
|
inited: pytree of parameter, typically coming from model init. |
|
dont_load: List of regexes for parameters which shall not be taken |
|
from `loaded`, either because they should remain at their init value, |
|
or because they are missing on either side. |
|
match_dtype: returned pytree as leaves converted to dtype from `inited`. |
|
|
|
Returns: |
|
If successful, a new pytree which matches the structure of `init` |
|
but contains values from `loaded`, except for `dont_load`. |
|
|
|
If structures don't match and mismatches are not covered by regexes in |
|
`dont_load` argument, then raises an exception with more information. |
|
""" |
|
if inited is None: |
|
return loaded |
|
|
|
dont_load = u.check_and_compile_patterns(dont_load) |
|
|
|
def should_merge(name): |
|
return not any(pattern.fullmatch(name) for pattern in dont_load) |
|
|
|
loaded_flat, _ = u.tree_flatten_with_names(loaded) |
|
inited_flat, _ = u.tree_flatten_with_names(inited) |
|
loaded_flat = {k: v for k, v in loaded_flat} |
|
inited_flat = {k: v for k, v in inited_flat} |
|
|
|
|
|
merged = {} |
|
for name, init_val in inited_flat.items(): |
|
|
|
if name in loaded_flat and should_merge(name): |
|
merged[name] = loaded_flat[name] |
|
if match_dtype: |
|
merged[name] = loaded_flat[name].astype(init_val.dtype) |
|
else: |
|
logging.info("Ignoring checkpoint and using init value for %s", name) |
|
merged[name] = init_val |
|
|
|
def pp(title, names, indent=" "): |
|
if names: |
|
return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names)) |
|
else: |
|
return "" |
|
|
|
|
|
not_in_loaded = inited_flat.keys() - loaded_flat.keys() |
|
not_in_inited = loaded_flat.keys() - inited_flat.keys() |
|
logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded)) |
|
logging.info(pp("Parameters in checkpoint but not in model", not_in_inited)) |
|
|
|
|
|
not_in_loaded = {k for k in not_in_loaded if should_merge(k)} |
|
not_in_inited = {k for k in not_in_inited if should_merge(k)} |
|
|
|
if not_in_loaded or not_in_inited: |
|
raise ValueError( |
|
pp("Params in checkpoint", loaded_flat.keys()) + "\n" + |
|
pp("Params in model (code)", inited_flat.keys()) + "\n" + |
|
pp("Params in model (code) but not in checkpoint and not `dont_load`ed", |
|
not_in_loaded, indent=" - ") + "\n" + |
|
pp("Params in checkpoint but not in model (code) and not `dont_load`ed", |
|
not_in_inited, indent=" + ")) |
|
|
|
return u.recover_tree(merged.keys(), merged.values()) |
|
|
|
|
|
class AddPositionEmbs(nn.Module): |
|
"""Adds positional embeddings to the inputs, supports caching for decode. |
|
|
|
Attributes: |
|
decode: whether to run in single-position autoregressive mode. |
|
""" |
|
decode: bool = False |
|
|
|
@nn.compact |
|
def __call__(self, inputs, posemb): |
|
"""Applies AddPositionEmbs module. |
|
|
|
Adds posemb to the inputs, supports single-position autoregressive mode. |
|
|
|
Args: |
|
inputs: input data [batch_size, seq_len, emb_dim]. |
|
posemb: positional embeddings. |
|
|
|
Returns: |
|
output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim]. |
|
""" |
|
assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}" |
|
_, seq_len, emb_dim = inputs.shape |
|
pe = posemb[:, :seq_len, :] |
|
|
|
if self.decode: |
|
is_initialized = self.has_variable("cache", "cache_index") |
|
|
|
cache_index = self.variable("cache", "cache_index", |
|
lambda: jnp.array(0, dtype=jnp.uint32)) |
|
if is_initialized: |
|
i = cache_index.value |
|
cache_index.value = i + 1 |
|
|
|
|
|
pe = jax.lax.dynamic_slice(posemb, |
|
start_indices=jnp.array((0, i, 0)), |
|
slice_sizes=(1, 1, emb_dim)) |
|
return inputs + pe |
|
|