peshk1n commited on
Commit
080c130
·
verified ·
1 Parent(s): e1afdf2
app.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow import keras
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import data as tf_data
5
+ from tensorflow import image as tf_image
6
+ from tensorflow import io as tf_io
7
+ from PIL import Image
8
+ import json
9
+ from tensorflow.keras import layers, Model
10
+ import string
11
+ from transformers import TFAutoModel
12
+ import gradio as gr
13
+ import os
14
+ import numpy as np
15
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
16
+ from tensorflow.keras.preprocessing import image
17
+ from tensorflow.keras.models import Model
18
+
19
+
20
+ os.environ["KERAS_BACKEND"] = "tensorflow"
21
+ start_token = "[BOS]"
22
+ end_token = "[EOS]"
23
+ cls_token = "[CLS]"
24
+
25
+ data_dir = '/content/coco'
26
+ data_type_train = 'train2014'
27
+ data_type_val = 'val2014'
28
+
29
+ vocab_size = 24000
30
+ sentence_length = 20
31
+ batch_size = 128
32
+ img_size = 224
33
+
34
+ proj_dim = 192
35
+ dropout_rate = 0.1
36
+ num_patches = 14
37
+ patch_size = img_size // num_patches
38
+
39
+ num_heads = 3
40
+ num_layers = 6
41
+ attn_pool_dim = proj_dim
42
+ attn_pool_heads = num_heads
43
+ cap_query_num = 128
44
+
45
+ rnn_embedding_dim = 256
46
+ rnn_proj_dim = 512
47
+
48
+
49
+ with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
50
+ word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
51
+
52
+ with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
53
+ index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
54
+
55
+ cls_token_id = word_index[cls_token]
56
+
57
+
58
+ class PositionalEmbedding(layers.Layer):
59
+ def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
60
+ super().__init__(**kwargs)
61
+ self.sequence_length = sequence_length
62
+ self.input_dim = input_dim
63
+ self.output_dim = output_dim
64
+ self.token_embeddings = layers.Embedding(
65
+ input_dim=input_dim, output_dim=output_dim
66
+ )
67
+ self.position_embeddings = layers.Embedding(
68
+ input_dim=sequence_length, output_dim=output_dim
69
+ )
70
+
71
+ def call(self, inputs):
72
+ positions = tf.range(start=0, limit=self.sequence_length, delta=1)
73
+ embedded_tokens = self.token_embeddings(inputs)
74
+ embedded_positions = self.position_embeddings(positions)
75
+ output = embedded_tokens + embedded_positions
76
+ return output
77
+
78
+
79
+ class AttentionalPooling(layers.Layer):
80
+ def __init__(self, embed_dim, num_heads=6):
81
+ super().__init__()
82
+ self.embed_dim = embed_dim
83
+ self.num_heads = num_heads
84
+ self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
85
+ self.norm = layers.LayerNormalization()
86
+
87
+
88
+ def call(self, features, query):
89
+ attn_output = self.multihead_attention(
90
+ query=query,
91
+ value=features,
92
+ key=features
93
+ )
94
+
95
+ return self.norm(attn_output)
96
+
97
+
98
+ class TransformerBlock(layers.Layer):
99
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
100
+ super().__init__(**kwargs)
101
+ self.embed_dim = embed_dim
102
+ self.dense_dim = dense_dim
103
+ self.num_heads = num_heads
104
+ self.dropout_rate = dropout_rate
105
+ self.ln_epsilon = ln_epsilon
106
+
107
+ self.self_attention = layers.MultiHeadAttention(
108
+ num_heads=self.num_heads,
109
+ key_dim=self.embed_dim,
110
+ dropout=self.dropout_rate
111
+ )
112
+
113
+ if is_multimodal:
114
+ self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
115
+ self.dropout2 = layers.Dropout(self.dropout_rate)
116
+ self.cross_attention = layers.MultiHeadAttention(
117
+ num_heads=self.num_heads,
118
+ key_dim=self.embed_dim,
119
+ dropout=self.dropout_rate
120
+ )
121
+
122
+ self.dense_proj = tf.keras.Sequential([
123
+ layers.Dense(self.dense_dim, activation="gelu"),
124
+ layers.Dropout(self.dropout_rate),
125
+ layers.Dense(self.embed_dim)
126
+ ])
127
+
128
+ self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
129
+ self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
130
+
131
+ self.dropout1 = layers.Dropout(self.dropout_rate)
132
+ self.dropout3 = layers.Dropout(self.dropout_rate)
133
+
134
+
135
+ def get_causal_attention_mask(self, inputs):
136
+ seq_len = tf.shape(inputs)[1]
137
+ causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
138
+ return tf.expand_dims(causal_mask, 0)
139
+
140
+
141
+ def get_combined_mask(self, causal_mask, padding_mask):
142
+ padding_mask = tf.cast(padding_mask, tf.bool)
143
+
144
+ padding_mask = tf.expand_dims(padding_mask, 1)
145
+ return causal_mask & padding_mask
146
+
147
+
148
+ def call(self, inputs, encoder_outputs=None, mask=None):
149
+ att_mask = self.get_causal_attention_mask(inputs)
150
+ if mask is not None:
151
+ att_mask = self.get_combined_mask(att_mask, mask)
152
+
153
+ x = self.norm1(inputs)
154
+ attention_output_1 = self.self_attention(
155
+ query=x, key=x, value=x, attention_mask=att_mask
156
+ )
157
+ attention_output_1 = self.dropout1(attention_output_1)
158
+ x = x + attention_output_1
159
+
160
+ if encoder_outputs is not None:
161
+ x_norm = self.norm2(x)
162
+ attention_output_2 = self.cross_attention(
163
+ query=x_norm, key=encoder_outputs, value=encoder_outputs
164
+ )
165
+ attention_output_2 = self.dropout2(attention_output_2)
166
+ x = x + attention_output_2
167
+
168
+ x_norm = self.norm3(x)
169
+ proj_output = self.dense_proj(x_norm)
170
+ proj_output = self.dropout3(proj_output)
171
+ return x + proj_output
172
+
173
+
174
+ class UnimodalTextDecoder(layers.Layer):
175
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
176
+ super().__init__()
177
+ self.embed_dim = embed_dim
178
+ self.dense_dim = dense_dim
179
+ self.num_heads = num_heads
180
+ self.dropout_rate = dropout_rate
181
+ self.ln_epsilon = ln_epsilon
182
+ self.num_layers = num_layers
183
+
184
+ self.layers = [
185
+ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False)
186
+ for _ in range(self.num_layers)
187
+ ]
188
+ self.norm = tf.keras.layers.LayerNormalization()
189
+
190
+ def call(self, x, mask=None):
191
+ for layer in self.layers:
192
+ x = layer(inputs=x, mask=mask)
193
+ return self.norm(x)
194
+
195
+
196
+ class MultimodalTextDecoder(layers.Layer):
197
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
198
+ super().__init__()
199
+ self.embed_dim = embed_dim
200
+ self.dense_dim = dense_dim
201
+ self.num_heads = num_heads
202
+ self.dropout_rate = dropout_rate
203
+ self.ln_epsilon = ln_epsilon
204
+ self.num_layers = num_layers
205
+
206
+ self.layers = [
207
+ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True)
208
+ for _ in range(self.num_layers)
209
+ ]
210
+ self.norm = tf.keras.layers.LayerNormalization()
211
+
212
+ def call(self, x, encoder_outputs, mask=None):
213
+ for layer in self.layers:
214
+ x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
215
+ return self.norm(x)
216
+
217
+
218
+ class EmbedToLatents(layers.Layer):
219
+ def __init__(self, dim_latents, **kwargs):
220
+ super(EmbedToLatents, self).__init__(**kwargs)
221
+ self.dim_latents = dim_latents
222
+ self.to_latents = layers.Dense(
223
+ self.dim_latents,
224
+ use_bias=False
225
+ )
226
+
227
+ def call(self, inputs):
228
+ latents = self.to_latents(inputs)
229
+ return tf.math.l2_normalize(latents, axis=-1)
230
+
231
+
232
+ class Perplexity(tf.keras.metrics.Metric):
233
+ def __init__(self, name='perplexity', **kwargs):
234
+ super().__init__(name=name, **kwargs)
235
+ self.total_loss = self.add_weight(name='total_loss', initializer='zeros')
236
+ self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros')
237
+
238
+ def update_state(self, y_true, y_pred, sample_weight=None):
239
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
240
+ loss = loss_fn(y_true, y_pred)
241
+
242
+ mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
243
+ loss = tf.reduce_sum(loss * mask)
244
+ num_tokens = tf.reduce_sum(mask)
245
+
246
+ self.total_loss.assign_add(loss)
247
+ self.total_tokens.assign_add(num_tokens)
248
+
249
+ def result(self):
250
+ return tf.exp(self.total_loss / self.total_tokens)
251
+
252
+ def reset_states(self):
253
+ self.total_loss.assign(0.0)
254
+ self.total_tokens.assign(0.0)
255
+
256
+
257
+ model_name = "WinKawaks/vit-tiny-patch16-224"
258
+ vit_tiny_model = TFAutoModel.from_pretrained(model_name)
259
+ vit_tiny_model.trainable = True
260
+
261
+ for layer in vit_tiny_model.layers:
262
+ layer.trainable = True
263
+
264
+
265
+ class CoCaEncoder(tf.keras.Model):
266
+ def __init__(self,
267
+ vit, **kwargs):
268
+
269
+ super().__init__(**kwargs)
270
+
271
+ self.vit = vit
272
+
273
+ self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
274
+ self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
275
+
276
+ self.con_query = tf.Variable(
277
+ initial_value=tf.random.normal([1, 1, proj_dim]),
278
+ trainable=True,
279
+ name="con_query"
280
+ )
281
+
282
+ self.cap_query = tf.Variable(
283
+ initial_value=tf.random.normal([1, cap_query_num, proj_dim]),
284
+ trainable=True,
285
+ name="cap_query"
286
+ )
287
+
288
+ def call(self, input, training=False):
289
+ img_feature = self.vit(input).last_hidden_state
290
+
291
+ batch_size = tf.shape(img_feature)[0]
292
+ con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0)
293
+ cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0)
294
+
295
+ con_feature = self.contrastive_pooling(img_feature, con_query_b)
296
+ cap_feature = self.caption_pooling(img_feature, cap_query_b)
297
+
298
+ return con_feature, cap_feature
299
+
300
+
301
+ class CoCaDecoder(tf.keras.Model):
302
+ def __init__(self,
303
+ cls_token_id,
304
+ num_heads,
305
+ num_layers,
306
+ **kwargs):
307
+
308
+ super().__init__(**kwargs)
309
+
310
+ self.cls_token_id = cls_token_id
311
+
312
+ self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim)
313
+
314
+ self.unimodal_decoder = UnimodalTextDecoder(
315
+ proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
316
+ )
317
+ self.multimodal_decoder = MultimodalTextDecoder(
318
+ proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
319
+ )
320
+
321
+ self.to_logits = tf.keras.layers.Dense(
322
+ vocab_size,
323
+ name='logits_projection'
324
+ )
325
+
326
+ self.norm = layers.LayerNormalization()
327
+
328
+ def call(self, inputs, training=False):
329
+ input_text, cap_feature = inputs
330
+ batch_size = tf.shape(input_text)[0]
331
+ cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype))
332
+ ids = tf.concat([input_text, cls_tokens], axis=1)
333
+
334
+ text_mask = tf.not_equal(input_text, 0)
335
+ cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype)
336
+ extended_mask = tf.concat([text_mask, cls_mask], axis=1)
337
+
338
+ txt_embs = self.pos_emb(ids)
339
+
340
+ unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask)
341
+ multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask)
342
+
343
+ cls_token_feature = self.norm(unimodal_out[:, -1:, :])
344
+ multimodal_logits = self.to_logits(multimodal_out)
345
+
346
+ return cls_token_feature, multimodal_logits
347
+
348
+
349
+ class CoCaModel(tf.keras.Model):
350
+ def __init__(self,
351
+ vit,
352
+ cls_token_id,
353
+ num_heads,
354
+ num_layers):
355
+ super().__init__()
356
+
357
+ self.encoder = CoCaEncoder(vit, name="coca_encoder")
358
+ self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder")
359
+
360
+ self.img_to_latents = EmbedToLatents(proj_dim)
361
+ self.text_to_latents = EmbedToLatents(proj_dim)
362
+
363
+ self.pad_id = 0
364
+ self.temperature = 0.07
365
+ self.caption_loss_weight = 1.0
366
+ self.contrastive_loss_weight = 1.0
367
+
368
+ self.perplexity = Perplexity()
369
+
370
+ def call(self, inputs, training=False):
371
+ image, text = inputs
372
+ con_feature, cap_feature = self.encoder(image)
373
+ cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
374
+ return con_feature, cls_token_feature, multimodal_logits
375
+
376
+ def compile(self, optimizer):
377
+ super().compile()
378
+ self.optimizer = optimizer
379
+
380
+ def compute_caption_loss(self, multimodal_out, caption_target):
381
+ caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
382
+ caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
383
+
384
+ return tf.reduce_mean(caption_loss)
385
+
386
+ def compute_contrastive_loss(self, con_feature, cls_feature):
387
+ text_embeds = tf.squeeze(cls_feature, axis=1)
388
+ image_embeds = tf.squeeze(con_feature, axis=1)
389
+
390
+ text_latents = self.text_to_latents(text_embeds)
391
+ image_latents = self.img_to_latents(image_embeds)
392
+
393
+ sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature
394
+
395
+ batch_size = tf.shape(sim)[0]
396
+ contrastive_labels = tf.range(batch_size)
397
+
398
+ loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
399
+ loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
400
+ contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
401
+
402
+ return contrastive_loss
403
+
404
+ def train_step(self, data):
405
+ (images, caption_input), caption_target = data
406
+
407
+ with tf.GradientTape() as tape:
408
+ con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True)
409
+
410
+ caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
411
+ contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
412
+
413
+ total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
414
+
415
+ gradients = tape.gradient(total_loss, self.trainable_variables)
416
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
417
+
418
+ self.perplexity.update_state(caption_target, multimodal_out)
419
+
420
+ return {
421
+ 'total_loss': total_loss,
422
+ 'caption_loss': caption_loss,
423
+ 'contrastive_loss': contrastive_loss,
424
+ 'perplexity': self.perplexity.result()
425
+ }
426
+
427
+ def test_step(self, data):
428
+ (images, caption_input), caption_target = data
429
+
430
+ con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False)
431
+
432
+ caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
433
+ contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
434
+
435
+ total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
436
+
437
+ self.perplexity.update_state(caption_target, multimodal_out)
438
+
439
+ return {
440
+ 'total_loss': total_loss,
441
+ 'caption_loss': caption_loss,
442
+ 'contrastive_loss': contrastive_loss,
443
+ 'perplexity': self.perplexity.result()
444
+ }
445
+
446
+ def reset_metrics(self):
447
+ self.perplexity.reset_state()
448
+
449
+
450
+ coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
451
+
452
+ dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
453
+ dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64)
454
+ _ = coca_model((dummy_features, dummy_captions))
455
+
456
+ optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
457
+ coca_model.compile(optimizer)
458
+
459
+ save_dir = "models/"
460
+ model_name = "coca"
461
+ coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
462
+
463
+
464
+ img_embed_dim = 2048
465
+ reg_count = 7 * 7
466
+
467
+ base_model = ResNet50(weights='imagenet', include_top=False)
468
+ model = Model(inputs=base_model.input, outputs=base_model.output)
469
+
470
+ def preprocess_image(img):
471
+ img = tf.image.resize(img, (img_size, img_size))
472
+ img = tf.convert_to_tensor(img)
473
+ img = preprocess_input(img)
474
+ return np.expand_dims(img, axis=0)
475
+
476
+ def create_features(img):
477
+ img = preprocess_image(img)
478
+ features = model.predict(img, verbose=0)
479
+ features = features.reshape((1, reg_count, img_embed_dim))
480
+ return features
481
+
482
+
483
+ class BahdanauAttention(layers.Layer):
484
+ def __init__(self, units, **kwargs):
485
+ super().__init__(**kwargs)
486
+ self.units = units
487
+ self.W1 = layers.Dense(units)
488
+ self.W2 = layers.Dense(units)
489
+ self.V = layers.Dense(1)
490
+
491
+ def call(self, features, hidden):
492
+ hidden = tf.expand_dims(hidden, 1)
493
+ score = self.V(tf.nn.tanh(
494
+ self.W1(features) + self.W2(hidden)
495
+ ))
496
+ alpha = tf.nn.softmax(score, axis=1)
497
+ context = tf.reduce_sum(alpha * features, axis=1)
498
+ return context, alpha
499
+
500
+
501
+ class ImageCaptioningModel(tf.keras.Model):
502
+ def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
503
+ super().__init__(**kwargs)
504
+
505
+ self.vocab_size = vocab_size
506
+ self.max_caption_len = max_caption_len
507
+ self.embedding_dim = embedding_dim
508
+ self.lstm_units = lstm_units
509
+ self.dropout_rate = dropout_rate
510
+
511
+ self.embedding = layers.Embedding(vocab_size, embedding_dim)
512
+ self.embedding_dropout = layers.Dropout(dropout_rate)
513
+ self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True)
514
+ self.attention = BahdanauAttention(lstm_units)
515
+ self.fc_dropout = layers.Dropout(dropout_rate)
516
+ self.fc = layers.Dense(vocab_size, activation='softmax')
517
+
518
+ self.init_h = layers.Dense(lstm_units, activation='tanh')
519
+ self.init_c = layers.Dense(lstm_units)
520
+
521
+ self.concatenate = layers.Concatenate(axis=-1)
522
+
523
+ def call(self, inputs):
524
+ features, captions = inputs
525
+
526
+ mean_features = tf.reduce_mean(features, axis=1)
527
+ h = self.init_h(mean_features)
528
+ c = self.init_c(mean_features)
529
+
530
+ embeddings = self.embedding(captions)
531
+ embeddings = self.embedding_dropout(embeddings)
532
+
533
+ outputs = []
534
+ for t in range(self.max_caption_len):
535
+ context, _ = self.attention(features, h)
536
+
537
+ lstm_input = self.concatenate([embeddings[:, t, :], context])
538
+ lstm_input = tf.expand_dims(lstm_input, 1)
539
+
540
+ output, h, c = self.lstm(lstm_input, initial_state=[h, c])
541
+ outputs.append(output)
542
+
543
+ outputs = tf.concat(outputs, axis=1)
544
+ outputs = self.fc_dropout(outputs)
545
+ return self.fc(outputs)
546
+
547
+
548
+ rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
549
+ image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
550
+ text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
551
+ _ = rnn_model([image_input, text_input])
552
+
553
+ rnn_model.compile(
554
+ optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
555
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
556
+ metrics=[Perplexity()]
557
+ )
558
+
559
+ save_dir = "models/"
560
+ model_name = "rnn_attn"
561
+
562
+ rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
563
+
564
+ beam_width=3
565
+ max_length=sentence_length-1
566
+ temperature=1.0
567
+
568
+ image_mean = [0.5, 0.5, 0.5]
569
+ image_std = [0.5, 0.5, 0.5]
570
+
571
+ def load_and_preprocess_image(img):
572
+ img = tf.convert_to_tensor(img)
573
+ img = tf.image.resize(img, (img_size, img_size))
574
+ img = img / 255.0
575
+
576
+ img = (img - image_mean) / image_std
577
+ img = tf.transpose(img, perm=[2, 0, 1])
578
+
579
+ return np.expand_dims(img, axis=0)
580
+
581
+
582
+ def has_repeated_ngrams(seq, n=2):
583
+ ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
584
+ return len(ngrams) != len(set(ngrams))
585
+
586
+
587
+ def generate_caption_coca(image):
588
+ img_processed = load_and_preprocess_image(image)
589
+ _, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
590
+
591
+ beams = [([word_index[start_token]], 0.0)]
592
+
593
+ for _ in range(max_length):
594
+ new_beams = []
595
+ for seq, log_prob in beams:
596
+ if seq[-1] == word_index[end_token]:
597
+ new_beams.append((seq, log_prob))
598
+ continue
599
+
600
+ text_input = np.zeros((1, max_length), dtype=np.int32)
601
+ text_input[0, :len(seq)] = seq
602
+
603
+ predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
604
+ _, logits = predictions
605
+ logits = logits[0, len(seq)-1, :] / temperature
606
+ probs = np.exp(logits - np.max(logits))
607
+ probs /= probs.sum()
608
+
609
+ top_k = np.argpartition(probs, -beam_width)[-beam_width:]
610
+ for token in top_k:
611
+ new_seq = seq + [token]
612
+ new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
613
+
614
+ if has_repeated_ngrams(new_seq, n=2):
615
+ new_log_prob -= 0.5
616
+
617
+ new_beams.append((new_seq, new_log_prob))
618
+
619
+ beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
620
+ if all(beam[0][-1] == word_index[end_token] for beam in beams):
621
+ break
622
+
623
+ best_seq = max(beams, key=lambda x: x[1])[0]
624
+ return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
625
+
626
+
627
+ def generate_caption_rnn(image):
628
+ image_embedding = create_features(image)
629
+ beams = [([word_index[start_token]], 0.0)]
630
+
631
+ for _ in range(max_length):
632
+ new_beams = []
633
+ for seq, log_prob in beams:
634
+ if seq[-1] == word_index[end_token]:
635
+ new_beams.append((seq, log_prob))
636
+ continue
637
+
638
+ text_input = np.zeros((1, max_length), dtype=np.int32)
639
+ text_input[0, :len(seq)] = seq
640
+
641
+ predictions = rnn_model.predict([image_embedding, text_input], verbose=0)
642
+ probs = predictions[0, len(seq)-1, :]
643
+ probs = probs ** (1 / temperature)
644
+ probs /= probs.sum()
645
+
646
+ top_k = np.argpartition(probs, -beam_width)[-beam_width:]
647
+ for token in top_k:
648
+ new_seq = seq + [token]
649
+ new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
650
+
651
+ if has_repeated_ngrams(new_seq, n=2):
652
+ new_log_prob -= 0.5
653
+ new_beams.append((new_seq, new_log_prob))
654
+
655
+ beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
656
+ if all(beam[0][-1] == word_index[end_token] for beam in beams):
657
+ break
658
+
659
+ best_seq = max(beams, key=lambda x: x[1])[0]
660
+ return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
661
+
662
+
663
+ def generate_both(image):
664
+ caption1 = generate_caption_rnn(image)
665
+ caption2 = generate_caption_coca(image)
666
+ return f"RNN: {caption1}\n\nCoCa: {caption2}"
667
+
668
+
669
+ interface = gr.Interface(
670
+ fn=generate_both,
671
+ inputs=gr.Image(type="pil", label="Изображение"),
672
+ outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
673
+ allow_flagging="never",
674
+ submit_btn="Сгенерировать",
675
+ clear_btn="Очистить",
676
+ deep_link=False
677
+ )
678
+
679
+ with gr.Blocks() as demo:
680
+ gr.Markdown("# 🖼️ Генератор описаний к изображениям")
681
+ interface.render()
682
+
683
+
684
+ if __name__ == "__main__":
685
+ demo.launch(ssr_mode=False, show_api=False)
models/coca.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dc33edd1df6158e35bef3f5c4e151c6ce69f4105a487e052754712debfd3656
3
+ size 262132344
models/rnn_attn.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e09a294e234d15baae6ef4916f35772ec53e2645e2de58c54e0996a7baa027
3
+ size 331683632
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ tensorflow>=2.18.0
2
+ keras>=3.8.0
3
+ numpy>=2.0.2
4
+ pillow>=11.2.1
5
+ transformers>=4.52.4
6
+ gradio>=5.31.0
7
+ h5py>=3.14.0
vocabs/index_word.json ADDED
The diff for this file is too large to render. See raw diff
 
vocabs/word_index.json ADDED
The diff for this file is too large to render. See raw diff