Arijit-hazra commited on
Commit
af10606
1 Parent(s): 2ecfc1e

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model/captioner_weights.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
load_model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### IMPORTS
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ import einops
6
+ import numpy as np
7
+ import tqdm
8
+
9
+ import collections
10
+ import re
11
+ import string
12
+ import pickle
13
+
14
+ print("import complete")
15
+ #=========================================================================================================================
16
+ ### UTILITY FUNCTIONS
17
+ #=========================================================================================================================
18
+
19
+ IMAGE_SHAPE=(224, 224, 3)
20
+
21
+ @tf.keras.utils.register_keras_serializable()
22
+ def custom_standardization(s):
23
+ s = tf.strings.lower(s)
24
+ s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
25
+ s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
26
+ return s
27
+
28
+ def load_image(image_path):
29
+ img = tf.io.read_file(image_path)
30
+ img = tf.io.decode_jpeg(img, channels=3)
31
+ img = tf.image.resize(img, IMAGE_SHAPE[:-1])
32
+ return img
33
+
34
+ def load_image_obj(img):
35
+ img = tf.image.resize(img, IMAGE_SHAPE[:-1])
36
+ return img
37
+
38
+ def masked_loss(labels, preds):
39
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)
40
+
41
+ mask = (labels != 0) & (loss < 1e8)
42
+ mask = tf.cast(mask, loss.dtype)
43
+
44
+ loss = loss*mask
45
+ loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
46
+ return loss
47
+
48
+ def masked_acc(labels, preds):
49
+ mask = tf.cast(labels!=0, tf.float32)
50
+ preds = tf.argmax(preds, axis=-1)
51
+ labels = tf.cast(labels, tf.int64)
52
+ match = tf.cast(preds == labels, mask.dtype)
53
+ acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)
54
+ return acc
55
+
56
+ print("utility complete")
57
+ #=========================================================================================================================
58
+ ### MODEL CLASS
59
+ #=========================================================================================================================
60
+
61
+ mobilenet = tf.keras.applications.MobileNetV3Small(
62
+ input_shape=IMAGE_SHAPE,
63
+ include_top=False,
64
+ include_preprocessing=True)
65
+ mobilenet.trainable=False
66
+
67
+ class SeqEmbedding(tf.keras.layers.Layer):
68
+ def __init__(self, vocab_size, max_length, depth):
69
+ super().__init__()
70
+ self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)
71
+
72
+ self.token_embedding = tf.keras.layers.Embedding(
73
+ input_dim=vocab_size,
74
+ output_dim=depth,
75
+ mask_zero=True)
76
+
77
+ self.add = tf.keras.layers.Add()
78
+
79
+
80
+ def call(self, seq):
81
+ seq = self.token_embedding(seq) # (batch, seq, depth)
82
+
83
+ x = tf.range(tf.shape(seq)[1]) # (seq)
84
+ x = x[tf.newaxis, :] # (1, seq)
85
+ x = self.pos_embedding(x) # (1, seq, depth)
86
+
87
+ return self.add([seq,x])
88
+
89
+ class CausalSelfAttention(tf.keras.layers.Layer):
90
+ def __init__(self, **kwargs):
91
+ super().__init__()
92
+ self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
93
+ # Use Add instead of + so the keras mask propagates through.
94
+ self.add = tf.keras.layers.Add()
95
+ self.layernorm = tf.keras.layers.LayerNormalization()
96
+
97
+
98
+ def call(self, x):
99
+ attn = self.mha(query=x, value=x,
100
+ use_causal_mask=True)
101
+ x = self.add([x, attn])
102
+ return self.layernorm(x)
103
+
104
+ class CrossAttention(tf.keras.layers.Layer):
105
+ def __init__(self,**kwargs):
106
+ super().__init__()
107
+ self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
108
+ self.add = tf.keras.layers.Add()
109
+ self.layernorm = tf.keras.layers.LayerNormalization()
110
+
111
+ def call(self, x, y, **kwargs):
112
+ attn, attention_scores = self.mha(
113
+ query=x, value=y,
114
+ return_attention_scores=True)
115
+
116
+ self.last_attention_scores = attention_scores
117
+
118
+ x = self.add([x, attn])
119
+ return self.layernorm(x)
120
+
121
+ class FeedForward(tf.keras.layers.Layer):
122
+ def __init__(self, units, dropout_rate=0.1):
123
+ super().__init__()
124
+ self.seq = tf.keras.Sequential([
125
+ tf.keras.layers.Dense(units=2*units, activation='relu'),
126
+ tf.keras.layers.Dense(units=units),
127
+ tf.keras.layers.Dropout(rate=dropout_rate),
128
+ ])
129
+
130
+ self.layernorm = tf.keras.layers.LayerNormalization()
131
+
132
+ def call(self, x):
133
+ x = x + self.seq(x)
134
+ return self.layernorm(x)
135
+
136
+ class DecoderLayer(tf.keras.layers.Layer):
137
+ def __init__(self, units, num_heads=1, dropout_rate=0.1):
138
+ super().__init__()
139
+
140
+ self.self_attention = CausalSelfAttention(num_heads=num_heads,
141
+ key_dim=units,
142
+ dropout=dropout_rate)
143
+ self.cross_attention = CrossAttention(num_heads=num_heads,
144
+ key_dim=units,
145
+ dropout=dropout_rate)
146
+ self.ff = FeedForward(units=units, dropout_rate=dropout_rate)
147
+
148
+
149
+ def call(self, inputs, training=False):
150
+ in_seq, out_seq = inputs
151
+
152
+ # Text input
153
+ out_seq = self.self_attention(out_seq)
154
+
155
+ out_seq = self.cross_attention(out_seq, in_seq)
156
+
157
+ self.last_attention_scores = self.cross_attention.last_attention_scores
158
+
159
+ out_seq = self.ff(out_seq)
160
+
161
+ return out_seq
162
+
163
+ class TokenOutput(tf.keras.layers.Layer):
164
+ def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), bias=None, **kwargs):
165
+ super().__init__()
166
+
167
+ self.dense = tf.keras.layers.Dense(
168
+ units=tokenizer.vocabulary_size(), **kwargs)
169
+ self.tokenizer = tokenizer
170
+ self.banned_tokens = banned_tokens
171
+
172
+ self.bias = bias
173
+
174
+ def adapt(self, ds):
175
+ counts = collections.Counter()
176
+ vocab_dict = {name: id
177
+ for id, name in enumerate(self.tokenizer.get_vocabulary())}
178
+
179
+ for tokens in tqdm.tqdm(ds):
180
+ counts.update(tokens.numpy().flatten())
181
+
182
+ counts_arr = np.zeros(shape=(self.tokenizer.vocabulary_size(),))
183
+ counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values())
184
+
185
+ counts_arr = counts_arr[:]
186
+ for token in self.banned_tokens:
187
+ counts_arr[vocab_dict[token]] = 0
188
+
189
+ total = counts_arr.sum()
190
+ p = counts_arr/total
191
+ p[counts_arr==0] = 1.0
192
+ log_p = np.log(p) # log(1) == 0
193
+
194
+ entropy = -(log_p*p).sum()
195
+
196
+ print()
197
+ print(f"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}")
198
+ print(f"Marginal entropy: {entropy:0.2f}")
199
+
200
+ self.bias = log_p
201
+ self.bias[counts_arr==0] = -1e9
202
+
203
+ def call(self, x):
204
+ x = self.dense(x)
205
+ return x + self.bias
206
+
207
+ def get_config(self):
208
+ config = super(TokenOutput, self).get_config()
209
+ config.update({
210
+ "tokenizer": self.tokenizer,
211
+ "banned_tokens": self.banned_tokens,
212
+ "bias": self.bias,
213
+ "dense":self.dense
214
+ })
215
+
216
+ return config
217
+
218
+ class Captioner(tf.keras.Model):
219
+ @classmethod
220
+ def add_method(cls, fun):
221
+ setattr(cls, fun.__name__, fun)
222
+ return fun
223
+
224
+ def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,
225
+ units=256, max_length=50, num_heads=1, dropout_rate=0.1):
226
+ super().__init__()
227
+ self.feature_extractor = feature_extractor
228
+ self.tokenizer = tokenizer
229
+ self.word_to_index = tf.keras.layers.StringLookup(
230
+ mask_token="",
231
+ vocabulary=tokenizer.get_vocabulary())
232
+ self.index_to_word = tf.keras.layers.StringLookup(
233
+ mask_token="",
234
+ vocabulary=tokenizer.get_vocabulary(),
235
+ invert=True)
236
+
237
+ self.seq_embedding = SeqEmbedding(
238
+ vocab_size=tokenizer.vocabulary_size(),
239
+ depth=units,
240
+ max_length=max_length)
241
+
242
+ self.decoder_layers = [
243
+ DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
244
+ for n in range(num_layers)]
245
+
246
+ self.output_layer = output_layer
247
+
248
+ def call(self, inputs):
249
+ image, txt = inputs
250
+
251
+ if image.shape[-1] == 3:
252
+ # Apply the feature-extractor, if you get an RGB image.
253
+ image = self.feature_extractor(image)
254
+
255
+ # Flatten the feature map
256
+ image = einops.rearrange(image, 'b h w c -> b (h w) c')
257
+
258
+
259
+ if txt.dtype == tf.string:
260
+ # Apply the tokenizer if you get string inputs.
261
+ txt = self.tokenizer(txt)
262
+
263
+ txt = self.seq_embedding(txt)
264
+
265
+ # Look at the image
266
+ for dec_layer in self.decoder_layers:
267
+ txt = dec_layer(inputs=(image, txt))
268
+
269
+ txt = self.output_layer(txt)
270
+
271
+ return txt
272
+
273
+
274
+ def simple_gen(self, image, temperature=1):
275
+ initial = self.word_to_index([['[START]']]) # (batch, sequence)
276
+ img_features = self.feature_extractor(image[tf.newaxis, ...])
277
+
278
+ tokens = initial # (batch, sequence)
279
+ for n in range(50):
280
+ preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
281
+ preds = preds[:,-1, :] #(batch, vocab)
282
+ if temperature==0:
283
+ next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
284
+ else:
285
+ next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
286
+ tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
287
+
288
+ if next[0] == self.word_to_index('[END]'):
289
+ break
290
+
291
+ words = self.index_to_word(tokens[0, 1:-1])
292
+ result = tf.strings.reduce_join(words, axis=-1, separator=' ')
293
+ return result.numpy().decode()
294
+
295
+ # def get_config(self):
296
+ # config = super().get_config()
297
+ # config.update({"feature_extractor": self.feature_extractor,
298
+ # "tokenizer": self.tokenizer,
299
+ # "word_to_index": self.word_to_index,
300
+ # "index_to_word": self.index_to_word,
301
+ # "outputlayer": self.output_layer,
302
+ # "seq_embedding": self.seq_embedding,
303
+ # "decoder_layers": self.decoder_layers
304
+ # })
305
+ # return config
306
+
307
+ # def build_from_config(self, config):
308
+ # return super().build_from_config(config)
309
+
310
+ # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
311
+ # loss=masked_loss,
312
+ # metrics=[masked_acc])
313
+
314
+ print("model complete")
315
+ #=========================================================================================================================
316
+ ### LOAD FUNCTION
317
+ #=========================================================================================================================
318
+
319
+ def build():
320
+ filename = "model/tokenizer.pkl"
321
+ token_meta = pickle.load(open(filename, 'rb'))
322
+ tokenizer = tf.keras.layers.TextVectorization.from_config(token_meta["config"])
323
+ tokenizer.set_weights(token_meta['weights'])
324
+ print(tokenizer("bulid sentence"))
325
+ word_to_index = tf.keras.layers.StringLookup(
326
+ mask_token="",
327
+ vocabulary=tokenizer.get_vocabulary())
328
+
329
+ index_to_word = tf.keras.layers.StringLookup(
330
+ mask_token="",
331
+ vocabulary=tokenizer.get_vocabulary(),
332
+ invert=True)
333
+
334
+ output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))
335
+ filename = "model/output_layer.pkl"
336
+ bias = pickle.load(open(filename, 'rb'))
337
+ output_layer.bias = bias
338
+
339
+ load_model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,
340
+ units=256, dropout_rate=0.5, num_layers=2, num_heads=2)
341
+ load_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
342
+ loss=masked_loss,
343
+ metrics=[masked_acc])
344
+
345
+ image_url = 'https://tensorflow.org/images/surf.jpg'
346
+ image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
347
+ image = load_image(image_path)
348
+ load_model.simple_gen(image)
349
+
350
+ path = "model/captioner_weights"
351
+ load_model.load_weights(path)
352
+ return load_model
353
+
354
+ # loaded_model = build()
355
+ print("loaded")
356
+ #=========================================================================================================================
357
+ ### TEST RUN
358
+ #=========================================================================================================================
359
+
360
+ image_url = 'https://tensorflow.org/images/surf.jpg'
361
+ image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
362
+ image = load_image(image_path)
363
+
model/captioner_weights.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fa754d28af355d673c5a5250a65eb9d95d9a5981c6a45f6b01a4f7c562b1bfd
3
+ size 80382098
model/captioner_weights.index ADDED
Binary file (24.3 kB). View file
 
model/checkpoint ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model_checkpoint_path: "captioner_weights"
2
+ all_model_checkpoint_paths: "captioner_weights"
model/output_layer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce05dcabab270cce9610dc02f1eceeee990839820ebfc315fd3e5f24c87920dd
3
+ size 48157
model/tokenizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d04920b7dc008907947069ca75e03e3de98a13011cd97d1bbf66bdeef99093
3
+ size 81048