|
import torch.nn as nn |
|
import torch |
|
import torch.distributed as dist |
|
|
|
class GlobalAvgPool2d(nn.Module): |
|
def __init__(self): |
|
"""Global average pooling over the input's spatial dimensions""" |
|
super(GlobalAvgPool2d, self).__init__() |
|
|
|
def forward(self, inputs): |
|
in_size = inputs.size() |
|
return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) |
|
|
|
class SingleGPU(nn.Module): |
|
def __init__(self, module): |
|
super(SingleGPU, self).__init__() |
|
self.module=module |
|
|
|
def forward(self, input): |
|
return self.module(input.cuda(non_blocking=True)) |
|
|
|
|