fewshot_random_subspace / models /R2D2_embedding.py
darklord25's picture
Upload model files
c10198c verified
raw
history blame
2.11 kB
import torch.nn as nn
import torch
# Embedding network used in Meta-learning with differentiable closed-form solvers
# (Bertinetto et al., in submission to NIPS 2018).
# They call the ridge rigressor version as "Ridge Regression Differentiable Discriminator (R2D2)."
# Note that they use a peculiar ordering of functions, namely conv-BN-pooling-lrelu,
# as opposed to the conventional one (conv-BN-lrelu-pooling).
def R2D2_conv_block(in_channels, out_channels, retain_activation=True, keep_prob=1.0):
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.MaxPool2d(2)
)
if retain_activation:
block.add_module("LeakyReLU", nn.LeakyReLU(0.1))
if keep_prob < 1.0:
block.add_module("Dropout", nn.Dropout(p=1 - keep_prob, inplace=False))
return block
class R2D2Embedding(nn.Module):
def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \
retain_last_activation=False):
super(R2D2Embedding, self).__init__()
self.block1 = R2D2_conv_block(x_dim, h1_dim)
self.block2 = R2D2_conv_block(h1_dim, h2_dim)
self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9)
# In the last conv block, we disable activation function to boost the classification accuracy.
# This trick was proposed by Gidaris et al. (CVPR 2018).
# With this trick, the accuracy goes up from 50% to 51%.
# Although the authors of R2D2 did not mention this trick in the paper,
# we were unable to reproduce the result of Bertinetto et al. without resorting to this trick.
self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7)
def forward(self, x):
b1 = self.block1(x)
b2 = self.block2(b1)
b3 = self.block3(b2)
b4 = self.block4(b3)
# Flatten and concatenate the output of the 3rd and 4th conv blocks as proposed in R2D2 paper.
return torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1)