File size: 8,052 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A version of ViT with flexible seqlen ((internal link))."""
from typing import Optional, Sequence
from absl import logging
from big_vision import utils
from big_vision.models import common
from big_vision.models import vit
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
def resample_patchemb(old, new_hw):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing. Colab with detailed explanation:
(internal link)
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
See (internal link)
Args:
old: original parameter to be resized.
new_hw: target shape (height, width)-only.
Returns:
Resized patch embedding kernel.
"""
assert len(old.shape) == 4, "Four dimensions expected"
assert len(new_hw) == 2, "New shape should only be hw"
if tuple(old.shape[:2]) == tuple(new_hw):
return old
logging.info("FlexiViT: resize embedding %s to %s", old.shape, new_hw)
def resize(x_np, new_shape):
x_tf = tf.constant(x_np)[None, ..., None]
# NOTE: we are using tf.image.resize here to match the resize operations in
# the data preprocessing pipeline.
x_upsampled = tf.image.resize(
x_tf, new_shape, method="bilinear")[0, ..., 0].numpy()
return x_upsampled
def get_resize_mat(old_shape, new_shape):
mat = []
for i in range(np.prod(old_shape)):
basis_vec = np.zeros(old_shape)
basis_vec[np.unravel_index(i, old_shape)] = 1.
mat.append(resize(basis_vec, new_shape).reshape(-1))
return np.stack(mat).T
resize_mat = get_resize_mat(old.shape[:2], new_hw)
resize_mat_pinv = np.linalg.pinv(resize_mat.T)
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_hw)
v_resample_kernel = jax.vmap(jax.vmap(resample_kernel, 2, 2), 3, 3)
return v_resample_kernel(old)
class Patchify(nn.Module):
"""As a class just to match param names with original ViT."""
patch_size: Sequence[int] = (32, 32)
width: int = 768
seqhw: Optional[int] = None
@nn.compact
def __call__(self, image, seqhw=None):
n, h, w, c = image.shape # pylint: disable=unused-variable
w_emb = self.param(
"kernel", nn.initializers.normal(stddev=1/np.sqrt(self.width)),
(*self.patch_size, c, self.width), image.dtype)
b_emb = self.param("bias", nn.initializers.zeros, self.width, image.dtype)
# Compute required patch-size to reach `seqhw` given `image` size.
seqhw = seqhw or self.seqhw
if seqhw is None and self.is_initializing():
patch_size = self.patch_size
else:
patch_size = tuple(np.array((h, w)) // np.array((seqhw, seqhw)))
if patch_size != self.patch_size:
w_emb = resample_patchemb(old=w_emb, new_hw=patch_size)
x = jax.lax.conv_general_dilated(
image, w_emb, window_strides=patch_size, padding="VALID",
dimension_numbers=("NHWC", "HWIO", "NHWC"))
return x + b_emb
class _Model(nn.Module):
"""ViT model."""
num_classes: int
patch_size: Sequence[int] = (32, 32)
posemb_size: Sequence[int] = (7, 7)
width: int = 768
depth: int = 12
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
posemb: str = "learn" # Can also be "sincos2d"
pool_type: str = "gap" # Can also be "map" or "tok"
head_zeroinit: bool = True
seqhw: Optional[int] = None
@nn.compact
def __call__(self, image, *, seqhw=None, train=False):
out = {}
x = out["stem"] = Patchify(
self.patch_size, self.width, self.seqhw, name="embedding")(image, seqhw)
# == Flattening + posemb
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
pos_emb = vit.get_posemb(
self, self.posemb, self.posemb_size, c, "pos_embedding", x.dtype)
if pos_emb.shape[1] != h * w:
pos_emb = jnp.reshape(pos_emb, (1, *self.posemb_size, c))
pos_emb = jax.image.resize(pos_emb, (1, h, w, c), "linear")
pos_emb = jnp.reshape(pos_emb, (1, h * w, c))
x = out["with_posemb"] = x + pos_emb
# == Optional [cls] token
if self.pool_type == "tok":
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
# == Encoder
n, l, c = x.shape # pylint: disable=unused-variable
x, out["encoder"] = vit.Encoder(
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
name="Transformer")(x)
encoded = out["encoded"] = x
if self.pool_type == "map":
x = out["head_input"] = vit.MAPHead(
num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
elif self.pool_type == "gap":
x = out["head_input"] = jnp.mean(x, axis=1)
elif self.pool_type == "tok":
x = out["head_input"] = x[:, 0]
encoded = encoded[:, 1:]
else:
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
x_2d = jnp.reshape(encoded, [n, h, w, -1])
out["pre_logits_2d"] = x_2d
out["pre_logits"] = x
if self.num_classes:
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
head = nn.Dense(self.num_classes, name="head", **kw)
x_2d = out["logits_2d"] = head(x_2d)
x = out["logits"] = head(x)
return x, out
def Model(num_classes, *, variant=None, **kw): # pylint: disable=invalid-name
"""Factory function, because linen really don't like what I'm doing!"""
return _Model(num_classes, **{**vit.decode_variant(variant), **kw})
def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above.
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
init_file = {**vit.VANITY_NAMES, **VANITY_NAMES}.get(init_file, init_file)
restored_params = utils.load_params(init_file)
restored_params = vit.fix_old_checkpoints(restored_params)
# Potentially resize the position embedings if seqlen differs.
restored_params["pos_embedding"] = vit.resample_posemb(
old=restored_params["pos_embedding"],
new=init_params["pos_embedding"])
# Potentially resize the patch embedding kernel.
old_patchemb = restored_params["embedding"]["kernel"]
restored_params["embedding"]["kernel"] = resample_patchemb(
old=old_patchemb, new_hw=model_cfg.patch_size)
# possibly use the random init for some of the params (such as, the head).
restored_params = common.merge_params(restored_params, init_params, dont_load)
return restored_params
# Shortcut names for some canonical paper checkpoints:
VANITY_NAMES = {
# pylint: disable=line-too-long
"FlexiViT-L i1k": "gs://big_vision/flexivit/flexivit_l_i1k.npz",
"FlexiViT-B i1k": "gs://big_vision/flexivit/flexivit_b_i1k.npz",
"FlexiViT-S i1k": "gs://big_vision/flexivit/flexivit_s_i1k.npz",
"FlexiViT-B i21k 90ep": "gs://big_vision/flexivit/flexivit_b_i21k_90ep.npz",
"FlexiViT-B i21k 300ep": "gs://big_vision/flexivit/flexivit_b_i21k_300ep.npz",
"FlexiViT-B i21k 1000ep": "gs://big_vision/flexivit/flexivit_b_i21k_1000ep.npz",
"ViT-B/16 i21k": "gs://big_vision/flexivit/vit_b16_i21k_300ep.npz",
"ViT-B/30 i21k": "gs://big_vision/flexivit/vit_b30_i21k_300ep.npz",
# pylint: enable=line-too-long
}
|