|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BiT models as in the paper (ResNet V2) w/ loading of public weights. |
|
|
|
See reproduction proof: http://(internal link)/qY70qs6j944 |
|
""" |
|
|
|
import functools |
|
import re |
|
from typing import Optional, Sequence, Union |
|
|
|
from big_vision import utils as u |
|
from big_vision.models import bit |
|
from big_vision.models import common |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
|
|
|
|
def standardize(x, axis, eps): |
|
x = x - jnp.mean(x, axis=axis, keepdims=True) |
|
x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps) |
|
return x |
|
|
|
|
|
|
|
|
|
class GroupNorm(nn.Module): |
|
"""Group normalization (arxiv.org/abs/1803.08494).""" |
|
ngroups: int = 32 |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
|
|
input_shape = x.shape |
|
group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups) |
|
|
|
x = x.reshape(group_shape) |
|
|
|
|
|
x = standardize(x, axis=[1, 2, 4], eps=1e-5) |
|
x = x.reshape(input_shape) |
|
|
|
bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]]) |
|
x = x * self.param('scale', nn.initializers.ones, bias_scale_shape) |
|
x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape) |
|
return x |
|
|
|
|
|
class StdConv(nn.Conv): |
|
|
|
def param(self, name, *a, **kw): |
|
param = super().param(name, *a, **kw) |
|
if name == 'kernel': |
|
param = standardize(param, axis=[0, 1, 2], eps=1e-10) |
|
return param |
|
|
|
|
|
class RootBlock(nn.Module): |
|
"""Root block of ResNet.""" |
|
width: int |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], |
|
use_bias=False, name='conv_root')(x) |
|
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)]) |
|
return x |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
"""Bottleneck ResNet block.""" |
|
nmid: Optional[int] = None |
|
strides: Sequence[int] = (1, 1) |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
nmid = self.nmid or x.shape[-1] // 4 |
|
nout = nmid * 4 |
|
conv = functools.partial(StdConv, use_bias=False) |
|
|
|
residual = x |
|
x = GroupNorm(name='gn1')(x) |
|
x = nn.relu(x) |
|
|
|
if x.shape[-1] != nout or self.strides != (1, 1): |
|
residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x) |
|
|
|
x = conv(nmid, (1, 1), name='conv1')(x) |
|
x = GroupNorm(name='gn2')(x) |
|
x = nn.relu(x) |
|
x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)], |
|
name='conv2')(x) |
|
x = GroupNorm(name='gn3')(x) |
|
x = nn.relu(x) |
|
x = conv(nout, (1, 1), name='conv3')(x) |
|
|
|
return x + residual |
|
|
|
|
|
class ResNetStage(nn.Module): |
|
"""A stage (sequence of same-resolution blocks).""" |
|
block_size: int |
|
nmid: Optional[int] = None |
|
first_stride: Sequence[int] = (1, 1) |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
out = {} |
|
x = out['unit01'] = ResidualUnit( |
|
self.nmid, strides=self.first_stride, name='unit01')(x) |
|
for i in range(1, self.block_size): |
|
x = out[f'unit{i+1:02d}'] = ResidualUnit( |
|
self.nmid, name=f'unit{i+1:02d}')(x) |
|
return x, out |
|
|
|
|
|
class Model(nn.Module): |
|
"""ResNetV2.""" |
|
num_classes: Optional[int] = None |
|
width: int = 1 |
|
depth: Union[int, Sequence[int]] = 50 |
|
head_zeroinit: bool = True |
|
|
|
@nn.compact |
|
def __call__(self, image, *, train=False): |
|
blocks = bit.get_block_desc(self.depth) |
|
width = int(64 * self.width) |
|
out = {} |
|
|
|
x = out['stem'] = RootBlock(width=width, name='root_block')(image) |
|
|
|
|
|
x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x) |
|
for i, block_size in enumerate(blocks[1:], 1): |
|
x, out[f'stage{i + 1}'] = ResNetStage( |
|
block_size, width * 2 ** i, |
|
first_stride=(2, 2), name=f'block{i + 1}')(x) |
|
|
|
|
|
x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x) |
|
x = out['pre_logits_2d'] = nn.relu(x) |
|
x = out['pre_logits'] = jnp.mean(x, axis=(1, 2)) |
|
|
|
|
|
if self.num_classes: |
|
kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} |
|
head = nn.Dense(self.num_classes, name='head', **kw) |
|
out['logits_2d'] = head(out['pre_logits_2d']) |
|
x = out['logits'] = head(out['pre_logits']) |
|
|
|
return x, out |
|
|
|
|
|
def load(init_params, init_file, model_cfg, dont_load=()): |
|
"""Loads the TF-dumped NumPy or big_vision checkpoint. |
|
|
|
Args: |
|
init_params: random init params from which the new head is taken. |
|
init_file: comes from `config.model_init`, can either be an absolute |
|
path (ie starts with /) to the checkpoint, or a string like |
|
"L-imagenet2012" describing one of the variants from the paper. |
|
model_cfg: the model configuration. |
|
dont_load: list of param names to be reset to init. |
|
|
|
Returns: |
|
The loaded parameters. |
|
""" |
|
|
|
|
|
vanity = { |
|
'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz', |
|
'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz', |
|
} |
|
if init_file[0] in ('L', 'M', 'S'): |
|
|
|
|
|
|
|
|
|
|
|
if '-' in init_file: |
|
up, down = init_file[0], init_file[1:] |
|
else: |
|
up, down = init_file, '' |
|
down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) |
|
fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz' |
|
fname = f'gs://bit_models/{fname}' |
|
else: |
|
fname = vanity.get(init_file, init_file) |
|
|
|
params = u.load_params(fname) |
|
params = maybe_convert_big_transfer_format(params) |
|
return common.merge_params(params, init_params, dont_load) |
|
|
|
|
|
def maybe_convert_big_transfer_format(params_tf): |
|
"""If the checkpoint comes from legacy codebase, convert it.""" |
|
|
|
|
|
if 'resnet' not in params_tf: |
|
return params_tf |
|
|
|
|
|
params_tf = dict(u.tree_flatten_with_names(params_tf)[0]) |
|
|
|
|
|
for k in list(params_tf): |
|
k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k) |
|
if k2 != k: |
|
params_tf[k2] = params_tf[k] |
|
del params_tf[k] |
|
|
|
params = { |
|
'root_block': {'conv_root': {'kernel': params_tf[ |
|
'resnet/root_block/standardized_conv2d/kernel']}}, |
|
'norm-pre-head': { |
|
'bias': params_tf['resnet/group_norm/beta'][None, None, None], |
|
'scale': params_tf['resnet/group_norm/gamma'][None, None, None], |
|
}, |
|
'head': { |
|
'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0], |
|
'bias': params_tf['resnet/head/conv2d/bias'], |
|
} |
|
} |
|
|
|
for block in ('block1', 'block2', 'block3', 'block4'): |
|
params[block] = {} |
|
units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys() |
|
if p.find(block) >= 0]) |
|
for unit in units: |
|
params[block][unit] = {} |
|
for i, group in enumerate('abc', 1): |
|
params[block][unit][f'conv{i}'] = { |
|
'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] |
|
} |
|
params[block][unit][f'gn{i}'] = { |
|
'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], |
|
'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], |
|
} |
|
|
|
projs = [p for p in params_tf.keys() |
|
if p.find(f'{block}/{unit}/a/proj') >= 0] |
|
assert len(projs) <= 1 |
|
if projs: |
|
params[block][unit]['conv_proj'] = { |
|
'kernel': params_tf[projs[0]] |
|
} |
|
|
|
return params |
|
|