|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from src.efficientvit.apps.trainer.run_config import Scheduler |
|
from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock |
|
from src.efficientvit.models.utils import build_kwargs_from_config |
|
|
|
__all__ = ["apply_drop_func"] |
|
|
|
|
|
def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None: |
|
if drop_config is None: |
|
return |
|
|
|
drop_lookup_table = { |
|
"droppath": apply_droppath, |
|
} |
|
|
|
drop_func = drop_lookup_table[drop_config["name"]] |
|
drop_kwargs = build_kwargs_from_config(drop_config, drop_func) |
|
|
|
drop_func(network, **drop_kwargs) |
|
|
|
|
|
def apply_droppath( |
|
network: nn.Module, |
|
drop_prob: float, |
|
linear_decay=True, |
|
scheduled=True, |
|
skip=0, |
|
) -> None: |
|
all_valid_blocks = [] |
|
for m in network.modules(): |
|
for name, sub_module in m.named_children(): |
|
if isinstance(sub_module, ResidualBlock) and isinstance( |
|
sub_module.shortcut, IdentityLayer |
|
): |
|
all_valid_blocks.append((m, name, sub_module)) |
|
all_valid_blocks = all_valid_blocks[skip:] |
|
for i, (m, name, sub_module) in enumerate(all_valid_blocks): |
|
prob = ( |
|
drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob |
|
) |
|
new_module = DropPathResidualBlock( |
|
sub_module.main, |
|
sub_module.shortcut, |
|
sub_module.post_act, |
|
sub_module.pre_norm, |
|
prob, |
|
scheduled, |
|
) |
|
m._modules[name] = new_module |
|
|
|
|
|
class DropPathResidualBlock(ResidualBlock): |
|
def __init__( |
|
self, |
|
main: nn.Module, |
|
shortcut: nn.Module or None, |
|
post_act=None, |
|
pre_norm: nn.Module or None = None, |
|
|
|
drop_prob: float = 0, |
|
scheduled=True, |
|
): |
|
super().__init__(main, shortcut, post_act, pre_norm) |
|
|
|
self.drop_prob = drop_prob |
|
self.scheduled = scheduled |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if ( |
|
not self.training |
|
or self.drop_prob == 0 |
|
or not isinstance(self.shortcut, IdentityLayer) |
|
): |
|
return ResidualBlock.forward(self, x) |
|
else: |
|
drop_prob = self.drop_prob |
|
if self.scheduled: |
|
drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1) |
|
keep_prob = 1 - drop_prob |
|
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = keep_prob + torch.rand( |
|
shape, dtype=x.dtype, device=x.device |
|
) |
|
random_tensor.floor_() |
|
|
|
res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x) |
|
if self.post_act: |
|
res = self.post_act(res) |
|
return res |
|
|