Lokesh1024 commited on
Commit
f4c35df
·
verified ·
1 Parent(s): 72311f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -42
app.py CHANGED
@@ -1,55 +1,280 @@
 
 
 
 
1
  import os
2
- os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
3
  import io
4
- import os
5
  import streamlit as st
6
  import requests
7
  from PIL import Image
8
- from model import get_caption_model, generate_caption
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- @st.cache(allow_output_mutation=True)
12
- def get_model():
13
- return get_caption_model()
 
 
 
 
 
14
 
15
- caption_model = get_model()
 
 
 
 
 
 
 
16
 
 
 
 
17
 
18
- def predict():
19
- captions = []
20
- pred_caption = generate_caption('tmp.jpg', caption_model)
 
21
 
22
- st.markdown('#### Predicted Captions:')
23
- captions.append(pred_caption)
 
 
 
 
 
 
 
 
 
 
24
 
25
- for _ in range(4):
26
- pred_caption = generate_caption('tmp.jpg', caption_model, add_noise=True)
27
- if pred_caption not in captions:
28
- captions.append(pred_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for c in captions:
31
- st.write(c)
32
-
33
- st.title('Image Captioner')
34
- img_url = st.text_input(label='Enter Image URL')
35
-
36
- if (img_url != "") and (img_url != None):
37
- img = Image.open(requests.get(img_url, stream=True).raw)
38
- img = img.convert('RGB')
39
- st.image(img)
40
- img.save('tmp.jpg')
41
- predict()
42
- os.remove('tmp.jpg')
43
-
44
-
45
- st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
46
- img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
47
-
48
- if img_upload != None:
49
- img = img_upload.read()
50
- img = Image.open(io.BytesIO(img))
51
- img = img.convert('RGB')
52
- img.save('tmp.jpg')
53
- st.image(img)
54
- predict()
55
- os.remove('tmp.jpg')
 
1
+ import pickle
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
  import os
 
6
  import io
 
7
  import streamlit as st
8
  import requests
9
  from PIL import Image
 
10
 
11
+ # Set environment variable
12
+ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
13
+
14
+ # Constants
15
+ MAX_LENGTH = 40
16
+ BATCH_SIZE = 32
17
+ BUFFER_SIZE = 1000
18
+ EMBEDDING_DIM = 512
19
+ UNITS = 512
20
+
21
+ # Load vocabulary
22
+ vocab = pickle.load(open('saved_vocabulary/vocab_coco.file', 'rb'))
23
+
24
+ tokenizer = tf.keras.layers.TextVectorization(
25
+ standardize=None,
26
+ output_sequence_length=MAX_LENGTH,
27
+ vocabulary=vocab
28
+ )
29
+
30
+ idx2word = tf.keras.layers.StringLookup(
31
+ mask_token="",
32
+ vocabulary=tokenizer.get_vocabulary(),
33
+ invert=True
34
+ )
35
+
36
+ # Model Definitions
37
+ def CNN_Encoder():
38
+ inception_v3 = tf.keras.applications.InceptionV3(
39
+ include_top=False,
40
+ weights='imagenet'
41
+ )
42
+ output = inception_v3.output
43
+ output = tf.keras.layers.Reshape(
44
+ (-1, output.shape[-1]))(output)
45
+ cnn_model = tf.keras.models.Model(inception_v3.input, output)
46
+ return cnn_model
47
+
48
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
49
+ def __init__(self, embed_dim, num_heads):
50
+ super().__init__()
51
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization()
52
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization()
53
+ self.attention = tf.keras.layers.MultiHeadAttention(
54
+ num_heads=num_heads, key_dim=embed_dim)
55
+ self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
56
+
57
+ def call(self, x, training):
58
+ x = self.layer_norm_1(x)
59
+ x = self.dense(x)
60
+ attn_output = self.attention(
61
+ query=x,
62
+ value=x,
63
+ key=x,
64
+ attention_mask=None,
65
+ training=training
66
+ )
67
+ x = self.layer_norm_2(x + attn_output)
68
+ return x
69
+
70
+ class Embeddings(tf.keras.layers.Layer):
71
+ def __init__(self, vocab_size, embed_dim, max_len):
72
+ super().__init__()
73
+ self.token_embeddings = tf.keras.layers.Embedding(
74
+ vocab_size, embed_dim)
75
+ self.position_embeddings = tf.keras.layers.Embedding(
76
+ max_len, embed_dim, input_shape=(None, max_len))
77
+
78
+ def call(self, input_ids):
79
+ length = tf.shape(input_ids)[-1]
80
+ position_ids = tf.range(start=0, limit=length, delta=1)
81
+ position_ids = tf.expand_dims(position_ids, axis=0)
82
+ token_embeddings = self.token_embeddings(input_ids)
83
+ position_embeddings = self.position_embeddings(position_ids)
84
+ return token_embeddings + position_embeddings
85
+
86
+ class TransformerDecoderLayer(tf.keras.layers.Layer):
87
+ def __init__(self, embed_dim, units, num_heads):
88
+ super().__init__()
89
+ self.embedding = Embeddings(
90
+ tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
91
+ self.attention_1 = tf.keras.layers.MultiHeadAttention(
92
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
93
+ )
94
+ self.attention_2 = tf.keras.layers.MultiHeadAttention(
95
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
96
+ )
97
+ self.layernorm_1 = tf.keras.layers.LayerNormalization()
98
+ self.layernorm_2 = tf.keras.layers.LayerNormalization()
99
+ self.layernorm_3 = tf.keras.layers.LayerNormalization()
100
+ self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
101
+ self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
102
+ self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
103
+ self.dropout_1 = tf.keras.layers.Dropout(0.3)
104
+ self.dropout_2 = tf.keras.layers.Dropout(0.5)
105
+
106
+ def call(self, input_ids, encoder_output, training, mask=None):
107
+ embeddings = self.embedding(input_ids)
108
+ combined_mask = None
109
+ padding_mask = None
110
+
111
+ if mask is not None:
112
+ causal_mask = self.get_causal_attention_mask(embeddings)
113
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
114
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
115
+ combined_mask = tf.minimum(combined_mask, causal_mask)
116
 
117
+ attn_output_1 = self.attention_1(
118
+ query=embeddings,
119
+ value=embeddings,
120
+ key=embeddings,
121
+ attention_mask=combined_mask,
122
+ training=training
123
+ )
124
+ out_1 = self.layernorm_1(embeddings + attn_output_1)
125
 
126
+ attn_output_2 = self.attention_2(
127
+ query=out_1,
128
+ value=encoder_output,
129
+ key=encoder_output,
130
+ attention_mask=padding_mask,
131
+ training=training
132
+ )
133
+ out_2 = self.layernorm_2(out_1 + attn_output_2)
134
 
135
+ ffn_out = self.ffn_layer_1(out_2)
136
+ ffn_out = self.dropout_1(ffn_out, training=training)
137
+ ffn_out = self.ffn_layer_2(ffn_out)
138
 
139
+ ffn_out = self.layernorm_3(ffn_out + out_2)
140
+ ffn_out = self.dropout_2(ffn_out, training=training)
141
+ preds = self.out(ffn_out)
142
+ return preds
143
 
144
+ def get_causal_attention_mask(self, inputs):
145
+ input_shape = tf.shape(inputs)
146
+ batch_size, sequence_length = input_shape[0], input_shape[1]
147
+ i = tf.range(sequence_length)[:, tf.newaxis]
148
+ j = tf.range(sequence_length)
149
+ mask = tf.cast(i >= j, dtype="int32")
150
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
151
+ mult = tf.concat(
152
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
153
+ axis=0
154
+ )
155
+ return tf.tile(mask, mult)
156
 
157
+ class ImageCaptioningModel(tf.keras.Model):
158
+ def __init__(self, cnn_model, encoder, decoder, image_aug=None):
159
+ super().__init__()
160
+ self.cnn_model = cnn_model
161
+ self.encoder = encoder
162
+ self.decoder = decoder
163
+ self.image_aug = image_aug
164
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
165
+ self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
166
+
167
+ def calculate_loss(self, y_true, y_pred, mask):
168
+ loss = self.loss(y_true, y_pred)
169
+ mask = tf.cast(mask, dtype=loss.dtype)
170
+ loss *= mask
171
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
172
+
173
+ def calculate_accuracy(self, y_true, y_pred, mask):
174
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
175
+ accuracy = tf.math.logical_and(mask, accuracy)
176
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
177
+ mask = tf.cast(mask, dtype=tf.float32)
178
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
179
+
180
+ def compute_loss_and_acc(self, img_embed, captions, training=True):
181
+ encoder_output = self.encoder(img_embed, training=True)
182
+ y_input = captions[:, :-1]
183
+ y_true = captions[:, 1:]
184
+ mask = (y_true != 0)
185
+ y_pred = self.decoder(
186
+ y_input, encoder_output, training=True, mask=mask
187
+ )
188
+ loss = self.calculate_loss(y_true, y_pred, mask)
189
+ acc = self.calculate_accuracy(y_true, y_pred, mask)
190
+ return loss, acc
191
+
192
+ def train_step(self, batch):
193
+ imgs, captions = batch
194
+ if self.image_aug:
195
+ imgs = self.image_aug(imgs)
196
+ img_embed = self.cnn_model(imgs)
197
+ with tf.GradientTape() as tape:
198
+ loss, acc = self.compute_loss_and_acc(
199
+ img_embed, captions
200
+ )
201
+ train_vars = (
202
+ self.encoder.trainable_variables + self.decoder.trainable_variables
203
+ )
204
+ grads = tape.gradient(loss, train_vars)
205
+ self.optimizer.apply_gradients(zip(grads, train_vars))
206
+ self.loss_tracker.update_state(loss)
207
+ self.acc_tracker.update_state(acc)
208
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
209
+
210
+ def test_step(self, batch):
211
+ imgs, captions = batch
212
+ img_embed = self.cnn_model(imgs)
213
+ loss, acc = self.compute_loss_and_acc(
214
+ img_embed, captions, training=False
215
+ )
216
+ self.loss_tracker.update_state(loss)
217
+ self.acc_tracker.update_state(acc)
218
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
219
+
220
+ @property
221
+ def metrics(self):
222
+ return [self.loss_tracker, self.acc_tracker]
223
+
224
+ def load_image_from_path(img_path):
225
+ img = tf.io.read_file(img_path)
226
+ img = tf.io.decode_jpeg(img, channels=3)
227
+ img = tf.keras.layers.Resizing(299, 299)(img)
228
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
229
+ return img
230
+
231
+ def generate_caption(img, caption_model, add_noise=False):
232
+ if isinstance(img, str):
233
+ img = load_image_from_path(img)
234
+ if add_noise:
235
+ noise = tf.random.normal(img.shape) * 0.1
236
+ img = (img + noise)
237
+ img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
238
+ img = tf.expand_dims(img, 0) # Add batch dimension
239
+ img_embed = caption_model.cnn_model(img, training=False)
240
+ encoder_output = caption_model.encoder(img_embed, training=False)
241
+ caption = [tokenizer.token_to_id("[START]")]
242
+ for _ in range(MAX_LENGTH):
243
+ input_caption = tf.convert_to_tensor([caption], dtype=tf.int32)
244
+ pred = caption_model.decoder(input_caption, encoder_output, training=False)
245
+ pred = tf.argmax(pred[0, -1, :]).numpy()
246
+ caption.append(pred)
247
+ if pred == tokenizer.token_to_id("[END]"):
248
+ break
249
+ return ' '.join([idx2word(word).numpy().decode('utf-8') for word in caption[1:-1]])
250
+
251
+ # Load saved model weights
252
+ cnn_model = CNN_Encoder()
253
+ encoder = TransformerEncoderLayer(embed_dim=EMBEDDING_DIM, num_heads=8)
254
+ decoder = TransformerDecoderLayer(embed_dim=EMBEDDING_DIM, units=UNITS, num_heads=8)
255
+ caption_model = ImageCaptioningModel(cnn_model=cnn_model, encoder=encoder, decoder=decoder)
256
+ caption_model.load_weights('saved_model_weights/caption_model')
257
+
258
+ # Streamlit App
259
+ st.title('Image Captioning with Transformer')
260
+ st.write('Upload an image to generate a caption.')
261
+
262
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
263
+
264
+ if uploaded_file is not None:
265
+ image = Image.open(uploaded_file)
266
+ st.image(image, caption='Uploaded Image', use_column_width=True)
267
+ st.write("")
268
+ st.write("Generating caption...")
269
 
270
+ img_path = os.path.join("temp", uploaded_file.name)
271
+ with open(img_path, "wb") as f:
272
+ f.write(uploaded_file.getbuffer())
273
+
274
+ img = load_image_from_path(img_path)
275
+ caption = generate_caption(img, caption_model)
276
+ st.write("Caption:", caption)
277
+
278
+ # Remove temp file after captioning
279
+ if uploaded_file is not None and os.path.exists(img_path):
280
+ os.remove(img_path)