File size: 28,339 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder-only and encoder-decoder GIVT model.
Used abbreviations for dimension annotations:
B: batch size.
E: embedding size.
L: (soft) token sequence length.
D: soft token dimension.
P: number of patches (extracted by a ViT encoder in GIVT-based UViM)
"""
import enum
import itertools
from typing import Literal, Optional, Sequence, Any, Mapping
from absl import logging
from big_vision import utils
from big_vision.models import common
from big_vision.models import vit
import distrax
import einops
import flax.linen as nn
from flax.linen import partitioning
import jax
import jax.numpy as jnp
import numpy as np
class _SpecialLabel(enum.Enum):
MASK = "mask"
NOMASK = "nomask"
REPLACE = "replace"
NOLABEL = "nolabel" # For CFG
def _random_mask_with_ratios(rng, ratios: jax.Array, seq_len: int):
"""Generates masks where a fraction of tokens is uncovered.
Args:
rng: RNG.
ratios: Ratios, must be a 1D matrix of shape (B,). Values must be in
[0, 1], and indicate at ratios[i] how many of the i-th tokens are
uncovered (ie. equal to `True`).
seq_len: How many tokens this mask has to cover.
Returns:
Mask of dtype bool, shape (B, L).
Raises:
ValueError: Incorrect inputs.
"""
if ratios.ndim != 1:
raise ValueError("Ratios must have shape (B,)!")
ratios = jnp.clip(ratios, 0, 1)
indices = jnp.arange(seq_len, dtype=jnp.float32) # Shape: (L,)
ratios = ratios[:, jnp.newaxis] * seq_len # Shape: (B, 1)
# This is a binary array where the first ratios * seq_len positions are True
mask = (indices < ratios).astype(jnp.bool_) # Shape: (B, L)
# Shuffle to a actual mask.
return jax.random.shuffle(rng, mask, axis=-1)
def apply_mask_schedule(ratio: float | jax.Array, method: str) -> jax.Array:
"""Generate a mask rate by scheduling mask functions R."""
if method == "cosine":
mask_ratio = jax.lax.cos(jnp.pi / 2. * ratio)
elif "pow:" in method:
exponent = float(method.replace("pow:", ""))
mask_ratio = 1. - ratio**exponent
else:
raise NotImplementedError(method)
# Clamps mask into [epsilon, 1)
mask_ratio = jnp.clip(mask_ratio, 1e-6, 1.)
return mask_ratio
class EncoderDecoderBlock(nn.Module):
"""Transformer encoder-decoder layer."""
mlp_dim: int
num_heads: int
dropout_rate: float = 0.
decode: bool = False
@nn.compact
def __call__(
self,
targets: jax.Array,
encoded: jax.Array | None = None,
decoder_mask: jax.Array | None = None,
deterministic: bool = True,
) -> tuple[jax.Array, jax.Array]:
"""Applies EncoderDecoderBlock module.
Args:
targets: target text embeddings [B, L, D].
encoded: encoded image patches from encoder [B, P, E].
decoder_mask: decoder self-attention mask.
deterministic: bool, deterministic or not (to apply dropout).
Returns:
output after transformer encoder-decoder block [B, L, E].
"""
# Helper function for axis annotation.
def wlc(f):
dim_names = ("act_batch", "act_len", "act_emb")
return nn.with_logical_constraint(f, dim_names)
# Decoder block.
x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=False)(targets))
x = wlc(nn.SelfAttention(
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False,
dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")(
x, decoder_mask, deterministic=deterministic))
x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic))
x = wlc(x + targets)
if encoded is None:
y = x
else:
# Encoder-Decoder block.
y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=False)(x))
y = wlc(nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False,
dropout_rate=self.dropout_rate, name="CrossAttn")(
y, encoded, deterministic=deterministic))
y = wlc(
nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic))
y = wlc(y + x)
# MLP block.
z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=False)(y))
z = wlc(vit.MlpBlock(mlp_dim=self.mlp_dim, dropout=self.dropout_rate,
name="MLP")(z, deterministic=deterministic))
# nn.scan requires a carry (second element in tuple)
out = wlc(y + z)
return out, out
class Decoder(nn.Module):
"""Transformer decoder model with optional cross-attention."""
emb_dim: int
mlp_dim: int
num_heads: int
num_layers: int
out_dim: int
seq_len: int
style: Literal["ar", "masked"]
dropout_rate: float = 0.
zero_embedding_init: bool = False
scan: bool = False
remat_policy: str = "nothing_saveable"
@nn.compact
def __call__(
self,
targets: jax.Array,
encoded: jax.Array | None = None,
decoder_mask: jax.Array | None = None,
decode: bool = False,
deterministic: bool = True,
return_reps: bool = False,
) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]:
"""Applies Transformer model on the inputs.
Args:
targets: target text tokens [B, L].
encoded: encoded sequence from an encoder [B, P, E].
decoder_mask: decoder self-attention mask.
decode: bool, whether to perform fast autoregressive decoding with cache.
deterministic: bool, deterministic or not (to apply dropout).
return_reps: bool, whether to return intermediate representations.
Returns:
output of a transformer decoder [B, L, out_dim], where out_dim is usually
a multiple of D.
"""
if self.style == "masked" and decode:
raise ValueError("Cannot run masked model in cached mode!")
pos_emb = vit.get_posemb(
self, "learn", self.seq_len, self.emb_dim,
"pos_emb")
y = common.AddPositionEmbs(
decode=decode, name="PosEmbedTargets")(targets, pos_emb)
out = {}
if self.scan:
# Mostly followed
# https://github.com/google/maxtext/blob/4d99e30b3e0e0cb1d1aa11c7db7fffe18e301498/MaxText/layers.py#L1126
# for the scanned version.
# 1. remat
enc_dec_block_remat = nn.remat(
EncoderDecoderBlock,
prevent_cse=False,
static_argnums=(-1, -2),
policy=getattr(jax.checkpoint_policies, self.remat_policy, None))
# 2. scan
initializing = self.is_mutable_collection("params")
param_scan_axis = 1
params_spec = (param_scan_axis if initializing
else partitioning.ScanIn(param_scan_axis))
dec_scanned = nn.scan(enc_dec_block_remat,
variable_axes={
"params": params_spec,
"cache": 0,
},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.num_layers)
# 3. fprop
y, out = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate, decode=decode,
name="EncDecBlock")(
y, encoded, decoder_mask, deterministic)
# Extracting the intermediate representation from the stacked activation
# tensor `out`, which is a [num_layers, B, L, E] tensor. Indexing along
# the first axis to extract individual layers, and then averaging across
# the second axis, which corresponds to the sequence dimension after
# indexing.
assert out.shape[0] == self.num_layers and (
decode or out.shape[2] == self.seq_len), (
(out.shape, self.num_layers, self.seq_len))
out = {f"block{l}_rep": jnp.mean(out[l], axis=1)
for l in range(self.num_layers)}
else:
for lyr in range(self.num_layers):
y, _ = EncoderDecoderBlock(
num_heads=self.num_heads, mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate, decode=decode,
name=f"EncDecBlock{lyr}")(y, encoded, decoder_mask=decoder_mask,
deterministic=deterministic)
out[f"block{lyr}_rep"] = jnp.mean(y, axis=1)
y = nn.LayerNorm(name="LayerNorm")(y)
out["pre_logits"] = jnp.mean(y, axis=1)
logits = nn.Dense(
self.out_dim,
kernel_init=nn.initializers.zeros,
name="LogitsDense",
)(y)
out["logits"] = logits
if return_reps:
return logits, out
return logits
class Model(nn.Module):
"""GIVT model supporting decoder-only and encoder-decoder applications."""
num_heads: int = 8
# num_layers = 0 means no encoder
num_layers: int = 0
num_decoder_layers: int = 6
mlp_dim: int = 2048
enc_dropout_rate: float = 0.
dec_dropout_rate: float = 0.
# Decoder params:
emb_dim: int = 512
num_labels: Optional[int] = 1000
seq_len: int = 256
# Encoder params:
patches: Sequence[int] = (16, 16)
input_size: Sequence[int] = (256, 256)
posemb_type: Literal["learn", "sincos2d"] = "learn"
zero_decoder_seq: bool = False
style: Literal["ar", "masked"] = "ar"
zero_embedding_init: bool = False
num_mixtures: int = 4
multivariate: bool = False
out_dim: int = 32
scale_tol: float = 1e-6
# Mask specific params.
mask_schedule_train: str = "cosine"
# Results in at least 40% masked tokens with cosine.
min_masking_rate_training: float = 0.3
# How to fuse mask at input:
# - replace: replace token[masked] with lookup(MASK)
# - concat: replace token[mask] with lookup(REPLACE) and concat either
# lookup(NOMASK) or lookup(MASK).
mask_style: str = "replace"
# Set to >0 for CFG support.
drop_labels_probability: float = 0.0
fix_square_plus: bool = False
# If True, and mixture >1, create a GMM per channel. Otherwise, create
# a GMM of `dim`-dimensional Gaussians.
per_channel_mixtures: bool = True
scan: bool = False
remat_policy: str = "nothing_saveable"
@property
def has_encoder(self) -> bool:
return self.num_layers > 0
@property
def num_logits(self) -> int:
if self.multivariate:
assert self.num_mixtures == 1
# d**2 covariance, d means.
# Note: `round` makes pytype happy.
return round(self.out_dim ** 2) + self.out_dim
elif self.per_channel_mixtures:
# One (mu, sigma, pi) per output dimension and mixture component.
# Note that we predict a distribution for each output dimensions in
# parallel.
return 3 * self.num_mixtures * self.out_dim
else:
# Mixture weights plus mean/scale per mixture
return self.num_mixtures + 2 * self.num_mixtures * self.out_dim
def setup(self) -> None:
assert self.posemb_type == "learn"
assert self.num_mixtures > 0
if self.multivariate and self.num_mixtures != 1:
raise ValueError("Cannot do multivariate GMM!")
if self.num_layers > 0:
grid_size = np.array(self.input_size) // np.array(self.patches)
self.pos_emb_for_encoder = vit.get_posemb(
self, self.posemb_type, grid_size, self.emb_dim,
"pos_embedding_encoder")
self.conv = nn.Conv(self.emb_dim, self.patches, padding="VALID",
strides=self.patches, name="EmbedPatches")
self.encoder = vit.Encoder(
depth=self.num_layers,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.enc_dropout_rate,
scan=self.scan,
remat_policy=self.remat_policy,)
else:
self.encoder = None
# Iterator that will lead free label IDs.
next_label = itertools.count(self.num_labels or 0)
special_labels = {}
if self.style == "ar":
pass
elif self.style == "masked":
if self.mask_style == "replace":
special_labels = {_SpecialLabel.MASK: next(next_label)}
elif self.mask_style == "concat":
special_labels = {
_SpecialLabel.MASK: next(next_label),
_SpecialLabel.NOMASK: next(next_label),
_SpecialLabel.REPLACE: next(next_label),
}
else:
raise NotImplementedError(self.mask_style)
else:
raise NotImplementedError(self.style)
if self.drop_labels_probability > 0:
special_labels[_SpecialLabel.NOLABEL] = next(next_label)
self.special_labels = special_labels
lookup_size = (self.num_labels or 1) + len(self.special_labels)
self.labels_emb = nn.Embed(
lookup_size,
self.emb_dim,
name="EmbedLabels",
embedding_init=nn.initializers.zeros
if self.zero_embedding_init
else nn.initializers.normal(stddev=1.0),
)
self.targets_emb = nn.Dense(self.emb_dim, name="EmbedTargets")
self.decoder = Decoder(
num_layers=self.num_decoder_layers or self.num_layers,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
out_dim=self.num_logits,
# In masked mode, we run with 1 more token at the input.
seq_len=self.seq_len + int(self.style == "masked"),
dropout_rate=self.dec_dropout_rate,
emb_dim=self.emb_dim,
zero_embedding_init=self.zero_embedding_init,
style=self.style,
scan=self.scan,
remat_policy=self.remat_policy,
)
def encode(self, image: jax.Array, train: bool = False) -> jax.Array:
"""Encodes input image or embeddings."""
emb = self.conv(image)
patch_embeddings = einops.rearrange(emb, "B PH PW E -> B (PH PW) E")
encoded, _ = self.encoder(
patch_embeddings + self.pos_emb_for_encoder, deterministic=not train)
return encoded
def embed_labels(
self,
labels: jax.Array | None = None,
batch_size: int | None = None,
) -> jax.Array:
if labels is not None:
# Embed class label, add a sequence dim (output shape (B, 1, E))
return self.labels_emb(labels)[:, None, :]
assert ((self.num_labels == 1 or self.num_labels is None)
and batch_size is not None)
# Create [BOS] token embedding
return self.labels_emb(jnp.zeros((batch_size,), jnp.int32))[:, None, :]
def prefill(
self, labels=None, batch_size=None, encoded=None, drop_labels=None
):
labels = self._drop_labels(drop_labels, labels)
labels_for_prefill = self.embed_labels(labels=labels, batch_size=batch_size)
return self.decoder(
labels_for_prefill,
encoded=encoded,
decode=True)
def _decode_ar(
self,
targets: jax.Array,
labels: jax.Array | None = None,
encoded: jax.Array | None = None,
decode: bool = False,
train: bool = False,
) -> tuple[jax.Array, Mapping[str, jax.Array]]:
"""Autoregressive decoding."""
targets_embedded = self.targets_emb(targets)
if decode:
decoder_mask = None
else:
decoder_mask = nn.make_causal_mask(targets[:, :, 0])
b = targets.shape[0]
labels_embedded = self.embed_labels(labels, b)
assert labels_embedded.shape == (b, 1, self.emb_dim), (
labels_embedded.shape, (b, 1, self.emb_dim))
targets_embedded = jnp.concatenate(
[labels_embedded, targets_embedded[:, : -1]], axis=1)
logits, out = self.decoder(
targets_embedded,
encoded=encoded,
decoder_mask=decoder_mask,
decode=decode,
deterministic=not train,
return_reps=True)
return logits, out
def _get_special_label(self, size, label: _SpecialLabel):
return self.labels_emb(
jnp.full(size, self.special_labels[label], jnp.int32)
)
def _decode_masked(
self,
targets,
input_mask,
labels=None,
encoded=None,
train=False,
):
"""Masked decoding."""
b, s, _ = targets.shape
assert input_mask.shape == (b, s)
if self.mask_style == "replace":
targets_embedded = jnp.where(
input_mask[:, :, None],
self._get_special_label((b, s), _SpecialLabel.MASK),
self.targets_emb(targets),
)
elif self.mask_style == "concat":
masks = jnp.where(
input_mask[:, :, None],
self._get_special_label((b, s), _SpecialLabel.MASK),
self._get_special_label((b, s), _SpecialLabel.NOMASK),
)
embedded_targets = self.targets_emb(targets)
targets_embedded = jnp.where(
input_mask[:, :, None],
self._get_special_label((b, s), _SpecialLabel.REPLACE),
embedded_targets,
)
# Only take half of each to get the right embedding size.
targets_embedded = jnp.concatenate(
[masks[..., ::2], targets_embedded[..., ::2]], axis=-1
)
else:
raise ValueError(self.mask_style)
labels_embedded = self.embed_labels(labels, b)
assert labels_embedded.shape == (b, 1, self.emb_dim)
# Note that we do not truncate the input here, so this has shape
# (B, L+1, E).
targets_embedded = jnp.concatenate(
[labels_embedded, targets_embedded], axis=1)
logits = self.decoder(
targets_embedded,
encoded=encoded,
decoder_mask=None,
decode=False,
deterministic=not train)
logits = logits[:, 1:, ...] # Remove class label
assert logits.shape[:2] == (b, s)
return logits
def _drop_labels(self, drop_labels_mask, labels):
if labels is None:
return None
if self.drop_labels_probability >= 0.999:
logging.warning("Dropping all labels...")
return jnp.full_like(labels, self.special_labels[_SpecialLabel.NOLABEL])
if drop_labels_mask is None:
return labels
assert _SpecialLabel.NOLABEL in self.special_labels
nolabel = jnp.full_like(
labels, self.special_labels[_SpecialLabel.NOLABEL]
)
return jnp.where(drop_labels_mask, nolabel, labels)
def decode(
self,
targets: jax.Array,
labels: jax.Array | None = None,
encoded: jax.Array | None = None,
decode: bool = False,
train: bool = False,
max_decode_length: int | None = None,
input_mask: jax.Array | None = None,
drop_labels: jax.Array | None = None,
return_reps: bool = False,
) -> jax.Array | tuple[jax.Array, Mapping[str, jax.Array]]:
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
targets: target text tokens [B, L, out_dim].
labels: optional class labes, [B].
encoded: encoded image patches from encoder [B, P, E].
decode: whether to prepare and use an autoregressive cache.
train: whether it is training.
max_decode_length: optional max length for positional embeddings.
input_mask: If given, mask input. Required for style=="masked".
Shape [B, L], bool tensor. True means the token will be removed
from the input.
drop_labels: Drop labels at corresponding locations [B].
return_reps: whether to return intermediate representations.
Returns:
logits array from transformer decoder [B, L, 3 * num_mixtures * out_dim].
"""
del max_decode_length
labels = self._drop_labels(drop_labels, labels)
if self.style == "ar":
logits, out = self._decode_ar(
targets, labels, encoded, decode, train)
if return_reps:
return logits, out
return logits
elif self.style == "masked":
assert not decode # Cache not supported.
assert input_mask is not None
assert not return_reps # Not implemented.
return self._decode_masked(targets, input_mask, labels, encoded, train)
else:
raise NotImplementedError(self.style)
def _square_plus(self, x):
# Via https://twitter.com/jon_barron/status/1387167648669048833
if self.fix_square_plus:
return (x + jnp.sqrt(jnp.square(x) + 4)) / 2
else:
return x + jnp.sqrt(jnp.square(x) + 4) / 2
def get_pdf(
self,
logits: jax.Array,
temperature_scales: float | None = None,
temperature_probs: float | None = None,
) -> distrax.Distribution:
assert logits.shape[-1] == self.num_logits
if self.multivariate:
scales = logits[..., :self.out_dim ** 2]
locs = logits[..., self.out_dim ** 2:]
assert locs.shape[-1] == self.out_dim
scales = self._square_plus(scales)
# Turn into a square matrix.
*leading, _ = scales.shape
scales = scales.reshape(*leading, self.out_dim, self.out_dim)
# Make sure the diagonals are non zero.
diag_scale_tol = jnp.eye(self.out_dim) * self.scale_tol
scales = jnp.maximum(scales, diag_scale_tol)
if (t := temperature_scales) is not None:
scales = scales * t
# Note that there is `tfd.MultivariateNormalFullCovariance`` but it just
# calls linalg.cholesky on the covariance and then uses the
# MultivariateNormalTri class. Using ... direcly avoids having to
# construct a hermetian matrix.
#
# Note that only the lower triag part of `scales` is used by applying
# jnp.tril. The other elements are replaced with zeros.
#
# Note on output shapes:
# - .sample() -> shape (..., seq_len, out_dim)
# - .prob() -> shape (..., seq_len).
return distrax.MultivariateNormalTri(locs, scales)
elif self.per_channel_mixtures:
# [..., 3 * num_mixtures * out_dim] -> [..., 3 * out_dim, num_mixtures]
logits = jnp.reshape(logits, logits.shape[: -1] + (-1, self.num_mixtures))
# 3 tensors with shape [..., out_dim, num_mixtures]
probs, locs, scales = jnp.split(logits, 3, axis=-2)
if (t := temperature_probs) is not None:
probs = probs * t
# normalize mixture probabilities
probs = nn.softmax(probs)
scales = self._square_plus(scales)
# threshold scale
scales = jnp.maximum(scales, self.scale_tol)
if (t := temperature_scales) is not None:
scales = scales * t
# Note on output shapes:
# - .sample() -> shape (..., seq_len, out_dim)
# - .prob() -> shape (..., seq_len, out_dim).
return distrax.MixtureSameFamily(
mixture_distribution=distrax.Categorical(probs=probs),
components_distribution=distrax.Normal(loc=locs, scale=scales),
)
else:
*shape, num_logits = logits.shape
assert num_logits == self.num_logits, (num_logits, self.num_logits)
prob_logits, other_logits = (
logits[..., : self.num_mixtures],
logits[..., self.num_mixtures :],
)
if (t := temperature_probs) is not None:
prob_logits = prob_logits * t
other_logits = jnp.reshape(
other_logits, (*shape, self.num_mixtures, 2, self.out_dim)
)
locs = other_logits[..., 0, :]
scales = self._square_plus(other_logits[..., 1, :])
scales = jnp.maximum(scales, self.scale_tol) # Threshold scale
if (t := temperature_scales) is not None:
scales = scales * t
# prob_logits has shape (b, seq_len, m)
# locs/scales has shape (b, seq_len, m, d)
assert prob_logits.ndim == locs.ndim - 1, (prob_logits.shape, locs.shape)
assert locs.shape == scales.shape, (locs.shape, scales.shape)
# Note on output shapes:
# - .sample() -> shape (..., seq_len, out_dim)
# - .prob() -> shape (..., seq_len,)
# - .nll() -> shape (..., seq_len,)
return distrax.MixtureSameFamily(
mixture_distribution=distrax.Categorical(logits=prob_logits),
components_distribution=distrax.MultivariateNormalDiag(
loc=locs, scale_diag=scales
),
)
def __call__(
self,
sequence: jax.Array,
labels: jax.Array | None = None,
*,
image: jax.Array | None = None,
decode: bool = False,
input_mask: jax.Array | None = None,
drop_labels: jax.Array | None = None,
train: bool = False,
) -> tuple[jax.Array, distrax.Distribution]:
"""Applies Transformer model on the inputs.
Args:
sequence: batch of sequences [B, L].
labels: class labels for class conditional generation [B].
image: batch of images [B, H, W, 3].
decode: whether to prepare and use an autoregressive cache.
input_mask: If given, mask input. Required for style=="masked" [B, L].
drop_labels: If given, drop labels of the corresponding batches [B].
train: whether it is training.
Returns:
logits array from full transformer [B, L, out_dim].
"""
if self.style == "masked" and input_mask is None:
raise ValueError("Cannot run masked model without input mask!")
if self.encoder is not None:
assert image is not None
encoded = self.encode(image, train=train)
else:
assert image is None
encoded = None
logits = self.decode(sequence, labels=labels, encoded=encoded,
decode=decode, input_mask=input_mask, train=train)
pdf = self.get_pdf(logits)
return logits, pdf
def get_input_mask_training(
self,
rng: jax.Array,
shape: tuple[int, int],
) -> jax.Array | None:
"""Creates a random maask of shape (B, L) for training masked models."""
if self.style == "ar":
return None
b, s = shape
# Sample b values in [0, 1-min_mask_ratio].
keep = jax.random.uniform(
rng, shape=(b,), maxval=1.0 - self.min_masking_rate_training
)
mask_ratio = apply_mask_schedule(keep, self.mask_schedule_train)
return _random_mask_with_ratios(rng, ratios=mask_ratio, seq_len=s)
def get_input_mask_teacher_forced(
self,
shape: tuple[int, int],
) -> jax.Array | None:
"""Creates a random maask of shape (B, L) for training masked models."""
if self.style == "ar":
return None
return jnp.zeros(shape, dtype=jnp.bool_)
def get_drop_labels(
self,
rng: jax.Array,
batch_size: int,
) -> jax.Array | None:
if (p := self.drop_labels_probability) > 0:
return jax.random.uniform(rng, shape=(batch_size,)) <= p
else:
return None
def load(
init_params: Any,
init_files: str | Mapping[str, str],
model_params: Any = None,
dont_load: Sequence[str] = (),
resample_encoder_posemb: bool = False,
trim_decoder_posemb: bool = False,
) -> Any:
"""Loads params from init checkpoint and merges into init_params."""
del model_params
if isinstance(init_files, str):
ckpt_params = utils.load_params(init_files)
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load)
if resample_encoder_posemb:
if init_params and "pos_embedding_encoder" in init_params:
ckpt_params["pos_embedding_encoder"] = vit.resample_posemb(
old=ckpt_params["pos_embedding_encoder"],
new=init_params["pos_embedding_encoder"])
if trim_decoder_posemb:
if init_params and "pos_embedding_decoder" in init_params:
ckpt_params["pos_embedding_decoder"] = (
ckpt_params["pos_embedding_decoder"][
:, :init_params["pos_embedding_decoder"].shape[1], :])
else:
init_files = {**init_files} # Shallow copy because we'll pop stuff off.
enc_init = init_files.pop("encoder", None)
if enc_init:
ckpt_params = init_params.copy()
vit_params = {
"pos_embedding": ckpt_params["pos_embedding_encoder"],
"Transformer": ckpt_params["encoder"],
"embedding": ckpt_params["EmbedPatches"],
}
encoder_params = vit.load(
vit_params, enc_init, model_cfg={},
dont_load=dont_load)
ckpt_params["encoder"] = encoder_params["Transformer"]
ckpt_params["pos_embedding_encoder"] = encoder_params["pos_embedding"]
ckpt_params["EmbedPatches"] = encoder_params["embedding"]
else:
raise ValueError("Only encoder init is supported: {}.".format(init_files))
return ckpt_params
|