wvle commited on
Commit
24ecec9
·
1 Parent(s): 8d9934c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +392 -0
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from tensorflow.keras import layers
6
+ from tensorflow.keras.applications import efficientnet
7
+ from tensorflow.keras.layers import TextVectorization
8
+
9
+ # Desired image dimensions
10
+ IMAGE_SIZE = (299, 299)
11
+ # Vocabulary size
12
+ VOCAB_SIZE = 10000
13
+ # Fixed length allowed for any sequence
14
+ SEQ_LENGTH = 25
15
+ # Dimension for the image embeddings and token embeddings
16
+ EMBED_DIM = 512
17
+ # Per-layer units in the feed-forward network
18
+ FF_DIM = 512
19
+
20
+ # text preprocessing
21
+ def custom_standardization(input_string):
22
+ lowercase = tf.strings.lower(input_string)
23
+ return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
24
+
25
+ strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
26
+ strip_chars = strip_chars.replace("<", "")
27
+ strip_chars = strip_chars.replace(">", "")
28
+
29
+ vectorization = TextVectorization(
30
+ max_tokens=VOCAB_SIZE,
31
+ output_mode="int",
32
+ output_sequence_length=SEQ_LENGTH,
33
+ standardize=custom_standardization,
34
+ )
35
+ vectorization.adapt(text_data)
36
+
37
+ # image preprocessing
38
+ def decode_and_resize(img_path):
39
+ img = tf.io.read_file(img_path)
40
+ img = tf.image.decode_jpeg(img, channels=3)
41
+ img = tf.image.resize(img, IMAGE_SIZE)
42
+ img = tf.image.convert_image_dtype(img, tf.float32)
43
+ return img
44
+
45
+ # Data augmentation for image data
46
+ image_augmentation = keras.Sequential(
47
+ [
48
+ layers.RandomFlip("horizontal"),
49
+ layers.RandomRotation(0.2),
50
+ layers.RandomContrast(0.3),
51
+ ]
52
+ )
53
+
54
+ # model building
55
+ def get_cnn_model():
56
+ base_model = efficientnet.EfficientNetB0(
57
+ input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet",
58
+ )
59
+ # We freeze our feature extractor
60
+ base_model.trainable = False
61
+ base_model_out = base_model.output
62
+ base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
63
+ cnn_model = keras.models.Model(base_model.input, base_model_out)
64
+ return cnn_model
65
+
66
+
67
+ class TransformerEncoderBlock(layers.Layer):
68
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
69
+ super().__init__(**kwargs)
70
+ self.embed_dim = embed_dim
71
+ self.dense_dim = dense_dim
72
+ self.num_heads = num_heads
73
+ self.attention_1 = layers.MultiHeadAttention(
74
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.0
75
+ )
76
+ self.layernorm_1 = layers.LayerNormalization()
77
+ self.layernorm_2 = layers.LayerNormalization()
78
+ self.dense_1 = layers.Dense(embed_dim, activation="relu")
79
+
80
+ def call(self, inputs, training, mask=None):
81
+ inputs = self.layernorm_1(inputs)
82
+ inputs = self.dense_1(inputs)
83
+
84
+ attention_output_1 = self.attention_1(
85
+ query=inputs,
86
+ value=inputs,
87
+ key=inputs,
88
+ attention_mask=None,
89
+ training=training,
90
+ )
91
+ out_1 = self.layernorm_2(inputs + attention_output_1)
92
+ return out_1
93
+
94
+
95
+ class PositionalEmbedding(layers.Layer):
96
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
97
+ super().__init__(**kwargs)
98
+ self.token_embeddings = layers.Embedding(
99
+ input_dim=vocab_size, output_dim=embed_dim
100
+ )
101
+ self.position_embeddings = layers.Embedding(
102
+ input_dim=sequence_length, output_dim=embed_dim
103
+ )
104
+ self.sequence_length = sequence_length
105
+ self.vocab_size = vocab_size
106
+ self.embed_dim = embed_dim
107
+ self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))
108
+
109
+ def call(self, inputs):
110
+ length = tf.shape(inputs)[-1]
111
+ positions = tf.range(start=0, limit=length, delta=1)
112
+ embedded_tokens = self.token_embeddings(inputs)
113
+ embedded_tokens = embedded_tokens * self.embed_scale
114
+ embedded_positions = self.position_embeddings(positions)
115
+ return embedded_tokens + embedded_positions
116
+
117
+ def compute_mask(self, inputs, mask=None):
118
+ return tf.math.not_equal(inputs, 0)
119
+
120
+
121
+ class TransformerDecoderBlock(layers.Layer):
122
+ def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
123
+ super().__init__(**kwargs)
124
+ self.embed_dim = embed_dim
125
+ self.ff_dim = ff_dim
126
+ self.num_heads = num_heads
127
+ self.attention_1 = layers.MultiHeadAttention(
128
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
129
+ )
130
+ self.attention_2 = layers.MultiHeadAttention(
131
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
132
+ )
133
+ self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
134
+ self.ffn_layer_2 = layers.Dense(embed_dim)
135
+
136
+ self.layernorm_1 = layers.LayerNormalization()
137
+ self.layernorm_2 = layers.LayerNormalization()
138
+ self.layernorm_3 = layers.LayerNormalization()
139
+
140
+ self.embedding = PositionalEmbedding(
141
+ embed_dim=EMBED_DIM, sequence_length=SEQ_LENGTH, vocab_size=VOCAB_SIZE
142
+ )
143
+ self.out = layers.Dense(VOCAB_SIZE, activation="softmax")
144
+
145
+ self.dropout_1 = layers.Dropout(0.3)
146
+ self.dropout_2 = layers.Dropout(0.5)
147
+ self.supports_masking = True
148
+
149
+ def call(self, inputs, encoder_outputs, training, mask=None):
150
+ inputs = self.embedding(inputs)
151
+ causal_mask = self.get_causal_attention_mask(inputs)
152
+
153
+ if mask is not None:
154
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
155
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
156
+ combined_mask = tf.minimum(combined_mask, causal_mask)
157
+
158
+ attention_output_1 = self.attention_1(
159
+ query=inputs,
160
+ value=inputs,
161
+ key=inputs,
162
+ attention_mask=combined_mask,
163
+ training=training,
164
+ )
165
+ out_1 = self.layernorm_1(inputs + attention_output_1)
166
+
167
+ attention_output_2 = self.attention_2(
168
+ query=out_1,
169
+ value=encoder_outputs,
170
+ key=encoder_outputs,
171
+ attention_mask=padding_mask,
172
+ training=training,
173
+ )
174
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
175
+
176
+ ffn_out = self.ffn_layer_1(out_2)
177
+ ffn_out = self.dropout_1(ffn_out, training=training)
178
+ ffn_out = self.ffn_layer_2(ffn_out)
179
+
180
+ ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
181
+ ffn_out = self.dropout_2(ffn_out, training=training)
182
+ preds = self.out(ffn_out)
183
+ return preds
184
+
185
+ def get_causal_attention_mask(self, inputs):
186
+ input_shape = tf.shape(inputs)
187
+ batch_size, sequence_length = input_shape[0], input_shape[1]
188
+ i = tf.range(sequence_length)[:, tf.newaxis]
189
+ j = tf.range(sequence_length)
190
+ mask = tf.cast(i >= j, dtype="int32")
191
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
192
+ mult = tf.concat(
193
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
194
+ axis=0,
195
+ )
196
+ return tf.tile(mask, mult)
197
+
198
+
199
+ class ImageCaptioningModel(keras.Model):
200
+ def __init__(
201
+ self, cnn_model, encoder, decoder, num_captions_per_image=5, image_aug=None,
202
+ ):
203
+ super().__init__()
204
+ self.cnn_model = cnn_model
205
+ self.encoder = encoder
206
+ self.decoder = decoder
207
+ self.loss_tracker = keras.metrics.Mean(name="loss")
208
+ self.acc_tracker = keras.metrics.Mean(name="accuracy")
209
+ self.num_captions_per_image = num_captions_per_image
210
+ self.image_aug = image_aug
211
+
212
+ def calculate_loss(self, y_true, y_pred, mask):
213
+ loss = self.loss(y_true, y_pred)
214
+ mask = tf.cast(mask, dtype=loss.dtype)
215
+ loss *= mask
216
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
217
+
218
+ def calculate_accuracy(self, y_true, y_pred, mask):
219
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
220
+ accuracy = tf.math.logical_and(mask, accuracy)
221
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
222
+ mask = tf.cast(mask, dtype=tf.float32)
223
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
224
+
225
+ def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
226
+ encoder_out = self.encoder(img_embed, training=training)
227
+ batch_seq_inp = batch_seq[:, :-1]
228
+ batch_seq_true = batch_seq[:, 1:]
229
+ mask = tf.math.not_equal(batch_seq_true, 0)
230
+ batch_seq_pred = self.decoder(
231
+ batch_seq_inp, encoder_out, training=training, mask=mask
232
+ )
233
+ loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
234
+ acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
235
+ return loss, acc
236
+
237
+ def train_step(self, batch_data):
238
+ batch_img, batch_seq = batch_data
239
+ batch_loss = 0
240
+ batch_acc = 0
241
+
242
+ if self.image_aug:
243
+ batch_img = self.image_aug(batch_img)
244
+
245
+ # 1. Get image embeddings
246
+ img_embed = self.cnn_model(batch_img)
247
+
248
+ # 2. Pass each of the five captions one by one to the decoder
249
+ # along with the encoder outputs and compute the loss as well as accuracy
250
+ # for each caption.
251
+ for i in range(self.num_captions_per_image):
252
+ with tf.GradientTape() as tape:
253
+ loss, acc = self._compute_caption_loss_and_acc(
254
+ img_embed, batch_seq[:, i, :], training=True
255
+ )
256
+
257
+ # 3. Update loss and accuracy
258
+ batch_loss += loss
259
+ batch_acc += acc
260
+
261
+ # 4. Get the list of all the trainable weights
262
+ train_vars = (
263
+ self.encoder.trainable_variables + self.decoder.trainable_variables
264
+ )
265
+
266
+ # 5. Get the gradients
267
+ grads = tape.gradient(loss, train_vars)
268
+
269
+ # 6. Update the trainable weights
270
+ self.optimizer.apply_gradients(zip(grads, train_vars))
271
+
272
+ # 7. Update the trackers
273
+ batch_acc /= float(self.num_captions_per_image)
274
+ self.loss_tracker.update_state(batch_loss)
275
+ self.acc_tracker.update_state(batch_acc)
276
+
277
+ # 8. Return the loss and accuracy values
278
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
279
+
280
+ def test_step(self, batch_data):
281
+ batch_img, batch_seq = batch_data
282
+ batch_loss = 0
283
+ batch_acc = 0
284
+
285
+ # 1. Get image embeddings
286
+ img_embed = self.cnn_model(batch_img)
287
+
288
+ # 2. Pass each of the five captions one by one to the decoder
289
+ # along with the encoder outputs and compute the loss as well as accuracy
290
+ # for each caption.
291
+ for i in range(self.num_captions_per_image):
292
+ loss, acc = self._compute_caption_loss_and_acc(
293
+ img_embed, batch_seq[:, i, :], training=False
294
+ )
295
+
296
+ # 3. Update batch loss and batch accuracy
297
+ batch_loss += loss
298
+ batch_acc += acc
299
+
300
+ batch_acc /= float(self.num_captions_per_image)
301
+
302
+ # 4. Update the trackers
303
+ self.loss_tracker.update_state(batch_loss)
304
+ self.acc_tracker.update_state(batch_acc)
305
+
306
+ # 5. Return the loss and accuracy values
307
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
308
+
309
+ @property
310
+ def metrics(self):
311
+ # We need to list our metrics here so the `reset_states()` can be
312
+ # called automatically.
313
+ return [self.loss_tracker, self.acc_tracker]
314
+
315
+ # wrapping models
316
+ cnn_model = get_cnn_model()
317
+ encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
318
+ decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)
319
+ caption_model = ImageCaptioningModel(
320
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
321
+ )
322
+
323
+
324
+ loaded_model = ImageCaptioningModel(
325
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
326
+ )
327
+ # load weights
328
+ loaded_model.built = True
329
+ loaded_model.load_weights('/content/drive/My Drive/AI_Hack/cap_model')
330
+
331
+ vocab = vectorization.get_vocabulary()
332
+ index_lookup = dict(zip(range(len(vocab)), vocab))
333
+ max_decoded_sentence_length = SEQ_LENGTH - 1
334
+ valid_images = list(valid_data.keys())
335
+
336
+ def generate_caption(image):
337
+
338
+ sample_img = image
339
+
340
+ # Read the image from the disk
341
+ sample_img = decode_and_resize(sample_img)
342
+ img = sample_img.numpy().clip(0, 255).astype(np.uint8)
343
+ plt.imshow(img)
344
+ plt.show()
345
+
346
+ # Pass the image to the CNN
347
+ img = tf.expand_dims(sample_img, 0)
348
+ img = loaded_model.cnn_model(img)
349
+
350
+ # Pass the image features to the Transformer encoder
351
+ encoded_img = loaded_model.encoder(img, training=False)
352
+
353
+ # Generate the caption using the Transformer decoder
354
+ decoded_caption = "<start> "
355
+ for i in range(max_decoded_sentence_length):
356
+ tokenized_caption = vectorization([decoded_caption])[:, :-1]
357
+ mask = tf.math.not_equal(tokenized_caption, 0)
358
+ predictions = loaded_model.decoder(
359
+ tokenized_caption, encoded_img, training=False, mask=mask
360
+ )
361
+ sampled_token_index = np.argmax(predictions[0, i, :])
362
+ sampled_token = index_lookup[sampled_token_index]
363
+ if sampled_token == " <end>":
364
+ break
365
+ decoded_caption += " " + sampled_token
366
+
367
+ decoded_caption = decoded_caption.replace("<start> ", "")
368
+ decoded_caption = decoded_caption.replace(" <end>", "").strip()
369
+ print("Predicted Caption: ", decoded_caption)
370
+
371
+ inputs = [
372
+ gr.inputs.Image( label="Original Image")
373
+ ]
374
+
375
+ outputs = [
376
+ gr.outputs.Textbox(label = 'Caption')
377
+ ]
378
+
379
+ title = "Image Captioning using CNN and a transformer + "
380
+ description = "Implementing an image cpationing model using a pretrained CNN model of Efficient Net and transformer to generate Image Caption for the uploaded image. Flickr8K Dataset was used for training."
381
+ article = " "
382
+ examples = [["pic 1.jpg"], ["pic 2.jpg"], ["pic 3.jpg"], ["pic 4.jpg"]]
383
+
384
+ gr.Interface(
385
+ generate_caption,
386
+ inputs,
387
+ outputs,
388
+ title=title,
389
+ description=description,
390
+ article=article,
391
+ examples=examples,
392
+ ).launch(debug=True, enable_queue=True)