Spaces:
Running
Running
File size: 1,277 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from rstor.architecture.base import BaseModel
from rstor.architecture.convolution_blocks import BaseConvolutionBlock, ResConvolutionBlock
from rstor.properties import LEAKY_RELU
import torch
class StackedConvolutions(BaseModel):
def __init__(self,
ch_in: int = 3,
ch_out: int = 3,
h_dim: int = 64,
num_layers: int = 8,
k_size: int = 3,
activation: str = LEAKY_RELU,
bias: bool = True,
) -> None:
super().__init__()
assert num_layers % 2 == 0, "Number of layers should be even"
self.conv_in_modality = BaseConvolutionBlock(
ch_in, h_dim, k_size, activation=activation, bias=bias)
conv_list = []
for _i in range(num_layers-2):
conv_list.append(ResConvolutionBlock(
h_dim, h_dim, k_size, activation=activation, bias=bias, residual=True))
self.conv_out_modality = BaseConvolutionBlock(
h_dim, ch_out, k_size, activation=None, bias=bias)
self.conv_stack = torch.nn.Sequential(self.conv_in_modality, *conv_list, self.conv_out_modality)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
return self.conv_stack(x_in)
|