NeverlandPeter commited on
Commit
f7f9895
·
1 Parent(s): ba88ddd
img_demoAE.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import torch, types, os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ import torchvision as vision
11
+ import torchvision.transforms as transforms
12
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
13
+ print(f'loading...')
14
+
15
+ ########################################################################################################
16
+
17
+ model_prefix = 'out-v7c_d8_256-224-13bit-OB32x0.5-201'
18
+ input_img = 'kodim24-modified.png'
19
+
20
+ ########################################################################################################
21
+
22
+ class ToBinary(torch.autograd.Function):
23
+ @staticmethod
24
+ def forward(ctx, x):
25
+ return torch.floor(x + 0.5) # no need for noise when we have plenty of data
26
+
27
+ @staticmethod
28
+ def backward(ctx, grad_output):
29
+ return grad_output.clone() # pass-through
30
+
31
+ class R_ENCODER(nn.Module):
32
+ def __init__(self, args):
33
+ super().__init__()
34
+ self.args = args
35
+ dd = 8
36
+ self.Bxx = nn.BatchNorm2d(dd*64)
37
+
38
+ self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
39
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
40
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
41
+
42
+ self.B00 = nn.BatchNorm2d(dd*4)
43
+ self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
44
+ self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
45
+ self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
46
+ self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
47
+
48
+ self.B10 = nn.BatchNorm2d(dd*16)
49
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
50
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
51
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
52
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
53
+
54
+ self.B20 = nn.BatchNorm2d(dd*64)
55
+ self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
56
+ self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
57
+ self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
58
+ self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
59
+
60
+ self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
61
+
62
+ def forward(self, img):
63
+ ACT = F.mish
64
+
65
+ x = self.CIN(img)
66
+ xx = self.Bxx(F.pixel_unshuffle(x, 8))
67
+ x = x + self.Cx1(ACT(self.Cx0(x)))
68
+
69
+ x = F.pixel_unshuffle(x, 2)
70
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
71
+ x = x + self.C03(ACT(self.C02(x)))
72
+
73
+ x = F.pixel_unshuffle(x, 2)
74
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
75
+ x = x + self.C13(ACT(self.C12(x)))
76
+
77
+ x = F.pixel_unshuffle(x, 2)
78
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
79
+ x = x + self.C23(ACT(self.C22(x)))
80
+
81
+ x = self.COUT(x + xx)
82
+ return torch.sigmoid(x)
83
+
84
+ class R_DECODER(nn.Module):
85
+ def __init__(self, args):
86
+ super().__init__()
87
+ self.args = args
88
+ dd = 8
89
+ self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
90
+
91
+ self.B00 = nn.BatchNorm2d(dd*64)
92
+ self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
93
+ self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
94
+ self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
95
+ self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
96
+
97
+ self.B10 = nn.BatchNorm2d(dd*16)
98
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
99
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
100
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
101
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
102
+
103
+ self.B20 = nn.BatchNorm2d(dd*4)
104
+ self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
105
+ self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
106
+ self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
107
+ self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
108
+
109
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
110
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
111
+ self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
112
+
113
+ def forward(self, code):
114
+ ACT = F.mish
115
+ x = self.CIN(code)
116
+
117
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
118
+ x = x + self.C03(ACT(self.C02(x)))
119
+ x = F.pixel_shuffle(x, 2)
120
+
121
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
122
+ x = x + self.C13(ACT(self.C12(x)))
123
+ x = F.pixel_shuffle(x, 2)
124
+
125
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
126
+ x = x + self.C23(ACT(self.C22(x)))
127
+ x = F.pixel_shuffle(x, 2)
128
+
129
+ x = x + self.Cx1(ACT(self.Cx0(x)))
130
+ x = self.COUT(x)
131
+
132
+ return torch.sigmoid(x)
133
+
134
+ ########################################################################################################
135
+
136
+ print(f'building model...')
137
+ args = types.SimpleNamespace()
138
+ args.my_img_bit = 13
139
+ encoder = R_ENCODER(args).eval().cuda()
140
+ decoder = R_DECODER(args).eval().cuda()
141
+
142
+ zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long()
143
+
144
+ encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth'))
145
+ decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth'))
146
+
147
+ ########################################################################################################
148
+
149
+ print(f'test image...')
150
+ img_transform = transforms.Compose([
151
+ transforms.PILToTensor(),
152
+ transforms.ConvertImageDtype(torch.float),
153
+ transforms.Resize((224, 224))
154
+ ])
155
+
156
+ with torch.no_grad():
157
+ img = img_transform(Image.open(input_img)).unsqueeze(0).cuda()
158
+ z = encoder(img)
159
+ z = ToBinary.apply(z)
160
+
161
+ zz = torch.sum(z.squeeze().long() * zpow, dim=0)
162
+ print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
163
+
164
+ out = decoder(z)
165
+ vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.png")
kodim24-modified-out-13bit.png ADDED
kodim24-modified.png ADDED
out-v7c_d8_256-224-13bit-OB32x0.5-201-D.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:917ddad270353caf0243dbd09c2257414b9cb599ee43fe1b41b8e7af49bf03b8
3
+ size 25068760
out-v7c_d8_256-224-13bit-OB32x0.5-201-E.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65933944a19a00241ebfecce4e4b5e9bd2d7f1ac7d10f447b6b8c3e73a92093a
3
+ size 25076297