mehdidc commited on
Commit
fa128ec
·
1 Parent(s): ef0ee1c

add app and generation / model code

Browse files
Files changed (7) hide show
  1. app.py +34 -0
  2. cli.py +320 -0
  3. convert.py +52 -0
  4. data.py +94 -0
  5. model.py +260 -0
  6. test.py +21 -0
  7. viz.py +204 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from cli import iterative_refinement
6
+ from viz import grid_of_images_default
7
+ from subprocess
8
+ subprocess.call("download_models.sh", shell=True)
9
+ models = {
10
+ "convae": torch.load("convae.th", map_location="cpu"),
11
+ "deep_convae": torch.load("deep_convae.th", map_location="cpu"),
12
+ }
13
+
14
+ def gen(model, seed, nb_iter, nb_samples, width, height):
15
+ torch.manual_seed(int(seed))
16
+ bs = 64
17
+ model = models[model]
18
+ samples = iterative_refinement(
19
+ model,
20
+ nb_iter=int(nb_iter),
21
+ nb_examples=int(nb_samples),
22
+ w=int(width), h=int(height), c=1,
23
+ batch_size=bs,
24
+ )
25
+ grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
26
+ grid = (grid*255).astype("uint8")
27
+ return Image.fromarray(grid)
28
+
29
+ iface = gr.Interface(
30
+ fn=gen,
31
+ inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)],
32
+ outputs="image"
33
+ )
34
+ iface.launch()
cli.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+ import matplotlib.pyplot as plt
5
+ from functools import partial
6
+
7
+ from clize import run
8
+ import numpy as np
9
+ from skimage.io import imsave
10
+
11
+ from viz import grid_of_images_default
12
+
13
+ import torch.nn as nn
14
+ import torch
15
+
16
+ from model import DenseAE
17
+ from model import ConvAE
18
+ from model import DeepConvAE
19
+ from model import SimpleConvAE
20
+ from model import ZAE
21
+ from model import KAE
22
+ from data import load_dataset
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+
27
+ def plot_dataset(code_2d, categories):
28
+ colors = [
29
+ 'r',
30
+ 'b',
31
+ 'g',
32
+ 'crimson',
33
+ 'gold',
34
+ 'yellow',
35
+ 'maroon',
36
+ 'm',
37
+ 'c',
38
+ 'orange'
39
+ ]
40
+ for cat in range(0, 10):
41
+ g = (categories == cat)
42
+ plt.scatter(
43
+ code_2d[g, 0],
44
+ code_2d[g, 1],
45
+ marker='+',
46
+ c=colors[cat],
47
+ s=40,
48
+ alpha=0.7,
49
+ label="digit {}".format(cat)
50
+ )
51
+
52
+
53
+ def plot_generated(code_2d, categories):
54
+ g = (categories < 0)
55
+ plt.scatter(
56
+ code_2d[g, 0],
57
+ code_2d[g, 1],
58
+ marker='+',
59
+ c='gray',
60
+ s=30
61
+ )
62
+
63
+
64
+ def grid_embedding(h):
65
+ from lapjv import lapjv
66
+ from scipy.spatial.distance import cdist
67
+ assert int(np.sqrt(h.shape[0])) ** 2 == h.shape[0], 'Nb of examples must be a square number'
68
+ size = int(np.sqrt(h.shape[0]))
69
+ grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
70
+ cost_matrix = cdist(grid, h, "sqeuclidean").astype('float32')
71
+ cost_matrix = cost_matrix * (100000 / cost_matrix.max())
72
+ _, rows, cols = lapjv(cost_matrix)
73
+ return rows
74
+
75
+
76
+ def save_weights(m, folder='.'):
77
+ if isinstance(m, nn.Linear):
78
+ w = m.weight.data
79
+ if w.size(1) == 28*28 or w.size(0) == 28*28:
80
+ w0, w1 = w.size(0), w.size(1)
81
+ if w0 == 28*28:
82
+ w = w.transpose(0, 1)
83
+ w = w.contiguous()
84
+ w = w.view(w.size(0), 1, 28, 28)
85
+ gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
86
+ imsave('{}/feat_{}.png'.format(folder, w0), gr)
87
+ elif isinstance(m, nn.ConvTranspose2d):
88
+ w = m.weight.data
89
+ if w.size(0) in (32, 64, 128, 256, 512) and w.size(1) in (1, 3):
90
+ gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
91
+ imsave('{}/feat.png'.format(folder), gr)
92
+
93
+ @torch.no_grad()
94
+ def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None):
95
+ if batch_size is None:
96
+ batch_size = nb_examples
97
+ x = torch.rand(nb_iter, nb_examples, c, w, h)
98
+ for i in range(1, nb_iter):
99
+ for j in range(0, nb_examples, batch_size):
100
+ oldv = x[i-1][j:j + batch_size].to(device)
101
+ newv = ae(oldv)
102
+ newv = newv.data.cpu()
103
+ x[i][j:j + batch_size] = newv
104
+ return x
105
+
106
+
107
+ def build_model(name, w, h, c):
108
+ if name == 'convae':
109
+ ae = ConvAE(
110
+ w=w, h=h, c=c,
111
+ nb_filters=128,
112
+ spatial=True,
113
+ channel=True,
114
+ channel_stride=4,
115
+ )
116
+ elif name == 'zae':
117
+ ae = ZAE(
118
+ w=w, h=h, c=c,
119
+ theta=3,
120
+ nb_hidden=1000,
121
+ )
122
+ elif name == 'kae':
123
+ ae = KAE(
124
+ w=w, h=h, c=c,
125
+ nb_active=1000,
126
+ nb_hidden=1000,
127
+ )
128
+ elif name == 'denseae':
129
+ ae = DenseAE(
130
+ w=w, h=h, c=c,
131
+ encode_hidden=[1000],
132
+ decode_hidden=[],
133
+ ksparse=True,
134
+ nb_active=50,
135
+ )
136
+ elif name == 'simple_convae':
137
+ ae = SimpleConvAE(
138
+ w=w, h=h, c=c,
139
+ nb_filters=128,
140
+ )
141
+ elif name == 'deep_convae':
142
+ ae = DeepConvAE(
143
+ w=w, h=h, c=c,
144
+ nb_filters=128,
145
+ spatial=True,
146
+ channel=True,
147
+ channel_stride=4,
148
+ nb_layers=3,
149
+ )
150
+ else:
151
+ raise ValueError('Unknown model')
152
+
153
+ return ae
154
+
155
+
156
+ def salt_and_pepper(X, proba=0.5):
157
+ a = (torch.rand(X.size()).to(device) <= (1 - proba)).float()
158
+ b = (torch.rand(X.size()).to(device) <= 0.5).float()
159
+ c = ((a == 0).float() * b)
160
+ return X * a + c
161
+
162
+
163
+ def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walkback=False, denoise=False, epochs=100, batch_size=64, log_interval=100):
164
+ gamma = 0.99
165
+ dataset = load_dataset(dataset, split='train')
166
+ x0, _ = dataset[0]
167
+ c, h, w = x0.size()
168
+ dataloader = torch.utils.data.DataLoader(
169
+ dataset,
170
+ batch_size=batch_size,
171
+ shuffle=True,
172
+ num_workers=4
173
+ )
174
+ if resume:
175
+ ae = torch.load('{}/model.th'.format(folder))
176
+ ae = ae.to(device)
177
+ else:
178
+ ae = build_model(model, w=w, h=h, c=c)
179
+ ae = ae.to(device)
180
+ optim = torch.optim.Adadelta(ae.parameters(), lr=0.1, eps=1e-7, rho=0.95, weight_decay=0)
181
+ avg_loss = 0.
182
+ nb_updates = 0
183
+ _save_weights = partial(save_weights, folder=folder)
184
+
185
+ for epoch in range(epochs):
186
+ for X, y in dataloader:
187
+ ae.zero_grad()
188
+ X = X.to(device)
189
+ if hasattr(ae, 'nb_active'):
190
+ ae.nb_active = max(ae.nb_active - 1, 32)
191
+ # walkback + denoise
192
+ if walkback:
193
+ loss = 0.
194
+ x = X.data
195
+ nb = 5
196
+ for _ in range(nb):
197
+ x = salt_and_pepper(x, proba=0.3) # denoise
198
+ x = x.to(device)
199
+ x = ae(x) # reconstruct
200
+ Xr = x
201
+ loss += (((x - X) ** 2).view(X.size(0), -1).sum(1).mean()) / nb
202
+ x = (torch.rand(x.size()).to(device) <= x.data).float() # sample
203
+ # denoise only
204
+ elif denoise:
205
+ Xc = salt_and_pepper(X.data, proba=0.3)
206
+ Xr = ae(Xc)
207
+ loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
208
+ # normal training
209
+ else:
210
+ Xr = ae(X)
211
+ loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
212
+ loss.backward()
213
+ optim.step()
214
+ avg_loss = avg_loss * gamma + loss.item() * (1 - gamma)
215
+ if nb_updates % log_interval == 0:
216
+ print('Epoch : {:05d} AvgTrainLoss: {:.6f}, Batch Loss : {:.6f}'.format(epoch, avg_loss, loss.item() ))
217
+ gr = grid_of_images_default(np.array(Xr.data.tolist()))
218
+ imsave('{}/rec.png'.format(folder), gr)
219
+ ae.apply(_save_weights)
220
+ torch.save(ae, '{}/model.th'.format(folder))
221
+ nb_updates += 1
222
+
223
+
224
+ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_generate=100, tsne=False):
225
+ if not os.path.exists(folder):
226
+ os.makedirs(folder, exist_ok=True)
227
+ dataset = load_dataset(dataset, split='train')
228
+ x0, _ = dataset[0]
229
+ c, h, w = x0.size()
230
+ nb = nb_generate
231
+ print('Load model...')
232
+ if model_path is None:
233
+ model_path = os.path.join(folder, "model.th")
234
+ ae = torch.load(model_path, map_location="cpu")
235
+ ae = ae.to(device)
236
+ def enc(X):
237
+ batch_size = 64
238
+ h_list = []
239
+ for i in range(0, X.size(0), batch_size):
240
+ x = X[i:i + batch_size]
241
+ x = x.to(device)
242
+ name = ae.__class__.__name__
243
+ if name in ('ConvAE',):
244
+ h = ae.encode(x)
245
+ h, _ = h.max(2)
246
+ h = h.view((h.size(0), -1))
247
+ elif name in ('DenseAE',):
248
+ x = x.view(x.size(0), -1)
249
+ h = x
250
+ #h = ae.encode(x)
251
+ else:
252
+ h = x.view(x.size(0), -1)
253
+ h = h.data.cpu()
254
+ h_list.append(h)
255
+ return torch.cat(h_list, 0)
256
+
257
+ print('iterative refinement...')
258
+ g = iterative_refinement(
259
+ ae,
260
+ nb_iter=nb_iter,
261
+ nb_examples=nb,
262
+ w=w, h=h, c=c,
263
+ batch_size=64
264
+ )
265
+ np.savez('{}/generated.npz'.format(folder), X=g.numpy())
266
+ g_subset = g[:, 0:100]
267
+ gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
268
+ imsave('{}/gen_full_iters.png'.format(folder), gr)
269
+
270
+ g = g[-1] # last iter
271
+ print(g.shape)
272
+ gr = grid_of_images_default(g.numpy())
273
+ imsave('{}/gen_full.png'.format(folder), gr)
274
+
275
+ if tsne:
276
+ from sklearn.manifold import TSNE
277
+ dataloader = torch.utils.data.DataLoader(
278
+ dataset,
279
+ batch_size=nb,
280
+ shuffle=True,
281
+ num_workers=1
282
+ )
283
+ print('Load data...')
284
+ X, y = next(iter(dataloader))
285
+ print('Encode data...')
286
+ xh = enc(X)
287
+ print('Encode generated...')
288
+ gh = enc(g)
289
+ X = X.numpy()
290
+ g = g.numpy()
291
+ xh = xh.numpy()
292
+ gh = gh.numpy()
293
+
294
+ a = np.concatenate((X, g), axis=0)
295
+ ah = np.concatenate((xh, gh), axis=0)
296
+ labels = np.array(y.tolist() + [-1] * len(g))
297
+ sne = TSNE()
298
+ print('fit tsne...')
299
+ ah = sne.fit_transform(ah)
300
+ print('grid embedding...')
301
+
302
+ asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
303
+ ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
304
+ rows = grid_embedding(ahsmall)
305
+ asmall = asmall[rows]
306
+ gr = grid_of_images_default(asmall)
307
+ imsave('{}/sne_grid.png'.format(folder), gr)
308
+
309
+ fig = plt.figure(figsize=(10, 10))
310
+ plot_dataset(ah, labels)
311
+ plot_generated(ah, labels)
312
+ plt.legend(loc='best')
313
+ plt.axis('off')
314
+ plt.savefig('{}/sne.png'.format(folder))
315
+ plt.close(fig)
316
+
317
+
318
+
319
+ if __name__ == '__main__':
320
+ run([train, test])
convert.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch, h5py
3
+ from model import *
4
+ w, h, c = 28, 28, 1
5
+ model_new = DeepConvAE(
6
+ w=w, h=h, c=c,
7
+ nb_filters=128,
8
+ spatial=True,
9
+ channel=True,
10
+ channel_stride=4,
11
+ # total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder
12
+ nb_layers=3,
13
+ )
14
+ # model_old = h5py.File("mnist_deepconvae/model.h5")
15
+ model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5")
16
+
17
+
18
+ print(model_new)
19
+ print(model_old["model_weights"].keys())
20
+
21
+
22
+ for name, param in model_new.named_parameters():
23
+ enc_or_decode, layer_id, bias_or_kernel = name.split(".")
24
+
25
+ if enc_or_decode == "encode":
26
+ layer_name = "conv2d"
27
+ else:
28
+ layer_name = "up_conv2d"
29
+
30
+ layer_id = (int(layer_id)//2) + 1
31
+
32
+ full_layer_name = f"{layer_name}_{layer_id}"
33
+ print(full_layer_name)
34
+
35
+ k = "kernel" if bias_or_kernel == "weight" else "bias"
36
+ weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()]
37
+ weights = np.array(weights)
38
+ weights = torch.from_numpy(weights)
39
+ print(name, layer_id, param.shape, weights.shape)
40
+ inds = [4,3,2,1,0]
41
+ if k == "kernel":
42
+ if layer_name == "conv2d":
43
+ weights = weights.permute((3,2,0,1))
44
+ weights = weights[:,:,inds]
45
+ weights = weights[:,:,:, inds]
46
+ print("W", weights.shape)
47
+ elif layer_name == "up_conv2d":
48
+ weights = weights.permute((2,3,0,1))
49
+ print(param.shape, weights.shape)
50
+ param.data.copy_(weights)
51
+ print((param-weights).sum())
52
+ torch.save(model_new, "mnist_deepconvae/model.th")
data.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torchvision.transforms as transforms
4
+ import torchvision.datasets as dset
5
+
6
+
7
+ class Invert:
8
+ def __call__(self, x):
9
+ return 1 - x
10
+
11
+ class Gray:
12
+ def __call__(self, x):
13
+ return x[0:1]
14
+
15
+
16
+
17
+ def load_dataset(dataset_name, split='full'):
18
+ if dataset_name == 'mnist':
19
+ dataset = dset.MNIST(
20
+ root='data/mnist',
21
+ download=True,
22
+ transform=transforms.Compose([
23
+ transforms.ToTensor(),
24
+ ])
25
+ )
26
+ return dataset
27
+ elif dataset_name == 'coco':
28
+ dataset = dset.ImageFolder(root='data/coco',
29
+ transform=transforms.Compose([
30
+ transforms.Scale(64),
31
+ transforms.CenterCrop(64),
32
+ transforms.ToTensor(),
33
+ ]))
34
+ return dataset
35
+ elif dataset_name == 'quickdraw':
36
+ X = (np.load('data/quickdraw/teapot.npy'))
37
+ X = X.reshape((X.shape[0], 28, 28))
38
+ X = X / 255.
39
+ X = X.astype(np.float32)
40
+ X = torch.from_numpy(X)
41
+ dataset = TensorDataset(X, X)
42
+ return dataset
43
+ elif dataset_name == 'shoes':
44
+ dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes',
45
+ transform=transforms.Compose([
46
+ transforms.Scale(64),
47
+ transforms.CenterCrop(64),
48
+ transforms.ToTensor(),
49
+ ]))
50
+ return dataset
51
+ elif dataset_name == 'footwear':
52
+ dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images',
53
+ transform=transforms.Compose([
54
+ transforms.Scale(64),
55
+ transforms.CenterCrop(64),
56
+ transforms.ToTensor(),
57
+ ]))
58
+ return dataset
59
+ elif dataset_name == 'celeba':
60
+ dataset = dset.ImageFolder(root='data/celeba',
61
+ transform=transforms.Compose([
62
+ transforms.Scale(32),
63
+ transforms.CenterCrop(32),
64
+ transforms.ToTensor(),
65
+ ]))
66
+ return dataset
67
+ elif dataset_name == 'birds':
68
+ dataset = dset.ImageFolder(root='data/birds/'+split,
69
+ transform=transforms.Compose([
70
+ transforms.Scale(32),
71
+ transforms.CenterCrop(32),
72
+ transforms.ToTensor(),
73
+ ]))
74
+ return dataset
75
+ elif dataset_name == 'sketchy':
76
+ dataset = dset.ImageFolder(root='data/sketchy/'+split,
77
+ transform=transforms.Compose([
78
+ transforms.Scale(64),
79
+ transforms.CenterCrop(64),
80
+ transforms.ToTensor(),
81
+ Gray()
82
+ ]))
83
+ return dataset
84
+
85
+ elif dataset_name == 'fonts':
86
+ dataset = dset.ImageFolder(root='data/fonts/'+split,
87
+ transform=transforms.Compose([
88
+ transforms.ToTensor(),
89
+ Invert(),
90
+ Gray(),
91
+ ]))
92
+ return dataset
93
+ else:
94
+ raise ValueError('Error : unknown dataset')
model.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.init import xavier_uniform
5
+
6
+ class KAE(nn.Module):
7
+
8
+ def __init__(self, w=32, h=32, c=1, nb_hidden=300, nb_active=16):
9
+ super().__init__()
10
+ self.nb_hidden = nb_hidden
11
+ self.nb_active = nb_active
12
+ self.encode = nn.Sequential(
13
+ nn.Linear(w*h*c, nb_hidden, bias=False)
14
+ )
15
+ self.bias = nn.Parameter(torch.zeros(w*h*c))
16
+ self.params = nn.ParameterList([self.bias])
17
+ self.apply(_weights_init)
18
+
19
+ def forward(self, X):
20
+ size = X.size()
21
+ X = X.view(X.size(0), -1)
22
+ h = self.encode(X)
23
+ Xr, _ = self.decode(h)
24
+ Xr = Xr.view(size)
25
+ return Xr
26
+
27
+ def decode(self, h):
28
+ thetas, _ = torch.sort(h, dim=1, descending=True)
29
+ thetas = thetas[:, self.nb_active:self.nb_active+1]
30
+ h = h * (h > thetas).float()
31
+ Xr = torch.matmul(h, self.encode[0].weight) + self.bias
32
+ Xr = nn.Sigmoid()(Xr)
33
+ return Xr, h
34
+
35
+
36
+ class ZAE(nn.Module):
37
+
38
+ def __init__(self, w=32, h=32, c=1, nb_hidden=300, theta=1):
39
+ super().__init__()
40
+ self.nb_hidden = nb_hidden
41
+ self.theta = theta
42
+ self.encode = nn.Sequential(
43
+ nn.Linear(w*h*c, nb_hidden, bias=False)
44
+ )
45
+ self.bias = nn.Parameter(torch.zeros(w*h*c))
46
+ self.params = nn.ParameterList([self.bias])
47
+ self.apply(_weights_init)
48
+
49
+ def forward(self, X):
50
+ size = X.size()
51
+ X = X.view(X.size(0), -1)
52
+ h = self.encode(X)
53
+ Xr, _ = self.decode(h)
54
+ Xr = Xr.view(size)
55
+ return Xr
56
+
57
+ def decode(self, h):
58
+ h = h * (h > self.theta).float()
59
+ Xr = torch.matmul(h, self.encode[0].weight) + self.bias
60
+ Xr = nn.Sigmoid()(Xr)
61
+ return Xr, h
62
+
63
+
64
+
65
+ class DenseAE(nn.Module):
66
+
67
+ def __init__(self, w=32, h=32, c=1, encode_hidden=(300,), decode_hidden=(300,), ksparse=True, nb_active=10, denoise=None):
68
+ super().__init__()
69
+ self.encode_hidden = encode_hidden
70
+ self.decode_hidden = decode_hidden
71
+ self.ksparse = ksparse
72
+ self.nb_active = nb_active
73
+ self.denoise = denoise
74
+
75
+ # encode layers
76
+ layers = []
77
+ hid_prev = w * h * c
78
+ for hid in encode_hidden:
79
+ layers.extend([
80
+ nn.Linear(hid_prev, hid),
81
+ nn.ReLU(True)
82
+ ])
83
+ hid_prev = hid
84
+ self.encode = nn.Sequential(*layers)
85
+
86
+ # decode layers
87
+ layers = []
88
+ for hid in decode_hidden:
89
+ layers.extend([
90
+ nn.Linear(hid_prev, hid),
91
+ nn.ReLU(True)
92
+ ])
93
+ hid_prev = hid
94
+ layers.extend([
95
+ nn.Linear(hid_prev, w * h * c),
96
+ nn.Sigmoid()
97
+ ])
98
+ self.decode = nn.Sequential(*layers)
99
+
100
+ self.apply(_weights_init)
101
+
102
+ def forward(self, X):
103
+ size = X.size()
104
+ if self.denoise is not None:
105
+ X = X * ((torch.rand(X.size()) <= self.denoise).float()).to(X.device)
106
+ X = X.view(X.size(0), -1)
107
+ h = self.encode(X)
108
+ if self.ksparse:
109
+ h = ksparse(h, nb_active=self.nb_active)
110
+ Xr = self.decode(h)
111
+ Xr = Xr.view(size)
112
+ return Xr
113
+
114
+
115
+
116
+ def ksparse(x, nb_active=10):
117
+ mask = torch.ones(x.size())
118
+ for i, xi in enumerate(x.data.tolist()):
119
+ inds = np.argsort(xi)
120
+ inds = inds[::-1]
121
+ inds = inds[nb_active:]
122
+ if len(inds):
123
+ inds = np.array(inds)
124
+ inds = torch.from_numpy(inds).long()
125
+ mask[i][inds] = 0
126
+ return x * (mask).float().to(x.device)
127
+
128
+
129
+ class ConvAE(nn.Module):
130
+
131
+ def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
132
+ super().__init__()
133
+ self.spatial = spatial
134
+ self.channel = channel
135
+ self.channel_stride = channel_stride
136
+ self.encode = nn.Sequential(
137
+ nn.Conv2d(c, nb_filters, 5, 1, 0),
138
+ nn.ReLU(True),
139
+ nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
140
+ nn.ReLU(True),
141
+ nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
142
+ )
143
+ self.decode = nn.Sequential(
144
+ nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
145
+ nn.Sigmoid()
146
+ )
147
+ self.apply(_weights_init)
148
+
149
+ def forward(self, X):
150
+ size = X.size()
151
+ h = self.encode(X)
152
+ h = self.sparsify(h)
153
+ Xr = self.decode(h)
154
+ return Xr
155
+
156
+ def sparsify(self, h):
157
+ if self.spatial:
158
+ h = spatial_sparsity(h)
159
+ if self.channel:
160
+ h = strided_channel_sparsity(h, stride=self.channel_stride)
161
+ return h
162
+
163
+ class SimpleConvAE(nn.Module):
164
+
165
+ def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
166
+ super().__init__()
167
+ self.spatial = spatial
168
+ self.channel = channel
169
+ self.channel_stride = channel_stride
170
+ self.encode = nn.Sequential(
171
+ nn.Conv2d(c, nb_filters, 13, 1, 0),
172
+ nn.ReLU(True),
173
+ )
174
+ self.decode = nn.Sequential(
175
+ nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
176
+ nn.Sigmoid()
177
+ )
178
+ self.apply(_weights_init)
179
+
180
+ def forward(self, X):
181
+ size = X.size()
182
+ h = self.encode(X)
183
+ h = self.sparsify(h)
184
+ Xr = self.decode(h)
185
+ return Xr
186
+
187
+ def sparsify(self, h):
188
+ if self.spatial:
189
+ h = spatial_sparsity(h)
190
+ if self.channel:
191
+ h = strided_channel_sparsity(h, stride=self.channel_stride)
192
+ return h
193
+
194
+ class DeepConvAE(nn.Module):
195
+
196
+ def __init__(self, w=32, h=32, c=1, nb_filters=64, nb_layers=3, spatial=True, channel=True, channel_stride=4):
197
+ super().__init__()
198
+ self.spatial = spatial
199
+ self.channel = channel
200
+ self.channel_stride = channel_stride
201
+
202
+ layers = [
203
+ nn.Conv2d(c, nb_filters, 5, 1, 0),
204
+ nn.ReLU(True),
205
+ ]
206
+ for _ in range(nb_layers - 1):
207
+ layers.extend([
208
+ nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
209
+ nn.ReLU(True),
210
+ ])
211
+ self.encode = nn.Sequential(*layers)
212
+ layers = []
213
+ for _ in range(nb_layers - 1):
214
+ layers.extend([
215
+ nn.ConvTranspose2d(nb_filters, nb_filters, 5, 1, 0),
216
+ nn.ReLU(True),
217
+ ])
218
+ layers.extend([
219
+ nn.ConvTranspose2d(nb_filters, c, 5, 1, 0),
220
+ nn.Sigmoid()
221
+ ])
222
+ self.decode = nn.Sequential(*layers)
223
+ self.apply(_weights_init)
224
+
225
+ def forward(self, X):
226
+ size = X.size()
227
+ h = self.encode(X)
228
+ h = self.sparsify(h)
229
+ Xr = self.decode(h)
230
+ return Xr
231
+
232
+ def sparsify(self, h):
233
+ if self.spatial:
234
+ h = spatial_sparsity(h)
235
+ if self.channel:
236
+ h = strided_channel_sparsity(h, stride=self.channel_stride)
237
+ return h
238
+
239
+
240
+ def spatial_sparsity(x):
241
+ maxes = x.amax(dim=(2,3), keepdims=True)
242
+ return x * equals(x, maxes)
243
+
244
+ def equals(x, y, eps=1e-8):
245
+ return torch.abs(x-y) <= eps
246
+
247
+ def strided_channel_sparsity(x, stride=1):
248
+ B, F = x.shape[0:2]
249
+ h, w = x.shape[2:]
250
+ x_ = x.view(B, F, h // stride, stride, w // stride, stride)
251
+ mask = equals(x_, x_.amax(axis=(1, 3, 5), keepdims=True))
252
+ mask = mask.view(x.shape).float()
253
+ return x * mask
254
+
255
+
256
+ def _weights_init(m):
257
+ if hasattr(m, 'weight'):
258
+ xavier_uniform(m.weight.data)
259
+ if m.bias is not None:
260
+ m.bias.data.fill_(0)
test.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from machinedesign.autoencoder.interface import load
4
+ from keras.models import Model
5
+ torch.use_deterministic_algorithms(True)
6
+ model = torch.load("mnist_deepconvae/model.th")
7
+ model_keras = load("/home/mehdi/work/code/out_of_class/ae/mnist")
8
+ print(model_keras.layers[8])
9
+
10
+ m = Model(model_keras.inputs, model_keras.layers[8].output)
11
+ X = torch.rand(1,1,28,28)
12
+ with torch.no_grad():
13
+ # X1 = model.sparsify(model.encode(X))
14
+ X1 = model(X)
15
+ X2 = model_keras.predict(X)
16
+ X2 = torch.from_numpy(X2)
17
+ print(torch.abs(X1-X2).sum())
18
+ # for i in range(128):
19
+ # print(i, torch.abs(X1[0,i]-X2[0,i]).sum())
20
+ # print(X1[0,i, 0, :])
21
+ # print(X2[0,i,0, :])
viz.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains common visualization functions
3
+ used to report results of the models.
4
+ """
5
+
6
+ from functools import partial
7
+ import numpy as np
8
+
9
+
10
+ def horiz_merge(left, right):
11
+ """
12
+ merges two images, left and right horizontally to obtain
13
+ a bigger image containing both.
14
+
15
+ Parameters
16
+ ---------
17
+ left: 2D or 3D numpy array
18
+ left image.
19
+ 2D for grayscale.
20
+ 3D for color.
21
+ right : numpy array array
22
+ right image.
23
+ 2D for grayscale
24
+ 3D for color.
25
+
26
+ Returns
27
+ -------
28
+
29
+ numpy array (2D or 3D depending on left and right)
30
+ """
31
+ assert left.shape[0] == right.shape[0]
32
+ assert left.shape[2:] == right.shape[2:]
33
+ shape = (left.shape[0], left.shape[1] + right.shape[1],) + left.shape[2:]
34
+ im_merge = np.zeros(shape)
35
+ im_merge[:, 0:left.shape[1]] = left
36
+ im_merge[:, left.shape[1]:] = right
37
+ return im_merge
38
+
39
+ def vert_merge(top, bottom):
40
+ """
41
+ merges two images, top and bottom vertically to obtain
42
+ a bigger image containing both.
43
+
44
+ Parameters
45
+ ---------
46
+ top: 2D or 3D numpy array
47
+ top image.
48
+ 2D for grayscale.
49
+ 3D for color.
50
+ bottom : numpy array array
51
+ bottom image.
52
+ 2D for grayscale
53
+ 3D for color.
54
+
55
+ Returns
56
+ -------
57
+
58
+ numpy array (2D or 3D depending on left and right)
59
+ """
60
+ im = horiz_merge(top, bottom)
61
+ if len(im.shape) == 2:
62
+ im = im.transpose((1, 0))
63
+ elif len(im.shape) == 3:
64
+ im = im.transpose((1, 0, 2))
65
+ return im
66
+
67
+
68
+ def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normalize=False):
69
+ """
70
+ Draw a grid of images from M
71
+ The order in the grid which corresponds to the order in M
72
+ is starting from top to bottom then left to right.
73
+
74
+ Parameters
75
+ ----------
76
+
77
+ M : numpy array
78
+ if 3D, convert it to 4D, the shape will be interpreted as (nb_images, h, w) and converted to (nb_images, 1, h, w).
79
+ if 4D, consider it as colored or grayscale
80
+ - if the shape is (nb_images, nb_colors, h, w), it is converted to (nb_images, h, w, nb_colors)
81
+ - otherwise, if it already (nb_images, h, w, nb_colors), use it as it is.
82
+ - nb_colors can be 1 (grayscale) or 3 (colors).
83
+ border: int
84
+ thickness of border(default=0)
85
+ shape: tuple (nb_cols, nb_rows)
86
+ shape of the grid
87
+ by default make a square shape
88
+ (in that case, it is possible that not all images from M will be part of the grid).
89
+ normalize: bool(default=False)
90
+ whether to normalize the pixel values of each image independently
91
+ by min and max. if False, clip the values of pixels to 0 and 1
92
+ without normalizing.
93
+
94
+ Returns
95
+ -------
96
+
97
+ 3D numpy array of shape (h, w, 3)
98
+ (with a color channel regardless of whether the original images were grayscale or colored)
99
+ """
100
+ if len(M.shape) == 3:
101
+ M = M[:, :, :, np.newaxis]
102
+ if M.shape[-1] not in (1, 3):
103
+ M = M.transpose((0, 2, 3, 1))
104
+ if M.shape[-1] == 1:
105
+ M = np.ones((1, 1, 1, 3)) * M
106
+ bordercolor = np.array(bordercolor)[None, None, :]
107
+ numimages = len(M)
108
+ M = M.copy()
109
+
110
+ if normalize:
111
+ for i in range(M.shape[0]):
112
+ M[i] -= M[i].flatten().min()
113
+ M[i] /= M[i].flatten().max()
114
+ else:
115
+ M = np.clip(M, 0, 1)
116
+ height, width, color = M[0].shape
117
+ assert color == 3, 'Nb of color channels are {}'.format(color)
118
+ if shape is None:
119
+ n0 = np.int(np.ceil(np.sqrt(numimages)))
120
+ n1 = np.int(np.ceil(np.sqrt(numimages)))
121
+ else:
122
+ n0 = shape[0]
123
+ n1 = shape[1]
124
+
125
+ im = np.array(bordercolor) * np.ones(
126
+ ((height + border) * n1 + border, (width + border) * n0 + border, 1), dtype='<f8')
127
+ # shape = (n0, n1)
128
+ # j corresponds to rows in the grid, n1 should correspond to nb of rows
129
+ # i corresponds to columns in the grid, n0 should correspond to nb of cols
130
+ # M should be such that the first n1 examples correspond to row 1,
131
+ # next n1 examples correspond to row 2, etc. that is, M first axis
132
+ # can be reshaped to (n1, n0)
133
+ for i in range(n0):
134
+ for j in range(n1):
135
+ if i * n1 + j < numimages:
136
+ im[j * (height + border) + border:(j + 1) * (height + border) + border,
137
+ i * (width + border) + border:(i + 1) * (width + border) + border, :] = np.concatenate((
138
+ np.concatenate((M[i * n1 + j, :, :, :],
139
+ bordercolor * np.ones((height, border, 3), dtype=float)), 1),
140
+ bordercolor * np.ones((border, width + border, 3), dtype=float)
141
+ ), 0)
142
+ return im
143
+
144
+ grid_of_images_default = partial(grid_of_images, border=1, bordercolor=(0.3, 0, 0))
145
+
146
+
147
+ def reshape_to_images(x, input_shape=None):
148
+ """
149
+ a function that takes a numpy array and try to
150
+ reshape it to an array of images that would
151
+ be compatible with the function grid_of_images.
152
+ Two cases are considered.
153
+
154
+ if x is a 2D numpy array, it uses input_shape:
155
+ - x can either be (nb_examples, nb_features) or (nb_features, nb_examples)
156
+ - nb_features should be prod(input_shape)
157
+ - the nb_features dim is then expanded to have :
158
+ (nb_examples, h, w, nb_channels), sorted input_shape shoud
159
+ be (h, w, nb_channels).
160
+
161
+ if x is a 4D numpy array:
162
+ - if the first tensor dim is 1 or 3 like e.g. (1, a, b, c), then assume it is
163
+ color channel and transform to (a, 1, b, c)
164
+ - if the second tensor dim is 1 or 3, leave x it as it is
165
+ - if the third tensor dim is 1 or 3, like e.g. (a, b, 1, c), then assume it is
166
+ color channel and transform to (c, 1, a, b)
167
+ - if the fourth tensor dim is 1 or 3, like e.g. (a, b, c, 1), then assume it is
168
+ color channel and transform to (c, 1, a, b)
169
+ Parameters
170
+ ----------
171
+
172
+ x : numpy array
173
+ input to be reshape
174
+ input_shape : tuple needed only when x is 2D numpy array
175
+ """
176
+ if len(x.shape) == 2:
177
+ assert input_shape is not None
178
+ if x.shape[0] == np.prod(input_shape):
179
+ x = x.T
180
+ x = x.reshape((x.shape[0],) + input_shape)
181
+ x = x.transpose((0, 2, 3, 1))
182
+ return x
183
+ elif x.shape[1] == np.prod(input_shape):
184
+ x = x.reshape((x.shape[0],) + input_shape)
185
+ x = x.transpose((0, 2, 3, 1))
186
+ return x
187
+ else:
188
+ raise ValueError('Cant recognize this shape : {}'.format(x.shape))
189
+ elif len(x.shape) == 4:
190
+ if x.shape[0] in (1, 3):
191
+ x = x.transpose((1, 0, 2, 3))
192
+ return x
193
+ elif x.shape[1] in (1, 3):
194
+ return x
195
+ elif x.shape[2] in (1, 3):
196
+ x = x.transpose((3, 2, 0, 1))
197
+ return x
198
+ elif x.shape[3] in (1, 3):
199
+ x = x.transpose((2, 3, 0, 1))
200
+ return x
201
+ else:
202
+ raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))
203
+ else:
204
+ raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))