|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ResNet V1 with GroupNorm.""" |
|
|
|
from typing import Optional, Sequence, Union |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
import flax |
|
import flax.linen as nn |
|
import flax.training.checkpoints |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
def weight_standardize(w, axis, eps): |
|
w = w - jnp.mean(w, axis=axis) |
|
w = w / (jnp.std(w, axis=axis) + eps) |
|
return w |
|
|
|
|
|
class StdConv(nn.Conv): |
|
|
|
def param(self, name, *a, **kw): |
|
param = super().param(name, *a, **kw) |
|
if name == "kernel": |
|
param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) |
|
return param |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
"""Bottleneck ResNet block.""" |
|
nmid: Optional[int] = None |
|
strides: Sequence[int] = (1, 1) |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
nmid = self.nmid or x.shape[-1] // 4 |
|
nout = nmid * 4 |
|
|
|
residual = x |
|
if x.shape[-1] != nout or self.strides != (1, 1): |
|
residual = StdConv(nout, (1, 1), self.strides, use_bias=False, |
|
name="conv_proj")(residual) |
|
residual = nn.GroupNorm(name="gn_proj")(residual) |
|
|
|
y = StdConv(nmid, (1, 1), use_bias=False, name="conv1")(x) |
|
y = nn.GroupNorm(name="gn1")(y) |
|
y = nn.relu(y) |
|
y = StdConv(nmid, (3, 3), self.strides, use_bias=False, name="conv2")(y) |
|
y = nn.GroupNorm(name="gn2")(y) |
|
y = nn.relu(y) |
|
y = StdConv(nout, (1, 1), use_bias=False, name="conv3")(y) |
|
|
|
y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y) |
|
y = nn.relu(residual + y) |
|
return y |
|
|
|
|
|
class ResNetStage(nn.Module): |
|
"""One stage of ResNet.""" |
|
block_size: int |
|
first_stride: Sequence[int] = (1, 1) |
|
nmid: Optional[int] = None |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
x = ResidualUnit(self.nmid, strides=self.first_stride, name="unit1")(x) |
|
for i in range(1, self.block_size): |
|
x = ResidualUnit(self.nmid, name=f"unit{i + 1}")(x) |
|
return x |
|
|
|
|
|
class Model(nn.Module): |
|
"""ResNetV1.""" |
|
num_classes: Optional[int] = None |
|
width: float = 1 |
|
depth: Union[int, Sequence[int]] = 50 |
|
|
|
@nn.compact |
|
def __call__(self, image, *, train=False): |
|
del train |
|
blocks = get_block_desc(self.depth) |
|
width = int(64 * self.width) |
|
|
|
out = {} |
|
|
|
|
|
x = StdConv(width, (7, 7), (2, 2), use_bias=False, name="conv_root")(image) |
|
x = nn.GroupNorm(name="gn_root")(x) |
|
x = nn.relu(x) |
|
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") |
|
out["stem"] = x |
|
|
|
|
|
x = ResNetStage(blocks[0], nmid=width, name="block1")(x) |
|
out["stage1"] = x |
|
for i, block_size in enumerate(blocks[1:], 1): |
|
x = ResNetStage(block_size, nmid=width * 2 ** i, |
|
first_stride=(2, 2), name=f"block{i + 1}")(x) |
|
out[f"stage{i + 1}"] = x |
|
out["pre_logits_2d"] = x |
|
|
|
|
|
x = out["pre_logits"] = jnp.mean(x, axis=(1, 2)) |
|
|
|
if self.num_classes: |
|
head = nn.Dense(self.num_classes, name="head", |
|
kernel_init=nn.initializers.zeros) |
|
out["logits_2d"] = head(out["pre_logits_2d"]) |
|
x = out["logits"] = head(out["pre_logits"]) |
|
|
|
return x, out |
|
|
|
|
|
|
|
|
|
|
|
def get_block_desc(depth): |
|
if isinstance(depth, list): |
|
depth = tuple(depth) |
|
return { |
|
26: [2, 2, 2, 2], |
|
50: [3, 4, 6, 3], |
|
101: [3, 4, 23, 3], |
|
152: [3, 8, 36, 3], |
|
200: [3, 24, 36, 3] |
|
}.get(depth, depth) |
|
|
|
|
|
def fix_old_checkpoints(params): |
|
"""Modifies params from old checkpoints to run with current implementation.""" |
|
params = flax.core.unfreeze( |
|
flax.training.checkpoints.convert_pre_linen(params)) |
|
|
|
params = flax.traverse_util.unflatten_dict({ |
|
k: np.squeeze(v) if (set(k) |
|
& {"gn_root", "gn_proj", "gn1", "gn2", "gn3"}) else v |
|
for k, v in flax.traverse_util.flatten_dict(params).items() |
|
}) |
|
return params |
|
|
|
|
|
def load(init_params, init_file, model_cfg, dont_load=()): |
|
"""Load init from checkpoint.""" |
|
del model_cfg |
|
params = utils.load_params(init_file) |
|
params = common.merge_params(params, init_params, dont_load) |
|
params = fix_old_checkpoints(params) |
|
return params |
|
|