Lokesh1024 commited on
Commit
73f387e
·
verified ·
1 Parent(s): 0c88930

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -268
app.py CHANGED
@@ -1,280 +1,53 @@
1
- import pickle
2
- import tensorflow as tf
3
- import pandas as pd
4
- import numpy as np
5
- import os
6
  import io
 
7
  import streamlit as st
8
  import requests
9
  from PIL import Image
 
10
 
11
- # Set environment variable
12
- os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
13
-
14
- # Constants
15
- MAX_LENGTH = 40
16
- BATCH_SIZE = 32
17
- BUFFER_SIZE = 1000
18
- EMBEDDING_DIM = 512
19
- UNITS = 512
20
-
21
- # Load vocabulary
22
- vocab = pickle.load(open('saved_vocabulary/vocab_coco.file', 'rb'))
23
-
24
- tokenizer = tf.keras.layers.TextVectorization(
25
- standardize=None,
26
- output_sequence_length=MAX_LENGTH,
27
- vocabulary=vocab
28
- )
29
-
30
- idx2word = tf.keras.layers.StringLookup(
31
- mask_token="",
32
- vocabulary=tokenizer.get_vocabulary(),
33
- invert=True
34
- )
35
-
36
- # Model Definitions
37
- def CNN_Encoder():
38
- inception_v3 = tf.keras.applications.InceptionV3(
39
- include_top=False,
40
- weights='imagenet'
41
- )
42
- output = inception_v3.output
43
- output = tf.keras.layers.Reshape(
44
- (-1, output.shape[-1]))(output)
45
- cnn_model = tf.keras.models.Model(inception_v3.input, output)
46
- return cnn_model
47
-
48
- class TransformerEncoderLayer(tf.keras.layers.Layer):
49
- def __init__(self, embed_dim, num_heads):
50
- super().__init__()
51
- self.layer_norm_1 = tf.keras.layers.LayerNormalization()
52
- self.layer_norm_2 = tf.keras.layers.LayerNormalization()
53
- self.attention = tf.keras.layers.MultiHeadAttention(
54
- num_heads=num_heads, key_dim=embed_dim)
55
- self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
56
-
57
- def call(self, x, training):
58
- x = self.layer_norm_1(x)
59
- x = self.dense(x)
60
- attn_output = self.attention(
61
- query=x,
62
- value=x,
63
- key=x,
64
- attention_mask=None,
65
- training=training
66
- )
67
- x = self.layer_norm_2(x + attn_output)
68
- return x
69
-
70
- class Embeddings(tf.keras.layers.Layer):
71
- def __init__(self, vocab_size, embed_dim, max_len):
72
- super().__init__()
73
- self.token_embeddings = tf.keras.layers.Embedding(
74
- vocab_size, embed_dim)
75
- self.position_embeddings = tf.keras.layers.Embedding(
76
- max_len, embed_dim, input_shape=(None, max_len))
77
-
78
- def call(self, input_ids):
79
- length = tf.shape(input_ids)[-1]
80
- position_ids = tf.range(start=0, limit=length, delta=1)
81
- position_ids = tf.expand_dims(position_ids, axis=0)
82
- token_embeddings = self.token_embeddings(input_ids)
83
- position_embeddings = self.position_embeddings(position_ids)
84
- return token_embeddings + position_embeddings
85
-
86
- class TransformerDecoderLayer(tf.keras.layers.Layer):
87
- def __init__(self, embed_dim, units, num_heads):
88
- super().__init__()
89
- self.embedding = Embeddings(
90
- tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
91
- self.attention_1 = tf.keras.layers.MultiHeadAttention(
92
- num_heads=num_heads, key_dim=embed_dim, dropout=0.1
93
- )
94
- self.attention_2 = tf.keras.layers.MultiHeadAttention(
95
- num_heads=num_heads, key_dim=embed_dim, dropout=0.1
96
- )
97
- self.layernorm_1 = tf.keras.layers.LayerNormalization()
98
- self.layernorm_2 = tf.keras.layers.LayerNormalization()
99
- self.layernorm_3 = tf.keras.layers.LayerNormalization()
100
- self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
101
- self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
102
- self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
103
- self.dropout_1 = tf.keras.layers.Dropout(0.3)
104
- self.dropout_2 = tf.keras.layers.Dropout(0.5)
105
-
106
- def call(self, input_ids, encoder_output, training, mask=None):
107
- embeddings = self.embedding(input_ids)
108
- combined_mask = None
109
- padding_mask = None
110
-
111
- if mask is not None:
112
- causal_mask = self.get_causal_attention_mask(embeddings)
113
- padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
114
- combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
115
- combined_mask = tf.minimum(combined_mask, causal_mask)
116
 
117
- attn_output_1 = self.attention_1(
118
- query=embeddings,
119
- value=embeddings,
120
- key=embeddings,
121
- attention_mask=combined_mask,
122
- training=training
123
- )
124
- out_1 = self.layernorm_1(embeddings + attn_output_1)
125
 
126
- attn_output_2 = self.attention_2(
127
- query=out_1,
128
- value=encoder_output,
129
- key=encoder_output,
130
- attention_mask=padding_mask,
131
- training=training
132
- )
133
- out_2 = self.layernorm_2(out_1 + attn_output_2)
134
 
135
- ffn_out = self.ffn_layer_1(out_2)
136
- ffn_out = self.dropout_1(ffn_out, training=training)
137
- ffn_out = self.ffn_layer_2(ffn_out)
138
 
139
- ffn_out = self.layernorm_3(ffn_out + out_2)
140
- ffn_out = self.dropout_2(ffn_out, training=training)
141
- preds = self.out(ffn_out)
142
- return preds
143
 
144
- def get_causal_attention_mask(self, inputs):
145
- input_shape = tf.shape(inputs)
146
- batch_size, sequence_length = input_shape[0], input_shape[1]
147
- i = tf.range(sequence_length)[:, tf.newaxis]
148
- j = tf.range(sequence_length)
149
- mask = tf.cast(i >= j, dtype="int32")
150
- mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
151
- mult = tf.concat(
152
- [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
153
- axis=0
154
- )
155
- return tf.tile(mask, mult)
156
 
157
- class ImageCaptioningModel(tf.keras.Model):
158
- def __init__(self, cnn_model, encoder, decoder, image_aug=None):
159
- super().__init__()
160
- self.cnn_model = cnn_model
161
- self.encoder = encoder
162
- self.decoder = decoder
163
- self.image_aug = image_aug
164
- self.loss_tracker = tf.keras.metrics.Mean(name="loss")
165
- self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
166
-
167
- def calculate_loss(self, y_true, y_pred, mask):
168
- loss = self.loss(y_true, y_pred)
169
- mask = tf.cast(mask, dtype=loss.dtype)
170
- loss *= mask
171
- return tf.reduce_sum(loss) / tf.reduce_sum(mask)
172
-
173
- def calculate_accuracy(self, y_true, y_pred, mask):
174
- accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
175
- accuracy = tf.math.logical_and(mask, accuracy)
176
- accuracy = tf.cast(accuracy, dtype=tf.float32)
177
- mask = tf.cast(mask, dtype=tf.float32)
178
- return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
179
-
180
- def compute_loss_and_acc(self, img_embed, captions, training=True):
181
- encoder_output = self.encoder(img_embed, training=True)
182
- y_input = captions[:, :-1]
183
- y_true = captions[:, 1:]
184
- mask = (y_true != 0)
185
- y_pred = self.decoder(
186
- y_input, encoder_output, training=True, mask=mask
187
- )
188
- loss = self.calculate_loss(y_true, y_pred, mask)
189
- acc = self.calculate_accuracy(y_true, y_pred, mask)
190
- return loss, acc
191
-
192
- def train_step(self, batch):
193
- imgs, captions = batch
194
- if self.image_aug:
195
- imgs = self.image_aug(imgs)
196
- img_embed = self.cnn_model(imgs)
197
- with tf.GradientTape() as tape:
198
- loss, acc = self.compute_loss_and_acc(
199
- img_embed, captions
200
- )
201
- train_vars = (
202
- self.encoder.trainable_variables + self.decoder.trainable_variables
203
- )
204
- grads = tape.gradient(loss, train_vars)
205
- self.optimizer.apply_gradients(zip(grads, train_vars))
206
- self.loss_tracker.update_state(loss)
207
- self.acc_tracker.update_state(acc)
208
- return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
209
-
210
- def test_step(self, batch):
211
- imgs, captions = batch
212
- img_embed = self.cnn_model(imgs)
213
- loss, acc = self.compute_loss_and_acc(
214
- img_embed, captions, training=False
215
- )
216
- self.loss_tracker.update_state(loss)
217
- self.acc_tracker.update_state(acc)
218
- return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
219
-
220
- @property
221
- def metrics(self):
222
- return [self.loss_tracker, self.acc_tracker]
223
-
224
- def load_image_from_path(img_path):
225
- img = tf.io.read_file(img_path)
226
- img = tf.io.decode_jpeg(img, channels=3)
227
- img = tf.keras.layers.Resizing(299, 299)(img)
228
- img = tf.keras.applications.inception_v3.preprocess_input(img)
229
- return img
230
-
231
- def generate_caption(img, caption_model, add_noise=False):
232
- if isinstance(img, str):
233
- img = load_image_from_path(img)
234
- if add_noise:
235
- noise = tf.random.normal(img.shape) * 0.1
236
- img = (img + noise)
237
- img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
238
- img = tf.expand_dims(img, 0) # Add batch dimension
239
- img_embed = caption_model.cnn_model(img, training=False)
240
- encoder_output = caption_model.encoder(img_embed, training=False)
241
- caption = [tokenizer.token_to_id("[START]")]
242
- for _ in range(MAX_LENGTH):
243
- input_caption = tf.convert_to_tensor([caption], dtype=tf.int32)
244
- pred = caption_model.decoder(input_caption, encoder_output, training=False)
245
- pred = tf.argmax(pred[0, -1, :]).numpy()
246
- caption.append(pred)
247
- if pred == tokenizer.token_to_id("[END]"):
248
- break
249
- return ' '.join([idx2word(word).numpy().decode('utf-8') for word in caption[1:-1]])
250
-
251
- # Load saved model weights
252
- cnn_model = CNN_Encoder()
253
- encoder = TransformerEncoderLayer(embed_dim=EMBEDDING_DIM, num_heads=8)
254
- decoder = TransformerDecoderLayer(embed_dim=EMBEDDING_DIM, units=UNITS, num_heads=8)
255
- caption_model = ImageCaptioningModel(cnn_model=cnn_model, encoder=encoder, decoder=decoder)
256
- caption_model.load_weights('saved_model_weights/caption_model')
257
-
258
- # Streamlit App
259
- st.title('Image Captioning with Transformer')
260
- st.write('Upload an image to generate a caption.')
261
-
262
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
263
-
264
- if uploaded_file is not None:
265
- image = Image.open(uploaded_file)
266
- st.image(image, caption='Uploaded Image', use_column_width=True)
267
- st.write("")
268
- st.write("Generating caption...")
269
 
270
- img_path = os.path.join("temp", uploaded_file.name)
271
- with open(img_path, "wb") as f:
272
- f.write(uploaded_file.getbuffer())
273
-
274
- img = load_image_from_path(img_path)
275
- caption = generate_caption(img, caption_model)
276
- st.write("Caption:", caption)
277
-
278
- # Remove temp file after captioning
279
- if uploaded_file is not None and os.path.exists(img_path):
280
- os.remove(img_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
+ import os
3
  import streamlit as st
4
  import requests
5
  from PIL import Image
6
+ from model import get_caption_model, generate_caption
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ @st.cache(allow_output_mutation=True)
10
+ def get_model():
11
+ return get_caption_model()
 
 
 
 
 
12
 
13
+ caption_model = get_model()
 
 
 
 
 
 
 
14
 
 
 
 
15
 
16
+ def predict():
17
+ captions = []
18
+ pred_caption = generate_caption('tmp.jpg', caption_model)
 
19
 
20
+ st.markdown('#### Predicted Captions:')
21
+ captions.append(pred_caption)
 
 
 
 
 
 
 
 
 
 
22
 
23
+ for _ in range(4):
24
+ pred_caption = generate_caption('tmp.jpg', caption_model, add_noise=True)
25
+ if pred_caption not in captions:
26
+ captions.append(pred_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ for c in captions:
29
+ st.write(c)
30
+
31
+ st.title('Image Captioner')
32
+ img_url = st.text_input(label='Enter Image URL')
33
+
34
+ if (img_url != "") and (img_url != None):
35
+ img = Image.open(requests.get(img_url, stream=True).raw)
36
+ img = img.convert('RGB')
37
+ st.image(img)
38
+ img.save('tmp.jpg')
39
+ predict()
40
+ os.remove('tmp.jpg')
41
+
42
+
43
+ st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
44
+ img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
45
+
46
+ if img_upload != None:
47
+ img = img_upload.read()
48
+ img = Image.open(io.BytesIO(img))
49
+ img = img.convert('RGB')
50
+ img.save('tmp.jpg')
51
+ st.image(img)
52
+ predict()
53
+ os.remove('tmp.jpg')