|
from typing import Optional, List |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn import ModuleList, Module, Upsample |
|
|
|
from tha3.nn.common.conv_block_factory import ConvBlockFactory |
|
from tha3.nn.nonlinearity_factory import ReLUFactory |
|
from tha3.nn.normalization import InstanceNorm2dFactory |
|
from tha3.nn.util import BlockArgs |
|
|
|
|
|
class ResizeConvUNetArgs: |
|
def __init__(self, |
|
image_size: int, |
|
input_channels: int, |
|
start_channels: int, |
|
bottleneck_image_size: int, |
|
num_bottleneck_blocks: int, |
|
max_channels: int, |
|
upsample_mode: str = 'bilinear', |
|
block_args: Optional[BlockArgs] = None, |
|
use_separable_convolution: bool = False): |
|
if block_args is None: |
|
block_args = BlockArgs( |
|
normalization_layer_factory=InstanceNorm2dFactory(), |
|
nonlinearity_factory=ReLUFactory(inplace=False)) |
|
|
|
self.use_separable_convolution = use_separable_convolution |
|
self.block_args = block_args |
|
self.upsample_mode = upsample_mode |
|
self.max_channels = max_channels |
|
self.num_bottleneck_blocks = num_bottleneck_blocks |
|
self.bottleneck_image_size = bottleneck_image_size |
|
self.input_channels = input_channels |
|
self.start_channels = start_channels |
|
self.image_size = image_size |
|
|
|
|
|
class ResizeConvUNet(Module): |
|
def __init__(self, args: ResizeConvUNetArgs): |
|
super().__init__() |
|
self.args = args |
|
conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) |
|
|
|
self.downsample_blocks = ModuleList() |
|
self.downsample_blocks.append(conv_block_factory.create_conv3_block( |
|
self.args.input_channels, |
|
self.args.start_channels)) |
|
current_channels = self.args.start_channels |
|
current_size = self.args.image_size |
|
|
|
size_to_channel = { |
|
current_size: current_channels |
|
} |
|
while current_size > self.args.bottleneck_image_size: |
|
next_size = current_size // 2 |
|
next_channels = min(self.args.max_channels, current_channels * 2) |
|
self.downsample_blocks.append(conv_block_factory.create_downsample_block( |
|
current_channels, |
|
next_channels, |
|
is_output_1x1=False)) |
|
current_size = next_size |
|
current_channels = next_channels |
|
size_to_channel[current_size] = current_channels |
|
|
|
self.bottleneck_blocks = ModuleList() |
|
for i in range(self.args.num_bottleneck_blocks): |
|
self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False)) |
|
|
|
self.output_image_sizes = [current_size] |
|
self.output_num_channels = [current_channels] |
|
self.upsample_blocks = ModuleList() |
|
while current_size < self.args.image_size: |
|
next_size = current_size * 2 |
|
next_channels = size_to_channel[next_size] |
|
self.upsample_blocks.append(conv_block_factory.create_conv3_block( |
|
current_channels + next_channels, |
|
next_channels)) |
|
current_size = next_size |
|
current_channels = next_channels |
|
self.output_image_sizes.append(current_size) |
|
self.output_num_channels.append(current_channels) |
|
|
|
if args.upsample_mode == 'nearest': |
|
align_corners = None |
|
else: |
|
align_corners = False |
|
self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners) |
|
|
|
def forward(self, feature: Tensor) -> List[Tensor]: |
|
downsampled_features = [] |
|
for block in self.downsample_blocks: |
|
feature = block(feature) |
|
downsampled_features.append(feature) |
|
|
|
for block in self.bottleneck_blocks: |
|
feature = block(feature) |
|
|
|
outputs = [feature] |
|
for i in range(0, len(self.upsample_blocks)): |
|
feature = self.double_resolution(feature) |
|
feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1) |
|
feature = self.upsample_blocks[i](feature) |
|
outputs.append(feature) |
|
|
|
return outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device('cuda') |
|
|
|
image_size = 512 |
|
image_channels = 4 |
|
num_pose_params = 6 |
|
args = ResizeConvUNetArgs( |
|
image_size=512, |
|
input_channels=10, |
|
start_channels=32, |
|
bottleneck_image_size=32, |
|
num_bottleneck_blocks=6, |
|
max_channels=512, |
|
upsample_mode='nearest', |
|
use_separable_convolution=False, |
|
block_args=BlockArgs( |
|
initialization_method='he', |
|
use_spectral_norm=False, |
|
normalization_layer_factory=InstanceNorm2dFactory(), |
|
nonlinearity_factory=ReLUFactory(inplace=False))) |
|
module = ResizeConvUNet(args).to(device) |
|
|
|
image_count = 8 |
|
input = torch.zeros(image_count, 10, 512, 512, device=device) |
|
outputs = module.forward(input) |
|
for output in outputs: |
|
print(output.shape) |
|
|
|
|
|
if True: |
|
repeat = 100 |
|
acc = 0.0 |
|
for i in range(repeat + 2): |
|
start = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
|
|
start.record() |
|
module.forward(input) |
|
end.record() |
|
torch.cuda.synchronize() |
|
if i >= 2: |
|
elapsed_time = start.elapsed_time(end) |
|
print("%d:" % i, elapsed_time) |
|
acc = acc + elapsed_time |
|
|
|
print("average:", acc / repeat) |