Spaces:
Runtime error
Runtime error
File size: 445 Bytes
231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import SelectAdaptivePool2d
from .gem import GeM
def create_pool2d_layer(name, **kwargs):
assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
if name != "gem":
pool2d_layer = SelectAdaptivePool2d(pool_type=name, flatten=True)
elif name == "gem":
pool2d_layer = GeM(dim=2, **kwargs)
return pool2d_layer |