flashsloth / model /pooling.py
Tongbo's picture
Upload folder using huggingface_hub
04f8e39 verified
raw
history blame
2.55 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class AveragePooling(nn.Module):
def __init__(self, pooling_size=2, device='cpu'):
super(AveragePooling, self).__init__()
self.pooling_size = pooling_size
self.device = device
self.to(device)
def forward(self, image_features):
batch_size, num_features, dim = image_features.size()
height = width = int(num_features ** 0.5)
image_features = image_features.view(batch_size, height, width, dim)
pooled_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2), kernel_size=self.pooling_size)
pooled_features = pooled_features.permute(0, 2, 3, 1)
pooled_features = pooled_features.view(batch_size, -1, dim)
return pooled_features
class AttentionPooling(nn.Module):
def __init__(self, input_dim, pooling_size=2, device='cpu',dtype=torch.float32):
super(AttentionPooling, self).__init__()
self.pooling_size = pooling_size
self.device = device
self.mlp = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(input_dim, 1))
# self.mlp.to(device,dtype)
def forward(self, x):
batch_size, n, dim = x.shape
sqrt_n = int(n ** 0.5)
pooling_size = self.pooling_size
x = x.view(batch_size, sqrt_n, sqrt_n, dim)
pooled_features = []
for i in range(0, sqrt_n, pooling_size):
for j in range(0, sqrt_n, pooling_size):
region = x[:, i:i+pooling_size, j:j+pooling_size, :]
region = region.reshape(batch_size, -1, dim)
alpha = self.mlp(region)
alpha = torch.softmax(alpha, dim=1)
region_pooled = torch.sum(alpha * region, dim=1)
pooled_features.append(region_pooled)
output = torch.stack(pooled_features, dim=1)
return output
def build_pooling(pooling_type, input_dim=None, pooling_size=2, device='cpu',dtype=torch.float32):
if pooling_type == 'average':
return AveragePooling(pooling_size=pooling_size, device=device)
elif pooling_type == 'attention':
if input_dim is None:
raise ValueError("input_dim must be specified for attention pooling")
return AttentionPooling(input_dim=input_dim, pooling_size=pooling_size, device=device, dtype=dtype)
else:
raise ValueError("Unknown pooling type: {}".format(pooling_type))