AItool commited on
Commit
3cce1e9
·
verified ·
1 Parent(s): b120055

Delete model

Browse files
Files changed (5) hide show
  1. model/IFNet_HDv3.py +0 -115
  2. model/RIFE_HDv3.py +0 -88
  3. model/flownet.pkl +0 -3
  4. model/loss.py +0 -128
  5. model/warplayer.py +0 -22
model/IFNet_HDv3.py DELETED
@@ -1,115 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from model.warplayer import warp
5
-
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
- def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
9
- return nn.Sequential(
10
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
11
- padding=padding, dilation=dilation, bias=True),
12
- nn.PReLU(out_planes)
13
- )
14
-
15
- def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
16
- return nn.Sequential(
17
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
18
- padding=padding, dilation=dilation, bias=False),
19
- nn.BatchNorm2d(out_planes),
20
- nn.PReLU(out_planes)
21
- )
22
-
23
- class IFBlock(nn.Module):
24
- def __init__(self, in_planes, c=64):
25
- super(IFBlock, self).__init__()
26
- self.conv0 = nn.Sequential(
27
- conv(in_planes, c//2, 3, 2, 1),
28
- conv(c//2, c, 3, 2, 1),
29
- )
30
- self.convblock0 = nn.Sequential(
31
- conv(c, c),
32
- conv(c, c)
33
- )
34
- self.convblock1 = nn.Sequential(
35
- conv(c, c),
36
- conv(c, c)
37
- )
38
- self.convblock2 = nn.Sequential(
39
- conv(c, c),
40
- conv(c, c)
41
- )
42
- self.convblock3 = nn.Sequential(
43
- conv(c, c),
44
- conv(c, c)
45
- )
46
- self.conv1 = nn.Sequential(
47
- nn.ConvTranspose2d(c, c//2, 4, 2, 1),
48
- nn.PReLU(c//2),
49
- nn.ConvTranspose2d(c//2, 4, 4, 2, 1),
50
- )
51
- self.conv2 = nn.Sequential(
52
- nn.ConvTranspose2d(c, c//2, 4, 2, 1),
53
- nn.PReLU(c//2),
54
- nn.ConvTranspose2d(c//2, 1, 4, 2, 1),
55
- )
56
-
57
- def forward(self, x, flow, scale=1):
58
- x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
59
- flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
60
- feat = self.conv0(torch.cat((x, flow), 1))
61
- feat = self.convblock0(feat) + feat
62
- feat = self.convblock1(feat) + feat
63
- feat = self.convblock2(feat) + feat
64
- feat = self.convblock3(feat) + feat
65
- flow = self.conv1(feat)
66
- mask = self.conv2(feat)
67
- flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
68
- mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
69
- return flow, mask
70
-
71
- class IFNet(nn.Module):
72
- def __init__(self):
73
- super(IFNet, self).__init__()
74
- self.block0 = IFBlock(7+4, c=90)
75
- self.block1 = IFBlock(7+4, c=90)
76
- self.block2 = IFBlock(7+4, c=90)
77
- self.block_tea = IFBlock(10+4, c=90)
78
- # self.contextnet = Contextnet()
79
- # self.unet = Unet()
80
-
81
- def forward(self, x, scale_list=[4, 2, 1], training=False):
82
- if training == False:
83
- channel = x.shape[1] // 2
84
- img0 = x[:, :channel]
85
- img1 = x[:, channel:]
86
- flow_list = []
87
- merged = []
88
- mask_list = []
89
- warped_img0 = img0
90
- warped_img1 = img1
91
- flow = (x[:, :4]).detach() * 0
92
- mask = (x[:, :1]).detach() * 0
93
- loss_cons = 0
94
- block = [self.block0, self.block1, self.block2]
95
- for i in range(3):
96
- f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
97
- f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
98
- flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
99
- mask = mask + (m0 + (-m1)) / 2
100
- mask_list.append(mask)
101
- flow_list.append(flow)
102
- warped_img0 = warp(img0, flow[:, :2])
103
- warped_img1 = warp(img1, flow[:, 2:4])
104
- merged.append((warped_img0, warped_img1))
105
- '''
106
- c0 = self.contextnet(img0, flow[:, :2])
107
- c1 = self.contextnet(img1, flow[:, 2:4])
108
- tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
109
- res = tmp[:, 1:4] * 2 - 1
110
- '''
111
- for i in range(3):
112
- mask_list[i] = torch.sigmoid(mask_list[i])
113
- merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
114
- # merged[i] = torch.clamp(merged[i] + res, 0, 1)
115
- return flow_list, mask_list[2], merged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/RIFE_HDv3.py DELETED
@@ -1,88 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- from torch.optim import AdamW
5
- import torch.optim as optim
6
- import itertools
7
- from model.warplayer import warp
8
- from torch.nn.parallel import DistributedDataParallel as DDP
9
- from train_log.IFNet_HDv3 import *
10
- import torch.nn.functional as F
11
- from model.loss import *
12
-
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- class Model:
16
- def __init__(self, local_rank=-1):
17
- self.flownet = IFNet()
18
- self.device()
19
- self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
20
- self.epe = EPE()
21
- # self.vgg = VGGPerceptualLoss().to(device)
22
- self.sobel = SOBEL()
23
- if local_rank != -1:
24
- self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
25
-
26
- def train(self):
27
- self.flownet.train()
28
-
29
- def eval(self):
30
- self.flownet.eval()
31
-
32
- def device(self):
33
- self.flownet.to(device)
34
-
35
- def load_model(self, path, rank=0):
36
- def convert(param):
37
- if rank == -1:
38
- return {
39
- k.replace("module.", ""): v
40
- for k, v in param.items()
41
- if "module." in k
42
- }
43
- else:
44
- return param
45
- if rank <= 0:
46
- if torch.cuda.is_available():
47
- self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
48
- else:
49
- self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')))
50
-
51
- def save_model(self, path, rank=0):
52
- if rank == 0:
53
- torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
54
-
55
- def inference(self, img0, img1, scale=1.0):
56
- imgs = torch.cat((img0, img1), 1)
57
- scale_list = [4/scale, 2/scale, 1/scale]
58
- flow, mask, merged = self.flownet(imgs, scale_list)
59
- return merged[2]
60
-
61
- def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
62
- for param_group in self.optimG.param_groups:
63
- param_group['lr'] = learning_rate
64
- img0 = imgs[:, :3]
65
- img1 = imgs[:, 3:]
66
- if training:
67
- self.train()
68
- else:
69
- self.eval()
70
- scale = [4, 2, 1]
71
- flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
72
- loss_l1 = (merged[2] - gt).abs().mean()
73
- loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
74
- # loss_vgg = self.vgg(merged[2], gt)
75
- if training:
76
- self.optimG.zero_grad()
77
- loss_G = loss_cons + loss_smooth * 0.1
78
- loss_G.backward()
79
- self.optimG.step()
80
- else:
81
- flow_teacher = flow[2]
82
- return merged[2], {
83
- 'mask': mask,
84
- 'flow': flow[2][:, :2],
85
- 'loss_l1': loss_l1,
86
- 'loss_cons': loss_cons,
87
- 'loss_smooth': loss_smooth,
88
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/flownet.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe854fc8996547c953f732aaa3b78cae76cc0a12833ae856ea0749c4c570d7d8
3
- size 12186817
 
 
 
 
model/loss.py DELETED
@@ -1,128 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torchvision.models as models
6
-
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
-
10
- class EPE(nn.Module):
11
- def __init__(self):
12
- super(EPE, self).__init__()
13
-
14
- def forward(self, flow, gt, loss_mask):
15
- loss_map = (flow - gt.detach()) ** 2
16
- loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
17
- return (loss_map * loss_mask)
18
-
19
-
20
- class Ternary(nn.Module):
21
- def __init__(self):
22
- super(Ternary, self).__init__()
23
- patch_size = 7
24
- out_channels = patch_size * patch_size
25
- self.w = np.eye(out_channels).reshape(
26
- (patch_size, patch_size, 1, out_channels))
27
- self.w = np.transpose(self.w, (3, 2, 0, 1))
28
- self.w = torch.tensor(self.w).float().to(device)
29
-
30
- def transform(self, img):
31
- patches = F.conv2d(img, self.w, padding=3, bias=None)
32
- transf = patches - img
33
- transf_norm = transf / torch.sqrt(0.81 + transf**2)
34
- return transf_norm
35
-
36
- def rgb2gray(self, rgb):
37
- r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38
- gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39
- return gray
40
-
41
- def hamming(self, t1, t2):
42
- dist = (t1 - t2) ** 2
43
- dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44
- return dist_norm
45
-
46
- def valid_mask(self, t, padding):
47
- n, _, h, w = t.size()
48
- inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49
- mask = F.pad(inner, [padding] * 4)
50
- return mask
51
-
52
- def forward(self, img0, img1):
53
- img0 = self.transform(self.rgb2gray(img0))
54
- img1 = self.transform(self.rgb2gray(img1))
55
- return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
-
57
-
58
- class SOBEL(nn.Module):
59
- def __init__(self):
60
- super(SOBEL, self).__init__()
61
- self.kernelX = torch.tensor([
62
- [1, 0, -1],
63
- [2, 0, -2],
64
- [1, 0, -1],
65
- ]).float()
66
- self.kernelY = self.kernelX.clone().T
67
- self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68
- self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
-
70
- def forward(self, pred, gt):
71
- N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72
- img_stack = torch.cat(
73
- [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74
- sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75
- sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76
- pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77
- pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
-
79
- L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80
- loss = (L1X+L1Y)
81
- return loss
82
-
83
- class MeanShift(nn.Conv2d):
84
- def __init__(self, data_mean, data_std, data_range=1, norm=True):
85
- c = len(data_mean)
86
- super(MeanShift, self).__init__(c, c, kernel_size=1)
87
- std = torch.Tensor(data_std)
88
- self.weight.data = torch.eye(c).view(c, c, 1, 1)
89
- if norm:
90
- self.weight.data.div_(std.view(c, 1, 1, 1))
91
- self.bias.data = -1 * data_range * torch.Tensor(data_mean)
92
- self.bias.data.div_(std)
93
- else:
94
- self.weight.data.mul_(std.view(c, 1, 1, 1))
95
- self.bias.data = data_range * torch.Tensor(data_mean)
96
- self.requires_grad = False
97
-
98
- class VGGPerceptualLoss(torch.nn.Module):
99
- def __init__(self, rank=0):
100
- super(VGGPerceptualLoss, self).__init__()
101
- blocks = []
102
- pretrained = True
103
- self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
104
- self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
105
- for param in self.parameters():
106
- param.requires_grad = False
107
-
108
- def forward(self, X, Y, indices=None):
109
- X = self.normalize(X)
110
- Y = self.normalize(Y)
111
- indices = [2, 7, 12, 21, 30]
112
- weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
113
- k = 0
114
- loss = 0
115
- for i in range(indices[-1]):
116
- X = self.vgg_pretrained_features[i](X)
117
- Y = self.vgg_pretrained_features[i](Y)
118
- if (i+1) in indices:
119
- loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
120
- k += 1
121
- return loss
122
-
123
- if __name__ == '__main__':
124
- img0 = torch.zeros(3, 3, 256, 256).float().to(device)
125
- img1 = torch.tensor(np.random.normal(
126
- 0, 1, (3, 3, 256, 256))).float().to(device)
127
- ternary_loss = Ternary()
128
- print(ternary_loss(img0, img1).shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/warplayer.py DELETED
@@ -1,22 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
- backwarp_tenGrid = {}
6
-
7
-
8
- def warp(tenInput, tenFlow):
9
- k = (str(tenFlow.device), str(tenFlow.size()))
10
- if k not in backwarp_tenGrid:
11
- tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
12
- 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
- tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
14
- 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
- backwarp_tenGrid[k] = torch.cat(
16
- [tenHorizontal, tenVertical], 1).to(device)
17
-
18
- tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
- tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
-
21
- g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
22
- return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)