add app and generation / model code
Browse files
app.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from cli import iterative_refinement
|
6 |
+
from viz import grid_of_images_default
|
7 |
+
from subprocess
|
8 |
+
subprocess.call("download_models.sh", shell=True)
|
9 |
+
models = {
|
10 |
+
"convae": torch.load("convae.th", map_location="cpu"),
|
11 |
+
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
12 |
+
}
|
13 |
+
|
14 |
+
def gen(model, seed, nb_iter, nb_samples, width, height):
|
15 |
+
torch.manual_seed(int(seed))
|
16 |
+
bs = 64
|
17 |
+
model = models[model]
|
18 |
+
samples = iterative_refinement(
|
19 |
+
model,
|
20 |
+
nb_iter=int(nb_iter),
|
21 |
+
nb_examples=int(nb_samples),
|
22 |
+
w=int(width), h=int(height), c=1,
|
23 |
+
batch_size=bs,
|
24 |
+
)
|
25 |
+
grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
|
26 |
+
grid = (grid*255).astype("uint8")
|
27 |
+
return Image.fromarray(grid)
|
28 |
+
|
29 |
+
iface = gr.Interface(
|
30 |
+
fn=gen,
|
31 |
+
inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)],
|
32 |
+
outputs="image"
|
33 |
+
)
|
34 |
+
iface.launch()
|
cli.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import matplotlib as mpl
|
3 |
+
mpl.use('Agg')
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from clize import run
|
8 |
+
import numpy as np
|
9 |
+
from skimage.io import imsave
|
10 |
+
|
11 |
+
from viz import grid_of_images_default
|
12 |
+
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from model import DenseAE
|
17 |
+
from model import ConvAE
|
18 |
+
from model import DeepConvAE
|
19 |
+
from model import SimpleConvAE
|
20 |
+
from model import ZAE
|
21 |
+
from model import KAE
|
22 |
+
from data import load_dataset
|
23 |
+
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
|
26 |
+
|
27 |
+
def plot_dataset(code_2d, categories):
|
28 |
+
colors = [
|
29 |
+
'r',
|
30 |
+
'b',
|
31 |
+
'g',
|
32 |
+
'crimson',
|
33 |
+
'gold',
|
34 |
+
'yellow',
|
35 |
+
'maroon',
|
36 |
+
'm',
|
37 |
+
'c',
|
38 |
+
'orange'
|
39 |
+
]
|
40 |
+
for cat in range(0, 10):
|
41 |
+
g = (categories == cat)
|
42 |
+
plt.scatter(
|
43 |
+
code_2d[g, 0],
|
44 |
+
code_2d[g, 1],
|
45 |
+
marker='+',
|
46 |
+
c=colors[cat],
|
47 |
+
s=40,
|
48 |
+
alpha=0.7,
|
49 |
+
label="digit {}".format(cat)
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def plot_generated(code_2d, categories):
|
54 |
+
g = (categories < 0)
|
55 |
+
plt.scatter(
|
56 |
+
code_2d[g, 0],
|
57 |
+
code_2d[g, 1],
|
58 |
+
marker='+',
|
59 |
+
c='gray',
|
60 |
+
s=30
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def grid_embedding(h):
|
65 |
+
from lapjv import lapjv
|
66 |
+
from scipy.spatial.distance import cdist
|
67 |
+
assert int(np.sqrt(h.shape[0])) ** 2 == h.shape[0], 'Nb of examples must be a square number'
|
68 |
+
size = int(np.sqrt(h.shape[0]))
|
69 |
+
grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
|
70 |
+
cost_matrix = cdist(grid, h, "sqeuclidean").astype('float32')
|
71 |
+
cost_matrix = cost_matrix * (100000 / cost_matrix.max())
|
72 |
+
_, rows, cols = lapjv(cost_matrix)
|
73 |
+
return rows
|
74 |
+
|
75 |
+
|
76 |
+
def save_weights(m, folder='.'):
|
77 |
+
if isinstance(m, nn.Linear):
|
78 |
+
w = m.weight.data
|
79 |
+
if w.size(1) == 28*28 or w.size(0) == 28*28:
|
80 |
+
w0, w1 = w.size(0), w.size(1)
|
81 |
+
if w0 == 28*28:
|
82 |
+
w = w.transpose(0, 1)
|
83 |
+
w = w.contiguous()
|
84 |
+
w = w.view(w.size(0), 1, 28, 28)
|
85 |
+
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
|
86 |
+
imsave('{}/feat_{}.png'.format(folder, w0), gr)
|
87 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
88 |
+
w = m.weight.data
|
89 |
+
if w.size(0) in (32, 64, 128, 256, 512) and w.size(1) in (1, 3):
|
90 |
+
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
|
91 |
+
imsave('{}/feat.png'.format(folder), gr)
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None):
|
95 |
+
if batch_size is None:
|
96 |
+
batch_size = nb_examples
|
97 |
+
x = torch.rand(nb_iter, nb_examples, c, w, h)
|
98 |
+
for i in range(1, nb_iter):
|
99 |
+
for j in range(0, nb_examples, batch_size):
|
100 |
+
oldv = x[i-1][j:j + batch_size].to(device)
|
101 |
+
newv = ae(oldv)
|
102 |
+
newv = newv.data.cpu()
|
103 |
+
x[i][j:j + batch_size] = newv
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def build_model(name, w, h, c):
|
108 |
+
if name == 'convae':
|
109 |
+
ae = ConvAE(
|
110 |
+
w=w, h=h, c=c,
|
111 |
+
nb_filters=128,
|
112 |
+
spatial=True,
|
113 |
+
channel=True,
|
114 |
+
channel_stride=4,
|
115 |
+
)
|
116 |
+
elif name == 'zae':
|
117 |
+
ae = ZAE(
|
118 |
+
w=w, h=h, c=c,
|
119 |
+
theta=3,
|
120 |
+
nb_hidden=1000,
|
121 |
+
)
|
122 |
+
elif name == 'kae':
|
123 |
+
ae = KAE(
|
124 |
+
w=w, h=h, c=c,
|
125 |
+
nb_active=1000,
|
126 |
+
nb_hidden=1000,
|
127 |
+
)
|
128 |
+
elif name == 'denseae':
|
129 |
+
ae = DenseAE(
|
130 |
+
w=w, h=h, c=c,
|
131 |
+
encode_hidden=[1000],
|
132 |
+
decode_hidden=[],
|
133 |
+
ksparse=True,
|
134 |
+
nb_active=50,
|
135 |
+
)
|
136 |
+
elif name == 'simple_convae':
|
137 |
+
ae = SimpleConvAE(
|
138 |
+
w=w, h=h, c=c,
|
139 |
+
nb_filters=128,
|
140 |
+
)
|
141 |
+
elif name == 'deep_convae':
|
142 |
+
ae = DeepConvAE(
|
143 |
+
w=w, h=h, c=c,
|
144 |
+
nb_filters=128,
|
145 |
+
spatial=True,
|
146 |
+
channel=True,
|
147 |
+
channel_stride=4,
|
148 |
+
nb_layers=3,
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
raise ValueError('Unknown model')
|
152 |
+
|
153 |
+
return ae
|
154 |
+
|
155 |
+
|
156 |
+
def salt_and_pepper(X, proba=0.5):
|
157 |
+
a = (torch.rand(X.size()).to(device) <= (1 - proba)).float()
|
158 |
+
b = (torch.rand(X.size()).to(device) <= 0.5).float()
|
159 |
+
c = ((a == 0).float() * b)
|
160 |
+
return X * a + c
|
161 |
+
|
162 |
+
|
163 |
+
def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walkback=False, denoise=False, epochs=100, batch_size=64, log_interval=100):
|
164 |
+
gamma = 0.99
|
165 |
+
dataset = load_dataset(dataset, split='train')
|
166 |
+
x0, _ = dataset[0]
|
167 |
+
c, h, w = x0.size()
|
168 |
+
dataloader = torch.utils.data.DataLoader(
|
169 |
+
dataset,
|
170 |
+
batch_size=batch_size,
|
171 |
+
shuffle=True,
|
172 |
+
num_workers=4
|
173 |
+
)
|
174 |
+
if resume:
|
175 |
+
ae = torch.load('{}/model.th'.format(folder))
|
176 |
+
ae = ae.to(device)
|
177 |
+
else:
|
178 |
+
ae = build_model(model, w=w, h=h, c=c)
|
179 |
+
ae = ae.to(device)
|
180 |
+
optim = torch.optim.Adadelta(ae.parameters(), lr=0.1, eps=1e-7, rho=0.95, weight_decay=0)
|
181 |
+
avg_loss = 0.
|
182 |
+
nb_updates = 0
|
183 |
+
_save_weights = partial(save_weights, folder=folder)
|
184 |
+
|
185 |
+
for epoch in range(epochs):
|
186 |
+
for X, y in dataloader:
|
187 |
+
ae.zero_grad()
|
188 |
+
X = X.to(device)
|
189 |
+
if hasattr(ae, 'nb_active'):
|
190 |
+
ae.nb_active = max(ae.nb_active - 1, 32)
|
191 |
+
# walkback + denoise
|
192 |
+
if walkback:
|
193 |
+
loss = 0.
|
194 |
+
x = X.data
|
195 |
+
nb = 5
|
196 |
+
for _ in range(nb):
|
197 |
+
x = salt_and_pepper(x, proba=0.3) # denoise
|
198 |
+
x = x.to(device)
|
199 |
+
x = ae(x) # reconstruct
|
200 |
+
Xr = x
|
201 |
+
loss += (((x - X) ** 2).view(X.size(0), -1).sum(1).mean()) / nb
|
202 |
+
x = (torch.rand(x.size()).to(device) <= x.data).float() # sample
|
203 |
+
# denoise only
|
204 |
+
elif denoise:
|
205 |
+
Xc = salt_and_pepper(X.data, proba=0.3)
|
206 |
+
Xr = ae(Xc)
|
207 |
+
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
|
208 |
+
# normal training
|
209 |
+
else:
|
210 |
+
Xr = ae(X)
|
211 |
+
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
|
212 |
+
loss.backward()
|
213 |
+
optim.step()
|
214 |
+
avg_loss = avg_loss * gamma + loss.item() * (1 - gamma)
|
215 |
+
if nb_updates % log_interval == 0:
|
216 |
+
print('Epoch : {:05d} AvgTrainLoss: {:.6f}, Batch Loss : {:.6f}'.format(epoch, avg_loss, loss.item() ))
|
217 |
+
gr = grid_of_images_default(np.array(Xr.data.tolist()))
|
218 |
+
imsave('{}/rec.png'.format(folder), gr)
|
219 |
+
ae.apply(_save_weights)
|
220 |
+
torch.save(ae, '{}/model.th'.format(folder))
|
221 |
+
nb_updates += 1
|
222 |
+
|
223 |
+
|
224 |
+
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_generate=100, tsne=False):
|
225 |
+
if not os.path.exists(folder):
|
226 |
+
os.makedirs(folder, exist_ok=True)
|
227 |
+
dataset = load_dataset(dataset, split='train')
|
228 |
+
x0, _ = dataset[0]
|
229 |
+
c, h, w = x0.size()
|
230 |
+
nb = nb_generate
|
231 |
+
print('Load model...')
|
232 |
+
if model_path is None:
|
233 |
+
model_path = os.path.join(folder, "model.th")
|
234 |
+
ae = torch.load(model_path, map_location="cpu")
|
235 |
+
ae = ae.to(device)
|
236 |
+
def enc(X):
|
237 |
+
batch_size = 64
|
238 |
+
h_list = []
|
239 |
+
for i in range(0, X.size(0), batch_size):
|
240 |
+
x = X[i:i + batch_size]
|
241 |
+
x = x.to(device)
|
242 |
+
name = ae.__class__.__name__
|
243 |
+
if name in ('ConvAE',):
|
244 |
+
h = ae.encode(x)
|
245 |
+
h, _ = h.max(2)
|
246 |
+
h = h.view((h.size(0), -1))
|
247 |
+
elif name in ('DenseAE',):
|
248 |
+
x = x.view(x.size(0), -1)
|
249 |
+
h = x
|
250 |
+
#h = ae.encode(x)
|
251 |
+
else:
|
252 |
+
h = x.view(x.size(0), -1)
|
253 |
+
h = h.data.cpu()
|
254 |
+
h_list.append(h)
|
255 |
+
return torch.cat(h_list, 0)
|
256 |
+
|
257 |
+
print('iterative refinement...')
|
258 |
+
g = iterative_refinement(
|
259 |
+
ae,
|
260 |
+
nb_iter=nb_iter,
|
261 |
+
nb_examples=nb,
|
262 |
+
w=w, h=h, c=c,
|
263 |
+
batch_size=64
|
264 |
+
)
|
265 |
+
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
266 |
+
g_subset = g[:, 0:100]
|
267 |
+
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
268 |
+
imsave('{}/gen_full_iters.png'.format(folder), gr)
|
269 |
+
|
270 |
+
g = g[-1] # last iter
|
271 |
+
print(g.shape)
|
272 |
+
gr = grid_of_images_default(g.numpy())
|
273 |
+
imsave('{}/gen_full.png'.format(folder), gr)
|
274 |
+
|
275 |
+
if tsne:
|
276 |
+
from sklearn.manifold import TSNE
|
277 |
+
dataloader = torch.utils.data.DataLoader(
|
278 |
+
dataset,
|
279 |
+
batch_size=nb,
|
280 |
+
shuffle=True,
|
281 |
+
num_workers=1
|
282 |
+
)
|
283 |
+
print('Load data...')
|
284 |
+
X, y = next(iter(dataloader))
|
285 |
+
print('Encode data...')
|
286 |
+
xh = enc(X)
|
287 |
+
print('Encode generated...')
|
288 |
+
gh = enc(g)
|
289 |
+
X = X.numpy()
|
290 |
+
g = g.numpy()
|
291 |
+
xh = xh.numpy()
|
292 |
+
gh = gh.numpy()
|
293 |
+
|
294 |
+
a = np.concatenate((X, g), axis=0)
|
295 |
+
ah = np.concatenate((xh, gh), axis=0)
|
296 |
+
labels = np.array(y.tolist() + [-1] * len(g))
|
297 |
+
sne = TSNE()
|
298 |
+
print('fit tsne...')
|
299 |
+
ah = sne.fit_transform(ah)
|
300 |
+
print('grid embedding...')
|
301 |
+
|
302 |
+
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
303 |
+
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
304 |
+
rows = grid_embedding(ahsmall)
|
305 |
+
asmall = asmall[rows]
|
306 |
+
gr = grid_of_images_default(asmall)
|
307 |
+
imsave('{}/sne_grid.png'.format(folder), gr)
|
308 |
+
|
309 |
+
fig = plt.figure(figsize=(10, 10))
|
310 |
+
plot_dataset(ah, labels)
|
311 |
+
plot_generated(ah, labels)
|
312 |
+
plt.legend(loc='best')
|
313 |
+
plt.axis('off')
|
314 |
+
plt.savefig('{}/sne.png'.format(folder))
|
315 |
+
plt.close(fig)
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == '__main__':
|
320 |
+
run([train, test])
|
convert.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch, h5py
|
3 |
+
from model import *
|
4 |
+
w, h, c = 28, 28, 1
|
5 |
+
model_new = DeepConvAE(
|
6 |
+
w=w, h=h, c=c,
|
7 |
+
nb_filters=128,
|
8 |
+
spatial=True,
|
9 |
+
channel=True,
|
10 |
+
channel_stride=4,
|
11 |
+
# total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder
|
12 |
+
nb_layers=3,
|
13 |
+
)
|
14 |
+
# model_old = h5py.File("mnist_deepconvae/model.h5")
|
15 |
+
model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5")
|
16 |
+
|
17 |
+
|
18 |
+
print(model_new)
|
19 |
+
print(model_old["model_weights"].keys())
|
20 |
+
|
21 |
+
|
22 |
+
for name, param in model_new.named_parameters():
|
23 |
+
enc_or_decode, layer_id, bias_or_kernel = name.split(".")
|
24 |
+
|
25 |
+
if enc_or_decode == "encode":
|
26 |
+
layer_name = "conv2d"
|
27 |
+
else:
|
28 |
+
layer_name = "up_conv2d"
|
29 |
+
|
30 |
+
layer_id = (int(layer_id)//2) + 1
|
31 |
+
|
32 |
+
full_layer_name = f"{layer_name}_{layer_id}"
|
33 |
+
print(full_layer_name)
|
34 |
+
|
35 |
+
k = "kernel" if bias_or_kernel == "weight" else "bias"
|
36 |
+
weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()]
|
37 |
+
weights = np.array(weights)
|
38 |
+
weights = torch.from_numpy(weights)
|
39 |
+
print(name, layer_id, param.shape, weights.shape)
|
40 |
+
inds = [4,3,2,1,0]
|
41 |
+
if k == "kernel":
|
42 |
+
if layer_name == "conv2d":
|
43 |
+
weights = weights.permute((3,2,0,1))
|
44 |
+
weights = weights[:,:,inds]
|
45 |
+
weights = weights[:,:,:, inds]
|
46 |
+
print("W", weights.shape)
|
47 |
+
elif layer_name == "up_conv2d":
|
48 |
+
weights = weights.permute((2,3,0,1))
|
49 |
+
print(param.shape, weights.shape)
|
50 |
+
param.data.copy_(weights)
|
51 |
+
print((param-weights).sum())
|
52 |
+
torch.save(model_new, "mnist_deepconvae/model.th")
|
data.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torchvision.datasets as dset
|
5 |
+
|
6 |
+
|
7 |
+
class Invert:
|
8 |
+
def __call__(self, x):
|
9 |
+
return 1 - x
|
10 |
+
|
11 |
+
class Gray:
|
12 |
+
def __call__(self, x):
|
13 |
+
return x[0:1]
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def load_dataset(dataset_name, split='full'):
|
18 |
+
if dataset_name == 'mnist':
|
19 |
+
dataset = dset.MNIST(
|
20 |
+
root='data/mnist',
|
21 |
+
download=True,
|
22 |
+
transform=transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
])
|
25 |
+
)
|
26 |
+
return dataset
|
27 |
+
elif dataset_name == 'coco':
|
28 |
+
dataset = dset.ImageFolder(root='data/coco',
|
29 |
+
transform=transforms.Compose([
|
30 |
+
transforms.Scale(64),
|
31 |
+
transforms.CenterCrop(64),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
]))
|
34 |
+
return dataset
|
35 |
+
elif dataset_name == 'quickdraw':
|
36 |
+
X = (np.load('data/quickdraw/teapot.npy'))
|
37 |
+
X = X.reshape((X.shape[0], 28, 28))
|
38 |
+
X = X / 255.
|
39 |
+
X = X.astype(np.float32)
|
40 |
+
X = torch.from_numpy(X)
|
41 |
+
dataset = TensorDataset(X, X)
|
42 |
+
return dataset
|
43 |
+
elif dataset_name == 'shoes':
|
44 |
+
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes',
|
45 |
+
transform=transforms.Compose([
|
46 |
+
transforms.Scale(64),
|
47 |
+
transforms.CenterCrop(64),
|
48 |
+
transforms.ToTensor(),
|
49 |
+
]))
|
50 |
+
return dataset
|
51 |
+
elif dataset_name == 'footwear':
|
52 |
+
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images',
|
53 |
+
transform=transforms.Compose([
|
54 |
+
transforms.Scale(64),
|
55 |
+
transforms.CenterCrop(64),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
]))
|
58 |
+
return dataset
|
59 |
+
elif dataset_name == 'celeba':
|
60 |
+
dataset = dset.ImageFolder(root='data/celeba',
|
61 |
+
transform=transforms.Compose([
|
62 |
+
transforms.Scale(32),
|
63 |
+
transforms.CenterCrop(32),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
]))
|
66 |
+
return dataset
|
67 |
+
elif dataset_name == 'birds':
|
68 |
+
dataset = dset.ImageFolder(root='data/birds/'+split,
|
69 |
+
transform=transforms.Compose([
|
70 |
+
transforms.Scale(32),
|
71 |
+
transforms.CenterCrop(32),
|
72 |
+
transforms.ToTensor(),
|
73 |
+
]))
|
74 |
+
return dataset
|
75 |
+
elif dataset_name == 'sketchy':
|
76 |
+
dataset = dset.ImageFolder(root='data/sketchy/'+split,
|
77 |
+
transform=transforms.Compose([
|
78 |
+
transforms.Scale(64),
|
79 |
+
transforms.CenterCrop(64),
|
80 |
+
transforms.ToTensor(),
|
81 |
+
Gray()
|
82 |
+
]))
|
83 |
+
return dataset
|
84 |
+
|
85 |
+
elif dataset_name == 'fonts':
|
86 |
+
dataset = dset.ImageFolder(root='data/fonts/'+split,
|
87 |
+
transform=transforms.Compose([
|
88 |
+
transforms.ToTensor(),
|
89 |
+
Invert(),
|
90 |
+
Gray(),
|
91 |
+
]))
|
92 |
+
return dataset
|
93 |
+
else:
|
94 |
+
raise ValueError('Error : unknown dataset')
|
model.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn.init import xavier_uniform
|
5 |
+
|
6 |
+
class KAE(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, w=32, h=32, c=1, nb_hidden=300, nb_active=16):
|
9 |
+
super().__init__()
|
10 |
+
self.nb_hidden = nb_hidden
|
11 |
+
self.nb_active = nb_active
|
12 |
+
self.encode = nn.Sequential(
|
13 |
+
nn.Linear(w*h*c, nb_hidden, bias=False)
|
14 |
+
)
|
15 |
+
self.bias = nn.Parameter(torch.zeros(w*h*c))
|
16 |
+
self.params = nn.ParameterList([self.bias])
|
17 |
+
self.apply(_weights_init)
|
18 |
+
|
19 |
+
def forward(self, X):
|
20 |
+
size = X.size()
|
21 |
+
X = X.view(X.size(0), -1)
|
22 |
+
h = self.encode(X)
|
23 |
+
Xr, _ = self.decode(h)
|
24 |
+
Xr = Xr.view(size)
|
25 |
+
return Xr
|
26 |
+
|
27 |
+
def decode(self, h):
|
28 |
+
thetas, _ = torch.sort(h, dim=1, descending=True)
|
29 |
+
thetas = thetas[:, self.nb_active:self.nb_active+1]
|
30 |
+
h = h * (h > thetas).float()
|
31 |
+
Xr = torch.matmul(h, self.encode[0].weight) + self.bias
|
32 |
+
Xr = nn.Sigmoid()(Xr)
|
33 |
+
return Xr, h
|
34 |
+
|
35 |
+
|
36 |
+
class ZAE(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self, w=32, h=32, c=1, nb_hidden=300, theta=1):
|
39 |
+
super().__init__()
|
40 |
+
self.nb_hidden = nb_hidden
|
41 |
+
self.theta = theta
|
42 |
+
self.encode = nn.Sequential(
|
43 |
+
nn.Linear(w*h*c, nb_hidden, bias=False)
|
44 |
+
)
|
45 |
+
self.bias = nn.Parameter(torch.zeros(w*h*c))
|
46 |
+
self.params = nn.ParameterList([self.bias])
|
47 |
+
self.apply(_weights_init)
|
48 |
+
|
49 |
+
def forward(self, X):
|
50 |
+
size = X.size()
|
51 |
+
X = X.view(X.size(0), -1)
|
52 |
+
h = self.encode(X)
|
53 |
+
Xr, _ = self.decode(h)
|
54 |
+
Xr = Xr.view(size)
|
55 |
+
return Xr
|
56 |
+
|
57 |
+
def decode(self, h):
|
58 |
+
h = h * (h > self.theta).float()
|
59 |
+
Xr = torch.matmul(h, self.encode[0].weight) + self.bias
|
60 |
+
Xr = nn.Sigmoid()(Xr)
|
61 |
+
return Xr, h
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
class DenseAE(nn.Module):
|
66 |
+
|
67 |
+
def __init__(self, w=32, h=32, c=1, encode_hidden=(300,), decode_hidden=(300,), ksparse=True, nb_active=10, denoise=None):
|
68 |
+
super().__init__()
|
69 |
+
self.encode_hidden = encode_hidden
|
70 |
+
self.decode_hidden = decode_hidden
|
71 |
+
self.ksparse = ksparse
|
72 |
+
self.nb_active = nb_active
|
73 |
+
self.denoise = denoise
|
74 |
+
|
75 |
+
# encode layers
|
76 |
+
layers = []
|
77 |
+
hid_prev = w * h * c
|
78 |
+
for hid in encode_hidden:
|
79 |
+
layers.extend([
|
80 |
+
nn.Linear(hid_prev, hid),
|
81 |
+
nn.ReLU(True)
|
82 |
+
])
|
83 |
+
hid_prev = hid
|
84 |
+
self.encode = nn.Sequential(*layers)
|
85 |
+
|
86 |
+
# decode layers
|
87 |
+
layers = []
|
88 |
+
for hid in decode_hidden:
|
89 |
+
layers.extend([
|
90 |
+
nn.Linear(hid_prev, hid),
|
91 |
+
nn.ReLU(True)
|
92 |
+
])
|
93 |
+
hid_prev = hid
|
94 |
+
layers.extend([
|
95 |
+
nn.Linear(hid_prev, w * h * c),
|
96 |
+
nn.Sigmoid()
|
97 |
+
])
|
98 |
+
self.decode = nn.Sequential(*layers)
|
99 |
+
|
100 |
+
self.apply(_weights_init)
|
101 |
+
|
102 |
+
def forward(self, X):
|
103 |
+
size = X.size()
|
104 |
+
if self.denoise is not None:
|
105 |
+
X = X * ((torch.rand(X.size()) <= self.denoise).float()).to(X.device)
|
106 |
+
X = X.view(X.size(0), -1)
|
107 |
+
h = self.encode(X)
|
108 |
+
if self.ksparse:
|
109 |
+
h = ksparse(h, nb_active=self.nb_active)
|
110 |
+
Xr = self.decode(h)
|
111 |
+
Xr = Xr.view(size)
|
112 |
+
return Xr
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
def ksparse(x, nb_active=10):
|
117 |
+
mask = torch.ones(x.size())
|
118 |
+
for i, xi in enumerate(x.data.tolist()):
|
119 |
+
inds = np.argsort(xi)
|
120 |
+
inds = inds[::-1]
|
121 |
+
inds = inds[nb_active:]
|
122 |
+
if len(inds):
|
123 |
+
inds = np.array(inds)
|
124 |
+
inds = torch.from_numpy(inds).long()
|
125 |
+
mask[i][inds] = 0
|
126 |
+
return x * (mask).float().to(x.device)
|
127 |
+
|
128 |
+
|
129 |
+
class ConvAE(nn.Module):
|
130 |
+
|
131 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
|
132 |
+
super().__init__()
|
133 |
+
self.spatial = spatial
|
134 |
+
self.channel = channel
|
135 |
+
self.channel_stride = channel_stride
|
136 |
+
self.encode = nn.Sequential(
|
137 |
+
nn.Conv2d(c, nb_filters, 5, 1, 0),
|
138 |
+
nn.ReLU(True),
|
139 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
140 |
+
nn.ReLU(True),
|
141 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
142 |
+
)
|
143 |
+
self.decode = nn.Sequential(
|
144 |
+
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
|
145 |
+
nn.Sigmoid()
|
146 |
+
)
|
147 |
+
self.apply(_weights_init)
|
148 |
+
|
149 |
+
def forward(self, X):
|
150 |
+
size = X.size()
|
151 |
+
h = self.encode(X)
|
152 |
+
h = self.sparsify(h)
|
153 |
+
Xr = self.decode(h)
|
154 |
+
return Xr
|
155 |
+
|
156 |
+
def sparsify(self, h):
|
157 |
+
if self.spatial:
|
158 |
+
h = spatial_sparsity(h)
|
159 |
+
if self.channel:
|
160 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
161 |
+
return h
|
162 |
+
|
163 |
+
class SimpleConvAE(nn.Module):
|
164 |
+
|
165 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
|
166 |
+
super().__init__()
|
167 |
+
self.spatial = spatial
|
168 |
+
self.channel = channel
|
169 |
+
self.channel_stride = channel_stride
|
170 |
+
self.encode = nn.Sequential(
|
171 |
+
nn.Conv2d(c, nb_filters, 13, 1, 0),
|
172 |
+
nn.ReLU(True),
|
173 |
+
)
|
174 |
+
self.decode = nn.Sequential(
|
175 |
+
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
|
176 |
+
nn.Sigmoid()
|
177 |
+
)
|
178 |
+
self.apply(_weights_init)
|
179 |
+
|
180 |
+
def forward(self, X):
|
181 |
+
size = X.size()
|
182 |
+
h = self.encode(X)
|
183 |
+
h = self.sparsify(h)
|
184 |
+
Xr = self.decode(h)
|
185 |
+
return Xr
|
186 |
+
|
187 |
+
def sparsify(self, h):
|
188 |
+
if self.spatial:
|
189 |
+
h = spatial_sparsity(h)
|
190 |
+
if self.channel:
|
191 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
192 |
+
return h
|
193 |
+
|
194 |
+
class DeepConvAE(nn.Module):
|
195 |
+
|
196 |
+
def __init__(self, w=32, h=32, c=1, nb_filters=64, nb_layers=3, spatial=True, channel=True, channel_stride=4):
|
197 |
+
super().__init__()
|
198 |
+
self.spatial = spatial
|
199 |
+
self.channel = channel
|
200 |
+
self.channel_stride = channel_stride
|
201 |
+
|
202 |
+
layers = [
|
203 |
+
nn.Conv2d(c, nb_filters, 5, 1, 0),
|
204 |
+
nn.ReLU(True),
|
205 |
+
]
|
206 |
+
for _ in range(nb_layers - 1):
|
207 |
+
layers.extend([
|
208 |
+
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
|
209 |
+
nn.ReLU(True),
|
210 |
+
])
|
211 |
+
self.encode = nn.Sequential(*layers)
|
212 |
+
layers = []
|
213 |
+
for _ in range(nb_layers - 1):
|
214 |
+
layers.extend([
|
215 |
+
nn.ConvTranspose2d(nb_filters, nb_filters, 5, 1, 0),
|
216 |
+
nn.ReLU(True),
|
217 |
+
])
|
218 |
+
layers.extend([
|
219 |
+
nn.ConvTranspose2d(nb_filters, c, 5, 1, 0),
|
220 |
+
nn.Sigmoid()
|
221 |
+
])
|
222 |
+
self.decode = nn.Sequential(*layers)
|
223 |
+
self.apply(_weights_init)
|
224 |
+
|
225 |
+
def forward(self, X):
|
226 |
+
size = X.size()
|
227 |
+
h = self.encode(X)
|
228 |
+
h = self.sparsify(h)
|
229 |
+
Xr = self.decode(h)
|
230 |
+
return Xr
|
231 |
+
|
232 |
+
def sparsify(self, h):
|
233 |
+
if self.spatial:
|
234 |
+
h = spatial_sparsity(h)
|
235 |
+
if self.channel:
|
236 |
+
h = strided_channel_sparsity(h, stride=self.channel_stride)
|
237 |
+
return h
|
238 |
+
|
239 |
+
|
240 |
+
def spatial_sparsity(x):
|
241 |
+
maxes = x.amax(dim=(2,3), keepdims=True)
|
242 |
+
return x * equals(x, maxes)
|
243 |
+
|
244 |
+
def equals(x, y, eps=1e-8):
|
245 |
+
return torch.abs(x-y) <= eps
|
246 |
+
|
247 |
+
def strided_channel_sparsity(x, stride=1):
|
248 |
+
B, F = x.shape[0:2]
|
249 |
+
h, w = x.shape[2:]
|
250 |
+
x_ = x.view(B, F, h // stride, stride, w // stride, stride)
|
251 |
+
mask = equals(x_, x_.amax(axis=(1, 3, 5), keepdims=True))
|
252 |
+
mask = mask.view(x.shape).float()
|
253 |
+
return x * mask
|
254 |
+
|
255 |
+
|
256 |
+
def _weights_init(m):
|
257 |
+
if hasattr(m, 'weight'):
|
258 |
+
xavier_uniform(m.weight.data)
|
259 |
+
if m.bias is not None:
|
260 |
+
m.bias.data.fill_(0)
|
test.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from machinedesign.autoencoder.interface import load
|
4 |
+
from keras.models import Model
|
5 |
+
torch.use_deterministic_algorithms(True)
|
6 |
+
model = torch.load("mnist_deepconvae/model.th")
|
7 |
+
model_keras = load("/home/mehdi/work/code/out_of_class/ae/mnist")
|
8 |
+
print(model_keras.layers[8])
|
9 |
+
|
10 |
+
m = Model(model_keras.inputs, model_keras.layers[8].output)
|
11 |
+
X = torch.rand(1,1,28,28)
|
12 |
+
with torch.no_grad():
|
13 |
+
# X1 = model.sparsify(model.encode(X))
|
14 |
+
X1 = model(X)
|
15 |
+
X2 = model_keras.predict(X)
|
16 |
+
X2 = torch.from_numpy(X2)
|
17 |
+
print(torch.abs(X1-X2).sum())
|
18 |
+
# for i in range(128):
|
19 |
+
# print(i, torch.abs(X1[0,i]-X2[0,i]).sum())
|
20 |
+
# print(X1[0,i, 0, :])
|
21 |
+
# print(X2[0,i,0, :])
|
viz.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains common visualization functions
|
3 |
+
used to report results of the models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def horiz_merge(left, right):
|
11 |
+
"""
|
12 |
+
merges two images, left and right horizontally to obtain
|
13 |
+
a bigger image containing both.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
---------
|
17 |
+
left: 2D or 3D numpy array
|
18 |
+
left image.
|
19 |
+
2D for grayscale.
|
20 |
+
3D for color.
|
21 |
+
right : numpy array array
|
22 |
+
right image.
|
23 |
+
2D for grayscale
|
24 |
+
3D for color.
|
25 |
+
|
26 |
+
Returns
|
27 |
+
-------
|
28 |
+
|
29 |
+
numpy array (2D or 3D depending on left and right)
|
30 |
+
"""
|
31 |
+
assert left.shape[0] == right.shape[0]
|
32 |
+
assert left.shape[2:] == right.shape[2:]
|
33 |
+
shape = (left.shape[0], left.shape[1] + right.shape[1],) + left.shape[2:]
|
34 |
+
im_merge = np.zeros(shape)
|
35 |
+
im_merge[:, 0:left.shape[1]] = left
|
36 |
+
im_merge[:, left.shape[1]:] = right
|
37 |
+
return im_merge
|
38 |
+
|
39 |
+
def vert_merge(top, bottom):
|
40 |
+
"""
|
41 |
+
merges two images, top and bottom vertically to obtain
|
42 |
+
a bigger image containing both.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
---------
|
46 |
+
top: 2D or 3D numpy array
|
47 |
+
top image.
|
48 |
+
2D for grayscale.
|
49 |
+
3D for color.
|
50 |
+
bottom : numpy array array
|
51 |
+
bottom image.
|
52 |
+
2D for grayscale
|
53 |
+
3D for color.
|
54 |
+
|
55 |
+
Returns
|
56 |
+
-------
|
57 |
+
|
58 |
+
numpy array (2D or 3D depending on left and right)
|
59 |
+
"""
|
60 |
+
im = horiz_merge(top, bottom)
|
61 |
+
if len(im.shape) == 2:
|
62 |
+
im = im.transpose((1, 0))
|
63 |
+
elif len(im.shape) == 3:
|
64 |
+
im = im.transpose((1, 0, 2))
|
65 |
+
return im
|
66 |
+
|
67 |
+
|
68 |
+
def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normalize=False):
|
69 |
+
"""
|
70 |
+
Draw a grid of images from M
|
71 |
+
The order in the grid which corresponds to the order in M
|
72 |
+
is starting from top to bottom then left to right.
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
|
77 |
+
M : numpy array
|
78 |
+
if 3D, convert it to 4D, the shape will be interpreted as (nb_images, h, w) and converted to (nb_images, 1, h, w).
|
79 |
+
if 4D, consider it as colored or grayscale
|
80 |
+
- if the shape is (nb_images, nb_colors, h, w), it is converted to (nb_images, h, w, nb_colors)
|
81 |
+
- otherwise, if it already (nb_images, h, w, nb_colors), use it as it is.
|
82 |
+
- nb_colors can be 1 (grayscale) or 3 (colors).
|
83 |
+
border: int
|
84 |
+
thickness of border(default=0)
|
85 |
+
shape: tuple (nb_cols, nb_rows)
|
86 |
+
shape of the grid
|
87 |
+
by default make a square shape
|
88 |
+
(in that case, it is possible that not all images from M will be part of the grid).
|
89 |
+
normalize: bool(default=False)
|
90 |
+
whether to normalize the pixel values of each image independently
|
91 |
+
by min and max. if False, clip the values of pixels to 0 and 1
|
92 |
+
without normalizing.
|
93 |
+
|
94 |
+
Returns
|
95 |
+
-------
|
96 |
+
|
97 |
+
3D numpy array of shape (h, w, 3)
|
98 |
+
(with a color channel regardless of whether the original images were grayscale or colored)
|
99 |
+
"""
|
100 |
+
if len(M.shape) == 3:
|
101 |
+
M = M[:, :, :, np.newaxis]
|
102 |
+
if M.shape[-1] not in (1, 3):
|
103 |
+
M = M.transpose((0, 2, 3, 1))
|
104 |
+
if M.shape[-1] == 1:
|
105 |
+
M = np.ones((1, 1, 1, 3)) * M
|
106 |
+
bordercolor = np.array(bordercolor)[None, None, :]
|
107 |
+
numimages = len(M)
|
108 |
+
M = M.copy()
|
109 |
+
|
110 |
+
if normalize:
|
111 |
+
for i in range(M.shape[0]):
|
112 |
+
M[i] -= M[i].flatten().min()
|
113 |
+
M[i] /= M[i].flatten().max()
|
114 |
+
else:
|
115 |
+
M = np.clip(M, 0, 1)
|
116 |
+
height, width, color = M[0].shape
|
117 |
+
assert color == 3, 'Nb of color channels are {}'.format(color)
|
118 |
+
if shape is None:
|
119 |
+
n0 = np.int(np.ceil(np.sqrt(numimages)))
|
120 |
+
n1 = np.int(np.ceil(np.sqrt(numimages)))
|
121 |
+
else:
|
122 |
+
n0 = shape[0]
|
123 |
+
n1 = shape[1]
|
124 |
+
|
125 |
+
im = np.array(bordercolor) * np.ones(
|
126 |
+
((height + border) * n1 + border, (width + border) * n0 + border, 1), dtype='<f8')
|
127 |
+
# shape = (n0, n1)
|
128 |
+
# j corresponds to rows in the grid, n1 should correspond to nb of rows
|
129 |
+
# i corresponds to columns in the grid, n0 should correspond to nb of cols
|
130 |
+
# M should be such that the first n1 examples correspond to row 1,
|
131 |
+
# next n1 examples correspond to row 2, etc. that is, M first axis
|
132 |
+
# can be reshaped to (n1, n0)
|
133 |
+
for i in range(n0):
|
134 |
+
for j in range(n1):
|
135 |
+
if i * n1 + j < numimages:
|
136 |
+
im[j * (height + border) + border:(j + 1) * (height + border) + border,
|
137 |
+
i * (width + border) + border:(i + 1) * (width + border) + border, :] = np.concatenate((
|
138 |
+
np.concatenate((M[i * n1 + j, :, :, :],
|
139 |
+
bordercolor * np.ones((height, border, 3), dtype=float)), 1),
|
140 |
+
bordercolor * np.ones((border, width + border, 3), dtype=float)
|
141 |
+
), 0)
|
142 |
+
return im
|
143 |
+
|
144 |
+
grid_of_images_default = partial(grid_of_images, border=1, bordercolor=(0.3, 0, 0))
|
145 |
+
|
146 |
+
|
147 |
+
def reshape_to_images(x, input_shape=None):
|
148 |
+
"""
|
149 |
+
a function that takes a numpy array and try to
|
150 |
+
reshape it to an array of images that would
|
151 |
+
be compatible with the function grid_of_images.
|
152 |
+
Two cases are considered.
|
153 |
+
|
154 |
+
if x is a 2D numpy array, it uses input_shape:
|
155 |
+
- x can either be (nb_examples, nb_features) or (nb_features, nb_examples)
|
156 |
+
- nb_features should be prod(input_shape)
|
157 |
+
- the nb_features dim is then expanded to have :
|
158 |
+
(nb_examples, h, w, nb_channels), sorted input_shape shoud
|
159 |
+
be (h, w, nb_channels).
|
160 |
+
|
161 |
+
if x is a 4D numpy array:
|
162 |
+
- if the first tensor dim is 1 or 3 like e.g. (1, a, b, c), then assume it is
|
163 |
+
color channel and transform to (a, 1, b, c)
|
164 |
+
- if the second tensor dim is 1 or 3, leave x it as it is
|
165 |
+
- if the third tensor dim is 1 or 3, like e.g. (a, b, 1, c), then assume it is
|
166 |
+
color channel and transform to (c, 1, a, b)
|
167 |
+
- if the fourth tensor dim is 1 or 3, like e.g. (a, b, c, 1), then assume it is
|
168 |
+
color channel and transform to (c, 1, a, b)
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
|
172 |
+
x : numpy array
|
173 |
+
input to be reshape
|
174 |
+
input_shape : tuple needed only when x is 2D numpy array
|
175 |
+
"""
|
176 |
+
if len(x.shape) == 2:
|
177 |
+
assert input_shape is not None
|
178 |
+
if x.shape[0] == np.prod(input_shape):
|
179 |
+
x = x.T
|
180 |
+
x = x.reshape((x.shape[0],) + input_shape)
|
181 |
+
x = x.transpose((0, 2, 3, 1))
|
182 |
+
return x
|
183 |
+
elif x.shape[1] == np.prod(input_shape):
|
184 |
+
x = x.reshape((x.shape[0],) + input_shape)
|
185 |
+
x = x.transpose((0, 2, 3, 1))
|
186 |
+
return x
|
187 |
+
else:
|
188 |
+
raise ValueError('Cant recognize this shape : {}'.format(x.shape))
|
189 |
+
elif len(x.shape) == 4:
|
190 |
+
if x.shape[0] in (1, 3):
|
191 |
+
x = x.transpose((1, 0, 2, 3))
|
192 |
+
return x
|
193 |
+
elif x.shape[1] in (1, 3):
|
194 |
+
return x
|
195 |
+
elif x.shape[2] in (1, 3):
|
196 |
+
x = x.transpose((3, 2, 0, 1))
|
197 |
+
return x
|
198 |
+
elif x.shape[3] in (1, 3):
|
199 |
+
x = x.transpose((2, 3, 0, 1))
|
200 |
+
return x
|
201 |
+
else:
|
202 |
+
raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))
|
203 |
+
else:
|
204 |
+
raise ValueError('Cant recognize a shape of size : {}'.format(len(x.shape)))
|