File size: 445 Bytes
231edce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.layers import SelectAdaptivePool2d

from .gem import GeM


def create_pool2d_layer(name, **kwargs):
    assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
    if name != "gem":
        pool2d_layer = SelectAdaptivePool2d(pool_type=name, flatten=True)
    elif name == "gem": 
        pool2d_layer = GeM(dim=2, **kwargs)
    return pool2d_layer