|
from typing import Tuple |
|
|
|
from torch import nn |
|
from torch.nn.common_types import _size_2_t |
|
|
|
|
|
def auto_pad(kernel_size: _size_2_t, dilation: _size_2_t = 1, **kwargs) -> Tuple[int, int]: |
|
""" |
|
Auto Padding for the convolution blocks |
|
""" |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size, kernel_size) |
|
if isinstance(dilation, int): |
|
dilation = (dilation, dilation) |
|
|
|
pad_h = ((kernel_size[0] - 1) * dilation[0]) // 2 |
|
pad_w = ((kernel_size[1] - 1) * dilation[1]) // 2 |
|
return (pad_h, pad_w) |
|
|
|
|
|
def get_activation(activation: str) -> nn.Module: |
|
""" |
|
Retrieves an activation function from the PyTorch nn module based on its name, case-insensitively. |
|
""" |
|
if not activation or activation.lower() in ["false", "none"]: |
|
return nn.Identity() |
|
|
|
activation_map = { |
|
name.lower(): obj |
|
for name, obj in nn.modules.activation.__dict__.items() |
|
if isinstance(obj, type) and issubclass(obj, nn.Module) |
|
} |
|
if activation.lower() in activation_map: |
|
return activation_map[activation.lower()]() |
|
else: |
|
raise ValueError(f"Activation function '{activation}' is not found in torch.nn") |
|
|