rezasalatin
Add all files and directories
0e4f45d
raw
history blame
12.1 kB
import math
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet50, ResNet50_Weights
import myutils
class ResBlock(nn.Module):
"""A simple residual block component."""
def __init__(self, indim, outdim=None, stride=1):
super(ResBlock, self).__init__()
outdim = outdim or indim
self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
self.downsample = nn.Conv2d(indim, outdim, kernel_size=1, stride=stride) if indim != outdim or stride != 1 else None
def forward(self, x):
identity = x
out = F.relu(self.conv1(x))
out = self.conv2(out)
if self.downsample:
identity = self.downsample(identity)
out += identity
return F.relu(out)
class EncoderM(nn.Module):
def __init__(self, load_imagenet_params):
super(EncoderM, self).__init__()
self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
resnet = resnet50(weights=weights)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu # 1/2, 64
self.maxpool = resnet.maxpool
self.res2 = resnet.layer1 # 1/4, 256
self.res3 = resnet.layer2 # 1/8, 512
self.res4 = resnet.layer3 # 1/16, 1024
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, in_f, in_m, in_o):
f = (in_f - self.mean) / self.std
x = self.conv1(f) + self.conv1_m(in_m) + self.conv1_o(in_o)
x = self.bn1(x)
r1 = self.relu(x) # 1/2, 64
x = self.maxpool(r1) # 1/4, 64
r2 = self.res2(x) # 1/4, 256
r3 = self.res3(r2) # 1/8, 512
r4 = self.res4(r3) # 1/16, 1024
return r4, r1
class EncoderQ(nn.Module):
def __init__(self, load_imagenet_params):
super(EncoderQ, self).__init__()
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
resnet = resnet50(weights=weights)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu # 1/2, 64
self.maxpool = resnet.maxpool
self.res2 = resnet.layer1 # 1/4, 256
self.res3 = resnet.layer2 # 1/8, 512
self.res4 = resnet.layer3 # 1/8, 1024
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, in_f):
f = (in_f - self.mean) / self.std
x = self.conv1(f)
x = self.bn1(x)
r1 = self.relu(x) # 1/2, 64
x = self.maxpool(r1) # 1/4, 64
r2 = self.res2(x) # 1/4, 256
r3 = self.res3(r2) # 1/8, 512
r4 = self.res4(r3) # 1/8, 1024
return r4, r3, r2, r1
class KeyValue(nn.Module):
def __init__(self, indim, keydim, valdim):
super(KeyValue, self).__init__()
self.keydim = keydim
self.valdim = valdim
self.Key = nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1)
self.Value = nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1)
def forward(self, x):
key = self.Key(x)
key = key.view(*key.shape[:2], -1) # obj_n, key_dim, pixel_n
val = self.Value(x)
val = val.view(*val.shape[:2], -1) # obj_n, key_dim, pixel_n
return key, val
class Refine(nn.Module):
def __init__(self, inplanes, planes):
super(Refine, self).__init__()
self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1), stride=1)
self.ResFS = ResBlock(planes, planes)
self.ResMM = ResBlock(planes, planes)
self.scale_factor = 2
def forward(self, f, pm):
s = self.ResFS(self.convFS(f))
m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
m = self.ResMM(m)
return m
class Matcher(nn.Module):
def __init__(self, thres_valid=1e-3, update_bank=False):
super(Matcher, self).__init__()
self.thres_valid = thres_valid
self.update_bank = update_bank
def forward(self, feature_bank, q_in, q_out):
mem_out_list = []
for i in range(0, feature_bank.obj_n):
d_key, bank_n = feature_bank.keys[i].size()
try:
p = torch.matmul(feature_bank.keys[i].transpose(0, 1), q_in) / math.sqrt(d_key) # THW, HW
p = F.softmax(p, dim=1) # bs, bank_n, HW
mem = torch.matmul(feature_bank.values[i], p) # bs, D_o, HW
except RuntimeError as e:
device = feature_bank.keys[i].device
key_cpu = feature_bank.keys[i].cpu()
value_cpu = feature_bank.values[i].cpu()
q_in_cpu = q_in.cpu()
p = torch.matmul(key_cpu.transpose(0, 1), q_in_cpu) / math.sqrt(d_key) # THW, HW
p = F.softmax(p, dim=1) # bs, bank_n, HW
mem = torch.matmul(value_cpu, p).to(device) # bs, D_o, HW
p = p.to(device)
print('\tLine 158. GPU out of memory, use CPU', f'p size: {p.shape}')
mem_out_list.append(torch.cat([mem, q_out], dim=1))
if self.update_bank:
try:
ones = torch.ones_like(p)
zeros = torch.zeros_like(p)
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0]
except RuntimeError as e:
device = p.device
p = p.cpu()
ones = torch.ones_like(p)
zeros = torch.zeros_like(p)
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device)
print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}')
feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1)
mem_out_tensor = torch.stack(mem_out_list, dim=0).transpose(0, 1) # bs, obj_n, dim, pixel_n
return mem_out_tensor
class Decoder(nn.Module):
def __init__(self, device): # mdim_global = 256
super(Decoder, self).__init__()
self.device = device
mdim_global = 256
mdim_local = 32
local_size = 7
# Patch-wise
self.convFM = nn.Conv2d(1024, mdim_global, kernel_size=3, padding=1, stride=1)
self.ResMM = ResBlock(mdim_global, mdim_global)
self.RF3 = Refine(512, mdim_global) # 1/8 -> 1/8
self.RF2 = Refine(256, mdim_global) # 1/8 -> 1/4
self.pred2 = nn.Conv2d(mdim_global, 2, kernel_size=3, padding=1, stride=1)
# Local
self.local_avg = nn.AvgPool2d(local_size, stride=1, padding=local_size // 2)
self.local_max = nn.MaxPool2d(local_size, stride=1, padding=local_size // 2)
self.local_convFM = nn.Conv2d(128, mdim_local, kernel_size=3, padding=1, stride=1)
self.local_ResMM = ResBlock(mdim_local, mdim_local)
self.local_pred2 = nn.Conv2d(mdim_local, 2, kernel_size=3, padding=1, stride=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, patch_match, r3, r2, r1=None, feature_shape=None):
p = self.ResMM(self.convFM(patch_match))
p = self.RF3(r3, p) # out: 1/8, 256
p = self.RF2(r2, p) # out: 1/4, 256
p = self.pred2(F.relu(p))
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
bs, obj_n, h, w = feature_shape
rough_seg = F.softmax(p, dim=1)[:, 1]
rough_seg = rough_seg.view(bs, obj_n, h, w)
rough_seg = F.softmax(rough_seg, dim=1) # object-level normalization
# Local refinement
uncertainty = myutils.calc_uncertainty(rough_seg)
uncertainty = uncertainty.expand(-1, obj_n, -1, -1).reshape(bs * obj_n, 1, h, w)
rough_seg = rough_seg.view(bs * obj_n, 1, h, w) # bs*obj_n, 1, h, w
r1_weighted = r1 * rough_seg
r1_local = self.local_avg(r1_weighted) # bs*obj_n, 64, h, w
r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8) # neighborhood reference
r1_conf = self.local_max(rough_seg) # bs*obj_n, 1, h, w
local_match = torch.cat([r1, r1_local], dim=1)
q = self.local_ResMM(self.local_convFM(local_match))
q = r1_conf * self.local_pred2(F.relu(q))
p = p + uncertainty * q
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
p = F.softmax(p, dim=1)[:, 1] # no, h, w
return p
class AFB_URR(nn.Module):
def __init__(self, device, update_bank, load_imagenet_params=False):
super(AFB_URR, self).__init__()
self.device = device
self.encoder_m = EncoderM(load_imagenet_params)
self.encoder_q = EncoderQ(load_imagenet_params)
self.keyval_r4 = KeyValue(1024, keydim=128, valdim=512)
self.global_matcher = Matcher(update_bank=update_bank)
self.decoder = Decoder(device)
def memorize(self, frame, mask):
_, K, H, W = mask.shape
(frame, mask), pad = myutils.pad_divide_by([frame, mask], 16, (frame.size()[2], frame.size()[3]))
frame = frame.expand(K, -1, -1, -1) # obj_n, 3, h, w
mask = mask[0].unsqueeze(1).float()
mask_ones = torch.ones_like(mask)
mask_inv = (mask_ones - mask).clamp(0, 1)
r4, r1 = self.encoder_m(frame, mask, mask_inv)
k4, v4 = self.keyval_r4(r4) # num_objects, 128 and 512, H/16, W/16
k4_list = [k4[i] for i in range(K)]
v4_list = [v4[i] for i in range(K)]
return k4_list, v4_list
def segment(self, frame, fb_global):
obj_n = fb_global.obj_n
if not self.training:
[frame], pad = myutils.pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3]))
r4, r3, r2, r1 = self.encoder_q(frame)
bs, _, global_match_h, global_match_w = r4.shape
_, _, local_match_h, local_match_w = r1.shape
k4, v4 = self.keyval_r4(r4) # 1, dim, H/16, W/16
res_global = self.global_matcher(fb_global, k4, v4)
res_global = res_global.reshape(bs * obj_n, v4.shape[1] * 2, global_match_h, global_match_w)
r3_size = r3.shape
r2_size = r2.shape
r3 = r3.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r3_size[1:])
r2 = r2.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r2_size[1:])
r1_size = r1.shape
r1 = r1.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r1_size[1:])
feature_size = (bs, obj_n, r1_size[2], r1_size[3])
score = self.decoder(res_global, r3, r2, r1, feature_size)
# score = score.view(obj_n, bs, *frame.shape[-2:]).permute(1, 0, 2, 3)
score = score.view(bs, obj_n, *frame.shape[-2:])
if self.training:
uncertainty = myutils.calc_uncertainty(F.softmax(score, dim=1))
uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1]) # [B,1,H,W]
uncertainty = uncertainty.mean()
else:
uncertainty = None
score = torch.clamp(score, 1e-7, 1 - 1e-7)
score = torch.log((score / (1 - score)))
if not self.training:
if pad[2] + pad[3] > 0:
score = score[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
score = score[:, :, :, pad[0]:-pad[1]]
return score, uncertainty
def forward(self, x):
pass