fbrynpk commited on
Commit
75e181a
·
1 Parent(s): 2322c93

Training and Saving the models

Browse files
Files changed (1) hide show
  1. training.py +383 -0
training.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ import re
6
+ import numpy as np
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+ import collections
10
+ import random
11
+ import requests
12
+ import json
13
+ import pickle
14
+ from math import sqrt
15
+ from PIL import Image
16
+ from tqdm.auto import tqdm
17
+
18
+ DATASET_PATH = './coco2017/'
19
+ MAX_LENGTH = 40
20
+ MAX_VOCABULARY = 12000
21
+ BATCH_SIZE = 64
22
+ BUFFER_SIZE = 1000
23
+ EMBEDDING_DIM = 512
24
+ UNITS = 512
25
+ EPOCHS = 5
26
+
27
+ with open(f'{DATASET_PATH}/annotations/captions_train2017.json', 'r') as f:
28
+ data = json.load(f)
29
+ data = data['annotations']
30
+
31
+ img_cap_pairs = []
32
+
33
+ for sample in data:
34
+ img_name = '%012d.jpg' % sample['image_id']
35
+ img_cap_pairs.append([img_name, sample['caption']])
36
+
37
+ captions = pd.DataFrame(img_cap_pairs, columns=['image', 'caption'])
38
+ captions['image'] = captions['image'].apply(
39
+ lambda x: f'{DATASET_PATH}/train2017/{x}'
40
+ )
41
+ captions = captions.sample(70000)
42
+ captions = captions.reset_index(drop=True)
43
+ captions.head()
44
+
45
+ def preprocessing(text):
46
+ text = text.lower()
47
+ text = re.sub(r'[^\w\s]', '', text)
48
+ text = re.sub('\s+', ' ', text)
49
+ text = text.strip()
50
+ text = '[start] ' + text + ' [end]'
51
+ return text
52
+
53
+ captions['caption'] = captions['caption'].apply(preprocessing)
54
+ captions.head()
55
+
56
+ tokenizer = tf.keras.layers.TextVectorization(
57
+ max_tokens=MAX_VOCABULARY,
58
+ standardize=None,
59
+ output_sequence_length=MAX_LENGTH)
60
+
61
+ tokenizer.adapt(captions['caption'])
62
+
63
+ pickle.dump(tokenizer.get_vocabulary(), open('./image-caption-generator/vocabulary/vocab_coco.file', 'wb'))
64
+
65
+ word2idx = tf.keras.layers.StringLookup(
66
+ mask_token = "",
67
+ vocabulary = tokenizer.get_vocabulary()
68
+ )
69
+
70
+ idx2word = tf.keras.layers.StringLookup(
71
+ mask_token = "",
72
+ vocabulary = tokenizer.get_vocabulary(),
73
+ invert = True
74
+ )
75
+
76
+ img_to_cap_vector = collections.defaultdict(list)
77
+ for img, cap in zip(captions['image'], captions['caption']):
78
+ img_to_cap_vector[img].append(cap)
79
+
80
+ img_keys = list(img_to_cap_vector.keys())
81
+ random.shuffle(img_keys)
82
+
83
+ slice_index = int(len(img_keys)*0.8)
84
+ img_name_train_keys, img_name_test_keys = (img_keys[:slice_index], img_keys[slice_index:])
85
+
86
+ train_img = []
87
+ train_caption = []
88
+ for imgt in img_name_train_keys:
89
+ capt_len = len(img_to_cap_vector[imgt])
90
+ train_img.extend([imgt]*capt_len)
91
+ train_caption.extend(img_to_cap_vector[imgt])
92
+
93
+ test_img = []
94
+ test_caption = []
95
+ for imgtest in img_name_test_keys:
96
+ capv_len = len(img_to_cap_vector[imgtest])
97
+ test_img.extend([imgtest]*capv_len)
98
+ test_caption.extend(img_to_cap_vector[imgtest])
99
+
100
+ len(train_img), len(train_caption), len(test_img), len(test_caption)
101
+
102
+ def load_data(img_path, caption):
103
+ img = tf.io.read_file(img_path)
104
+ img = tf.io.decode_jpeg(img, channels=3)
105
+ img = tf.keras.layers.Resizing(299, 299)(img)
106
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
107
+ caption = tokenizer(caption)
108
+ return img, caption
109
+
110
+ train_dataset = tf.data.Dataset.from_tensor_slices((train_img,train_caption))
111
+
112
+ train_dataset = train_dataset.map(load_data, num_parallel_calls = tf.data.AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
113
+
114
+ test_dataset = tf.data.Dataset.from_tensor_slices((test_img,test_caption))
115
+
116
+ test_dataset = test_dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
117
+
118
+ image_augmentation = tf.keras.Sequential(
119
+ [
120
+ tf.keras.layers.RandomFlip("horizontal"),
121
+ tf.keras.layers.RandomRotation(0.2),
122
+ tf.keras.layers.RandomContrast(0.3),
123
+ ]
124
+ )
125
+
126
+ def CNN_Encoder():
127
+ inception_v3 = tf.keras.applications.InceptionV3(
128
+ include_top=False,
129
+ weights='imagenet'
130
+ )
131
+
132
+ output = inception_v3.output
133
+ output = tf.keras.layers.Reshape(
134
+ (-1, output.shape[-1]))(output)
135
+
136
+ cnn_model = tf.keras.models.Model(inception_v3.input, output)
137
+ return cnn_model
138
+
139
+
140
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
141
+
142
+ def __init__(self, embed_dim, num_heads):
143
+ super().__init__()
144
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization()
145
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization()
146
+ self.attention = tf.keras.layers.MultiHeadAttention(
147
+ num_heads=num_heads, key_dim=embed_dim)
148
+ self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
149
+
150
+
151
+ def call(self, x, training):
152
+ x = self.layer_norm_1(x)
153
+ x = self.dense(x)
154
+
155
+ attn_output = self.attention(
156
+ query=x,
157
+ value=x,
158
+ key=x,
159
+ attention_mask=None,
160
+ training=training
161
+ )
162
+
163
+ x = self.layer_norm_2(x + attn_output)
164
+ return x
165
+
166
+
167
+ class Embeddings(tf.keras.layers.Layer):
168
+
169
+ def __init__(self, vocab_size, embed_dim, max_len):
170
+ super().__init__()
171
+ self.token_embeddings = tf.keras.layers.Embedding(
172
+ vocab_size, embed_dim)
173
+ self.position_embeddings = tf.keras.layers.Embedding(
174
+ max_len, embed_dim, input_shape=(None, max_len))
175
+
176
+
177
+ def call(self, input_ids):
178
+ length = tf.shape(input_ids)[-1]
179
+ position_ids = tf.range(start=0, limit=length, delta=1)
180
+ position_ids = tf.expand_dims(position_ids, axis=0)
181
+
182
+ token_embeddings = self.token_embeddings(input_ids)
183
+ position_embeddings = self.position_embeddings(position_ids)
184
+
185
+ return token_embeddings + position_embeddings
186
+
187
+ class TransformerDecoderLayer(tf.keras.layers.Layer):
188
+
189
+ def __init__(self, embed_dim, units, num_heads):
190
+ super().__init__()
191
+ self.embedding = Embeddings(
192
+ tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
193
+
194
+ self.attention_1 = tf.keras.layers.MultiHeadAttention(
195
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
196
+ )
197
+ self.attention_2 = tf.keras.layers.MultiHeadAttention(
198
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
199
+ )
200
+
201
+ self.layernorm_1 = tf.keras.layers.LayerNormalization()
202
+ self.layernorm_2 = tf.keras.layers.LayerNormalization()
203
+ self.layernorm_3 = tf.keras.layers.LayerNormalization()
204
+
205
+ self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
206
+ self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
207
+
208
+ self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
209
+
210
+ self.dropout_1 = tf.keras.layers.Dropout(0.3)
211
+ self.dropout_2 = tf.keras.layers.Dropout(0.5)
212
+
213
+
214
+ def call(self, input_ids, encoder_output, training, mask=None):
215
+ embeddings = self.embedding(input_ids)
216
+
217
+ combined_mask = None
218
+ padding_mask = None
219
+
220
+ if mask is not None:
221
+ causal_mask = self.get_causal_attention_mask(embeddings)
222
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
223
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
224
+ combined_mask = tf.minimum(combined_mask, causal_mask)
225
+
226
+ attn_output_1 = self.attention_1(
227
+ query=embeddings,
228
+ value=embeddings,
229
+ key=embeddings,
230
+ attention_mask=combined_mask,
231
+ training=training
232
+ )
233
+
234
+ out_1 = self.layernorm_1(embeddings + attn_output_1)
235
+
236
+ attn_output_2 = self.attention_2(
237
+ query=out_1,
238
+ value=encoder_output,
239
+ key=encoder_output,
240
+ attention_mask=padding_mask,
241
+ training=training
242
+ )
243
+
244
+ out_2 = self.layernorm_2(out_1 + attn_output_2)
245
+
246
+ ffn_out = self.ffn_layer_1(out_2)
247
+ ffn_out = self.dropout_1(ffn_out, training=training)
248
+ ffn_out = self.ffn_layer_2(ffn_out)
249
+
250
+ ffn_out = self.layernorm_3(ffn_out + out_2)
251
+ ffn_out = self.dropout_2(ffn_out, training=training)
252
+ preds = self.out(ffn_out)
253
+ return preds
254
+
255
+
256
+ def get_causal_attention_mask(self, inputs):
257
+ input_shape = tf.shape(inputs)
258
+ batch_size, sequence_length = input_shape[0], input_shape[1]
259
+ i = tf.range(sequence_length)[:, tf.newaxis]
260
+ j = tf.range(sequence_length)
261
+ mask = tf.cast(i >= j, dtype="int32")
262
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
263
+ mult = tf.concat(
264
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
265
+ axis=0
266
+ )
267
+ return tf.tile(mask, mult)
268
+
269
+
270
+ class ImageCaptioningModel(tf.keras.Model):
271
+
272
+ def __init__(self, cnn_model, encoder, decoder, image_aug=None):
273
+ super().__init__()
274
+ self.cnn_model = cnn_model
275
+ self.encoder = encoder
276
+ self.decoder = decoder
277
+ self.image_aug = image_aug
278
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
279
+ self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
280
+
281
+
282
+ def calculate_loss(self, y_true, y_pred, mask):
283
+ loss = self.loss(y_true, y_pred)
284
+ mask = tf.cast(mask, dtype=loss.dtype)
285
+ loss *= mask
286
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
287
+
288
+
289
+ def calculate_accuracy(self, y_true, y_pred, mask):
290
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
291
+ accuracy = tf.math.logical_and(mask, accuracy)
292
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
293
+ mask = tf.cast(mask, dtype=tf.float32)
294
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
295
+
296
+
297
+ def compute_loss_and_acc(self, img_embed, captions, training=True):
298
+ encoder_output = self.encoder(img_embed, training=True)
299
+ y_input = captions[:, :-1]
300
+ y_true = captions[:, 1:]
301
+ mask = (y_true != 0)
302
+ y_pred = self.decoder(
303
+ y_input, encoder_output, training=True, mask=mask
304
+ )
305
+ loss = self.calculate_loss(y_true, y_pred, mask)
306
+ acc = self.calculate_accuracy(y_true, y_pred, mask)
307
+ return loss, acc
308
+
309
+
310
+ def train_step(self, batch):
311
+ imgs, captions = batch
312
+
313
+ if self.image_aug:
314
+ imgs = self.image_aug(imgs)
315
+
316
+ img_embed = self.cnn_model(imgs)
317
+
318
+ with tf.GradientTape() as tape:
319
+ loss, acc = self.compute_loss_and_acc(
320
+ img_embed, captions
321
+ )
322
+
323
+ train_vars = (
324
+ self.encoder.trainable_variables + self.decoder.trainable_variables
325
+ )
326
+ grads = tape.gradient(loss, train_vars)
327
+ self.optimizer.apply_gradients(zip(grads, train_vars))
328
+ self.loss_tracker.update_state(loss)
329
+ self.acc_tracker.update_state(acc)
330
+
331
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
332
+
333
+
334
+ def test_step(self, batch):
335
+ imgs, captions = batch
336
+
337
+ img_embed = self.cnn_model(imgs)
338
+
339
+ loss, acc = self.compute_loss_and_acc(
340
+ img_embed, captions, training=False
341
+ )
342
+
343
+ self.loss_tracker.update_state(loss)
344
+ self.acc_tracker.update_state(acc)
345
+
346
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
347
+
348
+ @property
349
+ def metrics(self):
350
+ return [self.loss_tracker, self.acc_tracker]
351
+
352
+
353
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
354
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
355
+
356
+ cnn_model = CNN_Encoder()
357
+ caption_model = ImageCaptioningModel(
358
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
359
+ )
360
+
361
+
362
+ cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
363
+ from_logits=False, reduction="none"
364
+ )
365
+
366
+ early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
367
+
368
+ caption_model.compile(
369
+ optimizer=tf.keras.optimizers.Adam(),
370
+ loss=cross_entropy
371
+ )
372
+
373
+ history = caption_model.fit(
374
+ train_dataset,
375
+ epochs=EPOCHS,
376
+ validation_data=val_dataset,
377
+ callbacks=[early_stopping]
378
+ )
379
+
380
+ caption_model.save_weights('./image-caption-generator/models/trained_coco_weights.h5')
381
+
382
+
383
+