Spaces:
Runtime error
Runtime error
File size: 6,244 Bytes
8fc2b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import cliport.utils.utils as utils
from transformers import DistilBertTokenizer, DistilBertModel
from cliport.models.resnet import IdentityBlock, ConvBlock
from cliport.models.core.unet import Up
from cliport.models.core import fusion
from cliport.models.core.fusion import FusionConvLat
class RN50BertLingUNetLat(nn.Module):
""" ImageNet RN50 & Bert with U-Net skip connections """
def __init__(self, input_shape, output_dim, cfg, device, preprocess):
super(RN50BertLingUNetLat, self).__init__()
self.input_shape = input_shape
self.output_dim = output_dim
self.input_dim = 2048
self.cfg = cfg
self.batchnorm = self.cfg['train']['batchnorm']
self.lang_fusion_type = self.cfg['train']['lang_fusion_type']
self.bilinear = True
self.up_factor = 2 if self.bilinear else 1
self.device = device
self.preprocess = preprocess
self._load_vision_fcn()
self._load_lang_enc()
self._build_decoder()
def _load_vision_fcn(self):
resnet50 = models.resnet50(pretrained=True)
modules = list(resnet50.children())[:-2]
self.stem = nn.Sequential(*modules[:4])
self.layer1 = modules[4]
self.layer2 = modules[5]
self.layer3 = modules[6]
self.layer4 = modules[7]
def _load_lang_enc(self):
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.text_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
self.text_fc = nn.Linear(768, 1024)
self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2)
self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4)
self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8)
self.proj_input_dim = 512 if 'word' in self.lang_fusion_type else 1024
self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024)
self.lang_proj2 = nn.Linear(self.proj_input_dim, 512)
self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)
def _build_decoder(self):
self.conv1 = nn.Sequential(
nn.Conv2d(self.input_dim, 1024, kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(True)
)
self.up1 = Up(2048, 1024 // self.up_factor, self.bilinear)
self.lat_fusion1 = FusionConvLat(input_dim=1024+512, output_dim=512)
self.up2 = Up(1024, 512 // self.up_factor, self.bilinear)
self.lat_fusion2 = FusionConvLat(input_dim=512+256, output_dim=256)
self.up3 = Up(512, 256 // self.up_factor, self.bilinear)
self.lat_fusion3 = FusionConvLat(input_dim=256+128, output_dim=128)
self.layer1 = nn.Sequential(
ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
nn.UpsamplingBilinear2d(scale_factor=2),
)
self.lat_fusion4 = FusionConvLat(input_dim=128+64, output_dim=64)
self.layer2 = nn.Sequential(
ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
IdentityBlock(32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
nn.UpsamplingBilinear2d(scale_factor=2),
)
self.lat_fusion5 = FusionConvLat(input_dim=64+32, output_dim=32)
self.layer3 = nn.Sequential(
ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
IdentityBlock(16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
nn.UpsamplingBilinear2d(scale_factor=2),
)
self.lat_fusion6 = FusionConvLat(input_dim=32+16, output_dim=16)
self.conv2 = nn.Sequential(
nn.Conv2d(16, self.output_dim, kernel_size=1)
)
def resnet50(self, x):
im = []
for layer in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
x = layer(x)
im.append(x)
return x, im
def encode_image(self, img):
with torch.no_grad():
img_encoding, img_im = self.resnet50(img)
return img_encoding, img_im
def encode_text(self, x):
with torch.no_grad():
inputs = self.tokenizer(x, return_tensors='pt')
input_ids, attention_mask = inputs['input_ids'].to(self.device), inputs['attention_mask'].to(self.device)
text_embeddings = self.text_encoder(input_ids, attention_mask)
text_encodings = text_embeddings.last_hidden_state.mean(1)
text_feat = self.text_fc(text_encodings)
text_mask = torch.ones_like(input_ids) # [1, max_token_len]
return text_feat, text_embeddings.last_hidden_state, text_mask
def forward(self, x, lat, l):
x = self.preprocess(x, dist='clip')
in_type = x.dtype
in_shape = x.shape
x = x[:,:3] # select RGB
x, im = self.encode_image(x)
x = x.to(in_type)
l_enc, l_emb, l_mask = self.encode_text(l)
l_input = l_emb if 'word' in self.lang_fusion_type else l_enc
l_input = l_input.to(dtype=x.dtype)
assert x.shape[1] == self.input_dim
x = self.conv1(x)
x = self.lang_fuser1(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj1)
x = self.up1(x, im[-2])
x = self.lat_fusion1(x, lat[-6])
x = self.lang_fuser2(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj2)
x = self.up2(x, im[-3])
x = self.lat_fusion2(x, lat[-5])
x = self.lang_fuser3(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj3)
x = self.up3(x, im[-4])
x = self.lat_fusion3(x, lat[-4])
x = self.layer1(x)
x = self.lat_fusion4(x, lat[-3])
x = self.layer2(x)
x = self.lat_fusion5(x, lat[-2])
x = self.layer3(x)
x = self.lat_fusion6(x, lat[-1])
x = self.conv2(x)
x = F.interpolate(x, size=(in_shape[-2], in_shape[-1]), mode='bilinear')
return x |