|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from src.efficientvit.models.utils import build_kwargs_from_config |
|
|
|
__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"] |
|
|
|
|
|
class LayerNorm2d(nn.LayerNorm): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
out = x - torch.mean(x, dim=1, keepdim=True) |
|
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) |
|
if self.elementwise_affine: |
|
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) |
|
return out |
|
|
|
|
|
|
|
REGISTERED_NORM_DICT: dict[str, type] = { |
|
"bn2d": nn.BatchNorm2d, |
|
"ln": nn.LayerNorm, |
|
"ln2d": LayerNorm2d, |
|
} |
|
|
|
|
|
def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None: |
|
if name in ["ln", "ln2d"]: |
|
kwargs["normalized_shape"] = num_features |
|
else: |
|
kwargs["num_features"] = num_features |
|
if name in REGISTERED_NORM_DICT: |
|
norm_cls = REGISTERED_NORM_DICT[name] |
|
args = build_kwargs_from_config(kwargs, norm_cls) |
|
return norm_cls(**args) |
|
else: |
|
return None |
|
|
|
|
|
def reset_bn( |
|
model: nn.Module, |
|
data_loader: list, |
|
sync=True, |
|
progress_bar=False, |
|
) -> None: |
|
import copy |
|
|
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor |
|
from efficientvit.models.utils import get_device, list_join |
|
|
|
bn_mean = {} |
|
bn_var = {} |
|
|
|
tmp_model = copy.deepcopy(model) |
|
for name, m in tmp_model.named_modules(): |
|
if isinstance(m, _BatchNorm): |
|
bn_mean[name] = AverageMeter(is_distributed=False) |
|
bn_var[name] = AverageMeter(is_distributed=False) |
|
|
|
def new_forward(bn, mean_est, var_est): |
|
def lambda_forward(x): |
|
x = x.contiguous() |
|
if sync: |
|
batch_mean = ( |
|
x.mean(0, keepdim=True) |
|
.mean(2, keepdim=True) |
|
.mean(3, keepdim=True) |
|
) |
|
batch_mean = sync_tensor(batch_mean, reduce="cat") |
|
batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) |
|
|
|
batch_var = (x - batch_mean) * (x - batch_mean) |
|
batch_var = ( |
|
batch_var.mean(0, keepdim=True) |
|
.mean(2, keepdim=True) |
|
.mean(3, keepdim=True) |
|
) |
|
batch_var = sync_tensor(batch_var, reduce="cat") |
|
batch_var = torch.mean(batch_var, dim=0, keepdim=True) |
|
else: |
|
batch_mean = ( |
|
x.mean(0, keepdim=True) |
|
.mean(2, keepdim=True) |
|
.mean(3, keepdim=True) |
|
) |
|
batch_var = (x - batch_mean) * (x - batch_mean) |
|
batch_var = ( |
|
batch_var.mean(0, keepdim=True) |
|
.mean(2, keepdim=True) |
|
.mean(3, keepdim=True) |
|
) |
|
|
|
batch_mean = torch.squeeze(batch_mean) |
|
batch_var = torch.squeeze(batch_var) |
|
|
|
mean_est.update(batch_mean.data, x.size(0)) |
|
var_est.update(batch_var.data, x.size(0)) |
|
|
|
|
|
_feature_dim = batch_mean.shape[0] |
|
return F.batch_norm( |
|
x, |
|
batch_mean, |
|
batch_var, |
|
bn.weight[:_feature_dim], |
|
bn.bias[:_feature_dim], |
|
False, |
|
0.0, |
|
bn.eps, |
|
) |
|
|
|
return lambda_forward |
|
|
|
m.forward = new_forward(m, bn_mean[name], bn_var[name]) |
|
|
|
|
|
if len(bn_mean) == 0: |
|
return |
|
|
|
tmp_model.eval() |
|
with torch.no_grad(): |
|
with tqdm( |
|
total=len(data_loader), |
|
desc="reset bn", |
|
disable=not progress_bar or not is_master(), |
|
) as t: |
|
for images in data_loader: |
|
images = images.to(get_device(tmp_model)) |
|
tmp_model(images) |
|
t.set_postfix( |
|
{ |
|
"bs": images.size(0), |
|
"res": list_join(images.shape[-2:], "x"), |
|
} |
|
) |
|
t.update() |
|
|
|
for name, m in model.named_modules(): |
|
if name in bn_mean and bn_mean[name].count > 0: |
|
feature_dim = bn_mean[name].avg.size(0) |
|
assert isinstance(m, _BatchNorm) |
|
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) |
|
m.running_var.data[:feature_dim].copy_(bn_var[name].avg) |
|
|
|
|
|
def set_norm_eps(model: nn.Module, eps: float or None = None) -> None: |
|
for m in model.modules(): |
|
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): |
|
if eps is not None: |
|
m.eps = eps |
|
|