Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class AveragePooling(nn.Module): | |
def __init__(self, pooling_size=2, device='cpu'): | |
super(AveragePooling, self).__init__() | |
self.pooling_size = pooling_size | |
self.device = device | |
self.to(device) | |
def forward(self, image_features): | |
batch_size, num_features, dim = image_features.size() | |
height = width = int(num_features ** 0.5) | |
image_features = image_features.view(batch_size, height, width, dim) | |
pooled_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2), kernel_size=self.pooling_size) | |
pooled_features = pooled_features.permute(0, 2, 3, 1) | |
pooled_features = pooled_features.view(batch_size, -1, dim) | |
return pooled_features | |
class AttentionPooling(nn.Module): | |
def __init__(self, input_dim, pooling_size=2, device='cpu',dtype=torch.float32): | |
super(AttentionPooling, self).__init__() | |
self.pooling_size = pooling_size | |
self.device = device | |
self.mlp = nn.Sequential( | |
nn.Linear(input_dim, input_dim), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(input_dim, 1)) | |
# self.mlp.to(device,dtype) | |
def forward(self, x): | |
batch_size, n, dim = x.shape | |
sqrt_n = int(n ** 0.5) | |
pooling_size = self.pooling_size | |
x = x.view(batch_size, sqrt_n, sqrt_n, dim) | |
pooled_features = [] | |
for i in range(0, sqrt_n, pooling_size): | |
for j in range(0, sqrt_n, pooling_size): | |
region = x[:, i:i+pooling_size, j:j+pooling_size, :] | |
region = region.reshape(batch_size, -1, dim) | |
alpha = self.mlp(region) | |
alpha = torch.softmax(alpha, dim=1) | |
region_pooled = torch.sum(alpha * region, dim=1) | |
pooled_features.append(region_pooled) | |
output = torch.stack(pooled_features, dim=1) | |
return output | |
def build_pooling(pooling_type, input_dim=None, pooling_size=2, device='cpu',dtype=torch.float32): | |
if pooling_type == 'average': | |
return AveragePooling(pooling_size=pooling_size, device=device) | |
elif pooling_type == 'attention': | |
if input_dim is None: | |
raise ValueError("input_dim must be specified for attention pooling") | |
return AttentionPooling(input_dim=input_dim, pooling_size=pooling_size, device=device, dtype=dtype) | |
else: | |
raise ValueError("Unknown pooling type: {}".format(pooling_type)) | |