peshk1n commited on
Commit
bdc8b80
·
verified ·
1 Parent(s): 2aae9c6

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import keras
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import data as tf_data
5
+ from tensorflow import image as tf_image
6
+ from tensorflow import io as tf_io
7
+ from PIL import Image
8
+ import json
9
+ from tensorflow.keras import layers, Model
10
+ import string
11
+ from transformers import TFAutoModel
12
+ import gradio as gr
13
+ import os
14
+ import numpy as np
15
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
16
+ from tensorflow.keras.preprocessing import image
17
+ from tensorflow.keras.models import Model
18
+
19
+
20
+ # Переменные ================================
21
+ start_token = "[BOS]"
22
+ end_token = "[EOS]"
23
+ cls_token = "[CLS]"
24
+
25
+ data_dir = '/content/coco'
26
+ data_type_train = 'train2014'
27
+ data_type_val = 'val2014'
28
+
29
+ vocab_size = 24000
30
+ sentence_length = 20
31
+ batch_size = 128
32
+ img_size = 224
33
+
34
+ proj_dim = 192
35
+ dropout_rate = 0.1
36
+ num_patches = 14
37
+ patch_size = img_size // num_patches
38
+
39
+ num_heads = 3
40
+ num_layers = 6
41
+ attn_pool_dim = proj_dim
42
+ attn_pool_heads = num_heads
43
+ cap_query_num = 128
44
+
45
+ #RNN
46
+ rnn_embedding_dim = 256
47
+ rnn_proj_dim = 512
48
+
49
+ # =================================
50
+
51
+ # Загрузка word_index
52
+ with open('/content/drive/MyDrive/models/vocabs/word_index.json', 'r', encoding='utf-8') as f:
53
+ word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
54
+
55
+ # Загрузка index_word
56
+ with open('/content/drive/MyDrive/models/vocabs/index_word.json', 'r', encoding='utf-8') as f:
57
+ index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
58
+
59
+ cls_token_id = word_index[cls_token]
60
+
61
+
62
+ class PositionalEmbedding(layers.Layer):
63
+ def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
64
+ super().__init__(**kwargs)
65
+ self.sequence_length = sequence_length
66
+ self.input_dim = input_dim
67
+ self.output_dim = output_dim
68
+ self.token_embeddings = layers.Embedding(
69
+ input_dim=input_dim, output_dim=output_dim
70
+ )
71
+ self.position_embeddings = layers.Embedding(
72
+ input_dim=sequence_length, output_dim=output_dim
73
+ )
74
+
75
+ def call(self, inputs):
76
+ positions = tf.range(start=0, limit=self.sequence_length, delta=1)
77
+ embedded_tokens = self.token_embeddings(inputs)
78
+ embedded_positions = self.position_embeddings(positions)
79
+ output = embedded_tokens + embedded_positions
80
+ return output
81
+
82
+
83
+ class AttentionalPooling(layers.Layer):
84
+ def __init__(self, embed_dim, num_heads=6):
85
+ super().__init__()
86
+ self.embed_dim = embed_dim
87
+ self.num_heads = num_heads
88
+ self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
89
+ self.norm = layers.LayerNormalization()
90
+
91
+
92
+ def call(self, features, query):
93
+ attn_output = self.multihead_attention(
94
+ query=query,
95
+ value=features,
96
+ key=features
97
+ )
98
+
99
+ return self.norm(attn_output)
100
+
101
+
102
+ class TransformerBlock(layers.Layer):
103
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
104
+ super().__init__(**kwargs)
105
+ self.embed_dim = embed_dim
106
+ self.dense_dim = dense_dim
107
+ self.num_heads = num_heads
108
+ self.dropout_rate = dropout_rate
109
+ self.ln_epsilon = ln_epsilon
110
+
111
+ # Self-Attention
112
+ self.self_attention = layers.MultiHeadAttention(
113
+ num_heads=self.num_heads,
114
+ key_dim=self.embed_dim,
115
+ dropout=self.dropout_rate
116
+ )
117
+
118
+ # Cross-Attention
119
+ if is_multimodal:
120
+ self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
121
+ self.dropout2 = layers.Dropout(self.dropout_rate)
122
+ self.cross_attention = layers.MultiHeadAttention(
123
+ num_heads=self.num_heads,
124
+ key_dim=self.embed_dim,
125
+ dropout=self.dropout_rate
126
+ )
127
+
128
+
129
+ # Feed-Forward Network
130
+ self.dense_proj = keras.Sequential([
131
+ layers.Dense(self.dense_dim, activation="gelu"),
132
+ layers.Dropout(self.dropout_rate),
133
+ layers.Dense(self.embed_dim)
134
+ ])
135
+
136
+ # Layer Normalization
137
+ self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
138
+ self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
139
+
140
+ # Dropout
141
+ self.dropout1 = layers.Dropout(self.dropout_rate)
142
+ self.dropout3 = layers.Dropout(self.dropout_rate)
143
+
144
+
145
+ def get_causal_attention_mask(self, inputs):
146
+ seq_len = tf.shape(inputs)[1]
147
+ causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
148
+ return tf.expand_dims(causal_mask, 0)
149
+
150
+
151
+ def get_combined_mask(self, causal_mask, padding_mask):
152
+ padding_mask = tf.cast(padding_mask, tf.bool)
153
+
154
+ padding_mask = tf.expand_dims(padding_mask, 1) # (B, 1, L)
155
+ return causal_mask & padding_mask
156
+
157
+
158
+ def call(self, inputs, encoder_outputs=None, mask=None):
159
+ att_mask = self.get_causal_attention_mask(inputs)
160
+ if mask is not None:
161
+ att_mask = self.get_combined_mask(att_mask, mask)
162
+
163
+ # Self-Attention
164
+ x = self.norm1(inputs)
165
+ attention_output_1 = self.self_attention(
166
+ query=x, key=x, value=x, attention_mask=att_mask
167
+ )
168
+ attention_output_1 = self.dropout1(attention_output_1)
169
+ x = x + attention_output_1 # Add & Norm
170
+
171
+ # Cross-Attention
172
+ if encoder_outputs is not None:
173
+ x_norm = self.norm2(x)
174
+ attention_output_2 = self.cross_attention(
175
+ query=x_norm, key=encoder_outputs, value=encoder_outputs
176
+ )
177
+ attention_output_2 = self.dropout2(attention_output_2)
178
+ x = x + attention_output_2 # Add & Norm
179
+
180
+ # Feed-Forward Network (FFN)
181
+ x_norm = self.norm3(x)
182
+ proj_output = self.dense_proj(x_norm)
183
+ proj_output = self.dropout3(proj_output)
184
+ return x + proj_output # Add & Norm
185
+
186
+
187
+ class UnimodalTextDecoder(tf.keras.layers.Layer):
188
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
189
+ super().__init__()
190
+ self.embed_dim = embed_dim
191
+ self.dense_dim = dense_dim
192
+ self.num_heads = num_heads
193
+ self.dropout_rate = dropout_rate
194
+ self.ln_epsilon = ln_epsilon
195
+ self.num_layers = num_layers
196
+
197
+ self.layers = [
198
+ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False)
199
+ for _ in range(self.num_layers)
200
+ ]
201
+ self.norm = tf.keras.layers.LayerNormalization()
202
+
203
+
204
+ def call(self, x, mask=None):
205
+ for layer in self.layers:
206
+ x = layer(inputs=x, mask=mask)
207
+ return self.norm(x)
208
+
209
+
210
+
211
+ class MultimodalTextDecoder(tf.keras.layers.Layer):
212
+ def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
213
+ super().__init__()
214
+ self.embed_dim = embed_dim
215
+ self.dense_dim = dense_dim
216
+ self.num_heads = num_heads
217
+ self.dropout_rate = dropout_rate
218
+ self.ln_epsilon = ln_epsilon
219
+ self.num_layers = num_layers
220
+
221
+ self.layers = [
222
+ TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True)
223
+ for _ in range(self.num_layers)
224
+ ]
225
+ self.norm = tf.keras.layers.LayerNormalization()
226
+
227
+
228
+ def call(self, x, encoder_outputs, mask=None):
229
+ for layer in self.layers:
230
+ x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
231
+ return self.norm(x)
232
+
233
+
234
+ class EmbedToLatents(layers.Layer):
235
+ def __init__(self, dim_latents, **kwargs):
236
+ super(EmbedToLatents, self).__init__(**kwargs)
237
+ self.dim_latents = dim_latents
238
+ self.to_latents = layers.Dense(
239
+ self.dim_latents,
240
+ use_bias=False
241
+ )
242
+
243
+ def call(self, inputs):
244
+ latents = self.to_latents(inputs)
245
+ return tf.math.l2_normalize(latents, axis=-1)
246
+
247
+
248
+ class Perplexity(tf.keras.metrics.Metric):
249
+ def __init__(self, name='perplexity', **kwargs):
250
+ super().__init__(name=name, **kwargs)
251
+ self.total_loss = self.add_weight(name='total_loss', initializer='zeros')
252
+ self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros')
253
+
254
+ def update_state(self, y_true, y_pred, sample_weight=None):
255
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
256
+ loss = loss_fn(y_true, y_pred)
257
+
258
+ mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
259
+ loss = tf.reduce_sum(loss * mask)
260
+ num_tokens = tf.reduce_sum(mask)
261
+
262
+ self.total_loss.assign_add(loss)
263
+ self.total_tokens.assign_add(num_tokens)
264
+
265
+ def result(self):
266
+ return tf.exp(self.total_loss / self.total_tokens)
267
+
268
+ def reset_states(self):
269
+ self.total_loss.assign(0.0)
270
+ self.total_tokens.assign(0.0)
271
+
272
+
273
+ model_name = "WinKawaks/vit-tiny-patch16-224"
274
+ vit_tiny_model = TFAutoModel.from_pretrained(model_name)
275
+ vit_tiny_model.trainable = True
276
+
277
+ for layer in vit_tiny_model.layers:
278
+ layer.trainable = True
279
+
280
+
281
+ class CoCaEncoder(keras.Model):
282
+ def __init__(self,
283
+ vit, **kwargs):
284
+
285
+ super().__init__(**kwargs)
286
+
287
+ self.vit = vit
288
+
289
+ self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
290
+ self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
291
+
292
+ self.con_query = tf.Variable(
293
+ initial_value=tf.random.normal([1, 1, proj_dim]),
294
+ trainable=True,
295
+ name="con_query"
296
+ )
297
+
298
+ self.cap_query = tf.Variable(
299
+ initial_value=tf.random.normal([1, cap_query_num, proj_dim]),
300
+ trainable=True,
301
+ name="cap_query"
302
+ )
303
+
304
+
305
+ def call(self, input, training=False):
306
+ img_feature = self.vit(input).last_hidden_state
307
+
308
+ batch_size = tf.shape(img_feature)[0]
309
+ con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0)
310
+ cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0)
311
+
312
+ con_feature = self.contrastive_pooling(img_feature, con_query_b)
313
+ cap_feature = self.caption_pooling(img_feature, cap_query_b)
314
+
315
+ return con_feature, cap_feature
316
+
317
+
318
+
319
+ class CoCaDecoder(keras.Model):
320
+ def __init__(self,
321
+ cls_token_id,
322
+ num_heads,
323
+ num_layers,
324
+ **kwargs):
325
+
326
+ super().__init__(**kwargs)
327
+
328
+ self.cls_token_id = cls_token_id
329
+
330
+ self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim)
331
+
332
+ self.unimodal_decoder = UnimodalTextDecoder(
333
+ proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
334
+ )
335
+ self.multimodal_decoder = MultimodalTextDecoder(
336
+ proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
337
+ )
338
+
339
+ self.to_logits = tf.keras.layers.Dense(
340
+ vocab_size,
341
+ name='logits_projection'
342
+ )
343
+
344
+ self.norm = layers.LayerNormalization()
345
+
346
+
347
+ def call(self, inputs, training=False):
348
+ input_text, cap_feature = inputs
349
+ batch_size = tf.shape(input_text)[0]
350
+ cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype))
351
+ ids = tf.concat([input_text, cls_tokens], axis=1)
352
+
353
+ text_mask = tf.not_equal(input_text, 0)
354
+ cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype)
355
+ extended_mask = tf.concat([text_mask, cls_mask], axis=1)
356
+
357
+ txt_embs = self.pos_emb(ids)
358
+
359
+ unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask)
360
+ multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask)
361
+
362
+ cls_token_feature = self.norm(unimodal_out[:, -1:, :])
363
+ multimodal_logits = self.to_logits(multimodal_out)
364
+
365
+ return cls_token_feature, multimodal_logits
366
+
367
+
368
+
369
+ # день 6
370
+ class CoCaModel(keras.Model):
371
+ def __init__(self,
372
+ vit,
373
+ cls_token_id,
374
+ num_heads,
375
+ num_layers):
376
+
377
+ super().__init__()
378
+
379
+ self.encoder = CoCaEncoder(vit, name="coca_encoder")
380
+ self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder")
381
+
382
+ self.img_to_latents = EmbedToLatents(proj_dim)
383
+ self.text_to_latents = EmbedToLatents(proj_dim)
384
+
385
+ self.pad_id = 0
386
+ self.temperature = 0.2 # 0.5 #0.9 #1.0
387
+ self.caption_loss_weight = 1.0
388
+ self.contrastive_loss_weight = 1.0
389
+
390
+ self.perplexity = Perplexity()
391
+
392
+
393
+ def call(self, inputs, training=False):
394
+ image, text = inputs
395
+
396
+ con_feature, cap_feature = self.encoder(image)
397
+ cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
398
+
399
+ return con_feature, cls_token_feature, multimodal_logits
400
+
401
+
402
+ def compile(self, optimizer):
403
+ super().compile()
404
+ self.optimizer = optimizer
405
+
406
+
407
+ def compute_caption_loss(self, multimodal_out, caption_target):
408
+ caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
409
+ caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
410
+
411
+ return tf.reduce_mean(caption_loss)
412
+
413
+
414
+ def compute_contrastive_loss(self, con_feature, cls_feature):
415
+ text_embeds = tf.squeeze(cls_feature, axis=1)
416
+ image_embeds = tf.squeeze(con_feature, axis=1)
417
+
418
+ text_latents = self.text_to_latents(text_embeds)
419
+ image_latents = self.img_to_latents(image_embeds)
420
+
421
+ # Матрица схожести
422
+ sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature # tf.exp(self.log_temp)
423
+
424
+ # Метки
425
+ batch_size = tf.shape(sim)[0]
426
+ contrastive_labels = tf.range(batch_size)
427
+
428
+ # Вычисление потерь
429
+ loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
430
+ loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
431
+ contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
432
+
433
+ return contrastive_loss
434
+
435
+
436
+ def train_step(self, data):
437
+ (images, caption_input), caption_target = data
438
+
439
+ with tf.GradientTape() as tape:
440
+ con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True)
441
+
442
+ caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
443
+ contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
444
+
445
+ total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
446
+
447
+ gradients = tape.gradient(total_loss, self.trainable_variables)
448
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
449
+
450
+ self.perplexity.update_state(caption_target, multimodal_out)
451
+
452
+ return {
453
+ 'total_loss': total_loss,
454
+ 'caption_loss': caption_loss,
455
+ 'contrastive_loss': contrastive_loss,
456
+ 'perplexity': self.perplexity.result()
457
+ }
458
+
459
+
460
+ def test_step(self, data):
461
+ (images, caption_input), caption_target = data
462
+
463
+ con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False)
464
+
465
+ caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
466
+ contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
467
+
468
+ total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
469
+
470
+ self.perplexity.update_state(caption_target, multimodal_out)
471
+
472
+ return {
473
+ 'total_loss': total_loss,
474
+ 'caption_loss': caption_loss,
475
+ 'contrastive_loss': contrastive_loss,
476
+ 'perplexity': self.perplexity.result()
477
+ }
478
+
479
+
480
+ def reset_metrics(self):
481
+ self.perplexity.reset_state()
482
+
483
+
484
+ # ===========================================
485
+ # Загрузка весов для коки
486
+
487
+ coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
488
+
489
+ dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
490
+ dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64)
491
+ _ = coca_model((dummy_features, dummy_captions))
492
+
493
+ optimizer = keras.optimizers.Adam(learning_rate=1e-4)
494
+ coca_model.compile(optimizer)
495
+
496
+ save_dir = "/models/"
497
+ model_name = "coca_007"
498
+ coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
499
+
500
+ # ===========================================
501
+ # RNN =======================================
502
+ img_embed_dim = 2048
503
+ reg_count = 7 * 7
504
+
505
+ base_model = ResNet50(weights='imagenet', include_top=False)
506
+ model = Model(inputs=base_model.input, outputs=base_model.output)
507
+
508
+
509
+ def preprocess_image(img):
510
+ img = tf.image.resize(img, (img_size, img_size))
511
+ img = tf.convert_to_tensor(img)
512
+ img = preprocess_input(img)
513
+ return np.expand_dims(img, axis=0)
514
+
515
+
516
+ def create_features(img):
517
+ img = preprocess_image(img)
518
+ features = model.predict(img, verbose=0)
519
+ features = features.reshape((1, reg_count, img_embed_dim))
520
+ return features
521
+
522
+
523
+ class BahdanauAttention(layers.Layer):
524
+ def __init__(self, units, **kwargs):
525
+ super().__init__(**kwargs)
526
+ self.units = units
527
+ self.W1 = layers.Dense(units)
528
+ self.W2 = layers.Dense(units)
529
+ self.V = layers.Dense(1)
530
+
531
+ def call(self, features, hidden):
532
+ hidden = tf.expand_dims(hidden, 1)
533
+ score = self.V(tf.nn.tanh(
534
+ self.W1(features) + self.W2(hidden)
535
+ ))
536
+ alpha = tf.nn.softmax(score, axis=1)
537
+ context = tf.reduce_sum(alpha * features, axis=1)
538
+ return context, alpha
539
+
540
+
541
+
542
+ class ImageCaptioningModel(keras.Model):
543
+ def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
544
+ super().__init__(**kwargs)
545
+
546
+ self.vocab_size = vocab_size
547
+ self.max_caption_len = max_caption_len
548
+ self.embedding_dim = embedding_dim
549
+ self.lstm_units = lstm_units
550
+ self.dropout_rate = dropout_rate
551
+
552
+ self.embedding = layers.Embedding(vocab_size, embedding_dim)
553
+ self.embedding_dropout = layers.Dropout(dropout_rate)
554
+ self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True)
555
+ self.attention = BahdanauAttention(lstm_units)
556
+ self.fc_dropout = layers.Dropout(dropout_rate)
557
+ self.fc = layers.Dense(vocab_size, activation='softmax')
558
+
559
+ self.init_h = layers.Dense(lstm_units, activation='tanh')
560
+ self.init_c = layers.Dense(lstm_units)
561
+
562
+ self.concatenate = layers.Concatenate(axis=-1)
563
+
564
+
565
+ def call(self, inputs):
566
+ features, captions = inputs
567
+
568
+ mean_features = tf.reduce_mean(features, axis=1)
569
+ h = self.init_h(mean_features)
570
+ c = self.init_c(mean_features)
571
+
572
+ embeddings = self.embedding(captions)
573
+ embeddings = self.embedding_dropout(embeddings)
574
+
575
+ outputs = []
576
+ for t in range(self.max_caption_len):
577
+ context, _ = self.attention(features, h)
578
+
579
+ lstm_input = self.concatenate([embeddings[:, t, :], context])
580
+ lstm_input = tf.expand_dims(lstm_input, 1)
581
+
582
+ output, h, c = self.lstm(lstm_input, initial_state=[h, c])
583
+ outputs.append(output)
584
+
585
+ outputs = tf.concat(outputs, axis=1)
586
+ outputs = self.fc_dropout(outputs)
587
+ return self.fc(outputs)
588
+
589
+
590
+
591
+ rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
592
+ image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
593
+ text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
594
+ _ = rnn_model([image_input, text_input])
595
+
596
+ rnn_model.compile(
597
+ optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
598
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
599
+ metrics=[Perplexity()]
600
+ )
601
+
602
+ save_dir = "/models/"
603
+ model_name = "rnn_att_v4"
604
+
605
+ rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
606
+
607
+ # =====================================
608
+ # Методы генерации
609
+
610
+ beam_width=3
611
+ max_length=sentence_length-1
612
+ temperature=1.0
613
+
614
+ image_mean = [0.5, 0.5, 0.5]
615
+ image_std = [0.5, 0.5, 0.5]
616
+
617
+ def load_and_preprocess_image(img):
618
+ img = tf.convert_to_tensor(img)
619
+ img = tf.image.resize(img, (img_size, img_size))
620
+ img = img / 255.0
621
+
622
+ img = (img - image_mean) / image_std
623
+ img = tf.transpose(img, perm=[2, 0, 1])
624
+
625
+ return np.expand_dims(img, axis=0)
626
+
627
+
628
+ def has_repeated_ngrams(seq, n=2):
629
+ ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
630
+ return len(ngrams) != len(set(ngrams))
631
+
632
+
633
+ # метод с улучшениями для коки
634
+ def generate_caption_coca(image):
635
+ img_processed = load_and_preprocess_image(image)
636
+ _, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
637
+
638
+ beams = [([word_index[start_token]], 0.0)]
639
+
640
+ for _ in range(max_length):
641
+ new_beams = []
642
+ for seq, log_prob in beams:
643
+ if seq[-1] == word_index[end_token]:
644
+ new_beams.append((seq, log_prob))
645
+ continue
646
+
647
+ text_input = np.zeros((1, max_length), dtype=np.int32)
648
+ text_input[0, :len(seq)] = seq
649
+
650
+ predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
651
+ _, logits = predictions
652
+ logits = logits[0, len(seq)-1, :] / temperature
653
+ probs = np.exp(logits - np.max(logits))
654
+ probs /= probs.sum()
655
+
656
+ top_k = np.argpartition(probs, -beam_width)[-beam_width:]
657
+ for token in top_k:
658
+ new_seq = seq + [token]
659
+ new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
660
+
661
+ # Штраф за повторения
662
+ if has_repeated_ngrams(new_seq, n=2):
663
+ new_log_prob -= 0.5
664
+
665
+ new_beams.append((new_seq, new_log_prob))
666
+
667
+ beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
668
+ if all(beam[0][-1] == word_index[end_token] for beam in beams):
669
+ break
670
+
671
+ best_seq = max(beams, key=lambda x: x[1])[0]
672
+ return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
673
+
674
+
675
+ # метод с улучшениями для rnn
676
+ def generate_caption_rnn(image):
677
+ image_embedding = create_features(image)
678
+ beams = [([word_index[start_token]], 0.0)]
679
+
680
+ for _ in range(max_length):
681
+ new_beams = []
682
+ for seq, log_prob in beams:
683
+ if seq[-1] == word_index[end_token]:
684
+ new_beams.append((seq, log_prob))
685
+ continue
686
+
687
+ text_input = np.zeros((1, max_length), dtype=np.int32)
688
+ text_input[0, :len(seq)] = seq
689
+
690
+ predictions = rnn_model.predict([image_embedding, text_input], verbose=0)
691
+ probs = predictions[0, len(seq)-1, :]
692
+ probs = probs ** (1 / temperature)
693
+ probs /= probs.sum()
694
+
695
+ top_k = np.argpartition(probs, -beam_width)[-beam_width:]
696
+ for token in top_k:
697
+ new_seq = seq + [token]
698
+ new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
699
+
700
+ # Штраф за повторения
701
+ if has_repeated_ngrams(new_seq, n=2):
702
+ new_log_prob -= 0.5
703
+ new_beams.append((new_seq, new_log_prob))
704
+
705
+ beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
706
+ if all(beam[0][-1] == word_index[end_token] for beam in beams):
707
+ break
708
+
709
+ best_seq = max(beams, key=lambda x: x[1])[0]
710
+ return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
711
+
712
+
713
+ def generate_both(image):
714
+ caption1 = generate_caption_rnn(image)
715
+ caption2 = generate_caption_coca(image)
716
+ return f"RNN: {caption1}\n\nCoCa: {caption2}"
717
+
718
+
719
+ interface = gr.Interface(
720
+ fn=generate_both,
721
+ inputs=gr.Image(type="pil", label="Изображение"),
722
+ outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
723
+ title="Генератор описаний к изображениям",
724
+ allow_flagging="never",
725
+ submit_btn="Сгенерировать",
726
+ clear_btn="Очи��тить"
727
+ )
728
+
729
+ if __name__ == "__main__":
730
+ interface.launch()
models/coca_007.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d057bdbcd874a572898a3803a08059fb558d8695b282c790185a61a3604e5955
3
+ size 262132344
models/rnn_att_v4.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e09a294e234d15baae6ef4916f35772ec53e2645e2de58c54e0996a7baa027
3
+ size 331683632
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ tensorflow>=2.18.0
2
+ keras>=3.8.0
3
+ numpy>=2.0.2
4
+ pillow>=11.2.1
5
+ transformers>=4.52.4
6
+ gradio>=5.31.0
7
+ h5py>=3.14.0
vocabs/index_word.json ADDED
The diff for this file is too large to render. See raw diff
 
vocabs/word_index.json ADDED
The diff for this file is too large to render. See raw diff