Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,9 @@ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
|
16 |
from tensorflow.keras.preprocessing import image
|
17 |
from tensorflow.keras.models import Model
|
18 |
|
19 |
-
|
20 |
os.environ["KERAS_BACKEND"] = "tensorflow"
|
|
|
|
|
21 |
start_token = "[BOS]"
|
22 |
end_token = "[EOS]"
|
23 |
cls_token = "[CLS]"
|
@@ -42,13 +43,17 @@ attn_pool_dim = proj_dim
|
|
42 |
attn_pool_heads = num_heads
|
43 |
cap_query_num = 128
|
44 |
|
|
|
45 |
rnn_embedding_dim = 256
|
46 |
rnn_proj_dim = 512
|
47 |
|
|
|
48 |
|
|
|
49 |
with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
|
50 |
word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
|
51 |
|
|
|
52 |
with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
|
53 |
index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
|
54 |
|
@@ -76,7 +81,7 @@ class PositionalEmbedding(layers.Layer):
|
|
76 |
return output
|
77 |
|
78 |
|
79 |
-
class AttentionalPooling(layers.Layer):
|
80 |
def __init__(self, embed_dim, num_heads=6):
|
81 |
super().__init__()
|
82 |
self.embed_dim = embed_dim
|
@@ -95,7 +100,7 @@ class AttentionalPooling(layers.Layer):
|
|
95 |
return self.norm(attn_output)
|
96 |
|
97 |
|
98 |
-
class TransformerBlock(layers.Layer):
|
99 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
|
100 |
super().__init__(**kwargs)
|
101 |
self.embed_dim = embed_dim
|
@@ -104,12 +109,14 @@ class TransformerBlock(layers.Layer):
|
|
104 |
self.dropout_rate = dropout_rate
|
105 |
self.ln_epsilon = ln_epsilon
|
106 |
|
|
|
107 |
self.self_attention = layers.MultiHeadAttention(
|
108 |
num_heads=self.num_heads,
|
109 |
key_dim=self.embed_dim,
|
110 |
dropout=self.dropout_rate
|
111 |
)
|
112 |
|
|
|
113 |
if is_multimodal:
|
114 |
self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
115 |
self.dropout2 = layers.Dropout(self.dropout_rate)
|
@@ -119,15 +126,19 @@ class TransformerBlock(layers.Layer):
|
|
119 |
dropout=self.dropout_rate
|
120 |
)
|
121 |
|
|
|
|
|
122 |
self.dense_proj = tf.keras.Sequential([
|
123 |
layers.Dense(self.dense_dim, activation="gelu"),
|
124 |
layers.Dropout(self.dropout_rate),
|
125 |
layers.Dense(self.embed_dim)
|
126 |
])
|
127 |
|
|
|
128 |
self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
129 |
self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
130 |
|
|
|
131 |
self.dropout1 = layers.Dropout(self.dropout_rate)
|
132 |
self.dropout3 = layers.Dropout(self.dropout_rate)
|
133 |
|
@@ -137,11 +148,11 @@ class TransformerBlock(layers.Layer):
|
|
137 |
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
|
138 |
return tf.expand_dims(causal_mask, 0)
|
139 |
|
140 |
-
|
141 |
def get_combined_mask(self, causal_mask, padding_mask):
|
142 |
padding_mask = tf.cast(padding_mask, tf.bool)
|
143 |
|
144 |
-
padding_mask = tf.expand_dims(padding_mask, 1)
|
145 |
return causal_mask & padding_mask
|
146 |
|
147 |
|
@@ -150,28 +161,31 @@ class TransformerBlock(layers.Layer):
|
|
150 |
if mask is not None:
|
151 |
att_mask = self.get_combined_mask(att_mask, mask)
|
152 |
|
|
|
153 |
x = self.norm1(inputs)
|
154 |
attention_output_1 = self.self_attention(
|
155 |
query=x, key=x, value=x, attention_mask=att_mask
|
156 |
)
|
157 |
attention_output_1 = self.dropout1(attention_output_1)
|
158 |
-
x = x + attention_output_1
|
159 |
-
|
|
|
160 |
if encoder_outputs is not None:
|
161 |
x_norm = self.norm2(x)
|
162 |
attention_output_2 = self.cross_attention(
|
163 |
query=x_norm, key=encoder_outputs, value=encoder_outputs
|
164 |
)
|
165 |
attention_output_2 = self.dropout2(attention_output_2)
|
166 |
-
x = x + attention_output_2
|
167 |
|
|
|
168 |
x_norm = self.norm3(x)
|
169 |
proj_output = self.dense_proj(x_norm)
|
170 |
proj_output = self.dropout3(proj_output)
|
171 |
-
return x + proj_output
|
172 |
|
173 |
|
174 |
-
class UnimodalTextDecoder(layers.Layer):
|
175 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
176 |
super().__init__()
|
177 |
self.embed_dim = embed_dim
|
@@ -187,13 +201,15 @@ class UnimodalTextDecoder(layers.Layer):
|
|
187 |
]
|
188 |
self.norm = tf.keras.layers.LayerNormalization()
|
189 |
|
|
|
190 |
def call(self, x, mask=None):
|
191 |
for layer in self.layers:
|
192 |
x = layer(inputs=x, mask=mask)
|
193 |
return self.norm(x)
|
194 |
|
195 |
|
196 |
-
|
|
|
197 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
198 |
super().__init__()
|
199 |
self.embed_dim = embed_dim
|
@@ -209,6 +225,7 @@ class MultimodalTextDecoder(layers.Layer):
|
|
209 |
]
|
210 |
self.norm = tf.keras.layers.LayerNormalization()
|
211 |
|
|
|
212 |
def call(self, x, encoder_outputs, mask=None):
|
213 |
for layer in self.layers:
|
214 |
x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
|
@@ -285,6 +302,7 @@ class CoCaEncoder(tf.keras.Model):
|
|
285 |
name="cap_query"
|
286 |
)
|
287 |
|
|
|
288 |
def call(self, input, training=False):
|
289 |
img_feature = self.vit(input).last_hidden_state
|
290 |
|
@@ -298,6 +316,7 @@ class CoCaEncoder(tf.keras.Model):
|
|
298 |
return con_feature, cap_feature
|
299 |
|
300 |
|
|
|
301 |
class CoCaDecoder(tf.keras.Model):
|
302 |
def __init__(self,
|
303 |
cls_token_id,
|
@@ -325,6 +344,7 @@ class CoCaDecoder(tf.keras.Model):
|
|
325 |
|
326 |
self.norm = layers.LayerNormalization()
|
327 |
|
|
|
328 |
def call(self, inputs, training=False):
|
329 |
input_text, cap_feature = inputs
|
330 |
batch_size = tf.shape(input_text)[0]
|
@@ -346,12 +366,15 @@ class CoCaDecoder(tf.keras.Model):
|
|
346 |
return cls_token_feature, multimodal_logits
|
347 |
|
348 |
|
|
|
|
|
349 |
class CoCaModel(tf.keras.Model):
|
350 |
def __init__(self,
|
351 |
vit,
|
352 |
cls_token_id,
|
353 |
num_heads,
|
354 |
num_layers):
|
|
|
355 |
super().__init__()
|
356 |
|
357 |
self.encoder = CoCaEncoder(vit, name="coca_encoder")
|
@@ -361,28 +384,34 @@ class CoCaModel(tf.keras.Model):
|
|
361 |
self.text_to_latents = EmbedToLatents(proj_dim)
|
362 |
|
363 |
self.pad_id = 0
|
364 |
-
self.temperature = 0.
|
365 |
self.caption_loss_weight = 1.0
|
366 |
self.contrastive_loss_weight = 1.0
|
367 |
|
368 |
self.perplexity = Perplexity()
|
369 |
|
|
|
370 |
def call(self, inputs, training=False):
|
371 |
image, text = inputs
|
|
|
372 |
con_feature, cap_feature = self.encoder(image)
|
373 |
cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
|
|
|
374 |
return con_feature, cls_token_feature, multimodal_logits
|
375 |
|
|
|
376 |
def compile(self, optimizer):
|
377 |
super().compile()
|
378 |
self.optimizer = optimizer
|
379 |
|
|
|
380 |
def compute_caption_loss(self, multimodal_out, caption_target):
|
381 |
caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
|
382 |
caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
|
383 |
|
384 |
return tf.reduce_mean(caption_loss)
|
385 |
|
|
|
386 |
def compute_contrastive_loss(self, con_feature, cls_feature):
|
387 |
text_embeds = tf.squeeze(cls_feature, axis=1)
|
388 |
image_embeds = tf.squeeze(con_feature, axis=1)
|
@@ -390,17 +419,21 @@ class CoCaModel(tf.keras.Model):
|
|
390 |
text_latents = self.text_to_latents(text_embeds)
|
391 |
image_latents = self.img_to_latents(image_embeds)
|
392 |
|
393 |
-
|
|
|
394 |
|
|
|
395 |
batch_size = tf.shape(sim)[0]
|
396 |
contrastive_labels = tf.range(batch_size)
|
397 |
|
|
|
398 |
loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
|
399 |
loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
|
400 |
contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
|
401 |
|
402 |
return contrastive_loss
|
403 |
|
|
|
404 |
def train_step(self, data):
|
405 |
(images, caption_input), caption_target = data
|
406 |
|
@@ -424,6 +457,7 @@ class CoCaModel(tf.keras.Model):
|
|
424 |
'perplexity': self.perplexity.result()
|
425 |
}
|
426 |
|
|
|
427 |
def test_step(self, data):
|
428 |
(images, caption_input), caption_target = data
|
429 |
|
@@ -443,10 +477,14 @@ class CoCaModel(tf.keras.Model):
|
|
443 |
'perplexity': self.perplexity.result()
|
444 |
}
|
445 |
|
|
|
446 |
def reset_metrics(self):
|
447 |
self.perplexity.reset_state()
|
448 |
|
449 |
|
|
|
|
|
|
|
450 |
coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
|
451 |
|
452 |
dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
|
@@ -460,19 +498,22 @@ save_dir = "models/"
|
|
460 |
model_name = "coca_007"
|
461 |
coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
462 |
|
463 |
-
|
|
|
464 |
img_embed_dim = 2048
|
465 |
reg_count = 7 * 7
|
466 |
|
467 |
base_model = ResNet50(weights='imagenet', include_top=False)
|
468 |
model = Model(inputs=base_model.input, outputs=base_model.output)
|
469 |
|
|
|
470 |
def preprocess_image(img):
|
471 |
img = tf.image.resize(img, (img_size, img_size))
|
472 |
img = tf.convert_to_tensor(img)
|
473 |
img = preprocess_input(img)
|
474 |
return np.expand_dims(img, axis=0)
|
475 |
|
|
|
476 |
def create_features(img):
|
477 |
img = preprocess_image(img)
|
478 |
features = model.predict(img, verbose=0)
|
@@ -498,6 +539,7 @@ class BahdanauAttention(layers.Layer):
|
|
498 |
return context, alpha
|
499 |
|
500 |
|
|
|
501 |
class ImageCaptioningModel(tf.keras.Model):
|
502 |
def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
|
503 |
super().__init__(**kwargs)
|
@@ -520,6 +562,7 @@ class ImageCaptioningModel(tf.keras.Model):
|
|
520 |
|
521 |
self.concatenate = layers.Concatenate(axis=-1)
|
522 |
|
|
|
523 |
def call(self, inputs):
|
524 |
features, captions = inputs
|
525 |
|
@@ -545,6 +588,7 @@ class ImageCaptioningModel(tf.keras.Model):
|
|
545 |
return self.fc(outputs)
|
546 |
|
547 |
|
|
|
548 |
rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
|
549 |
image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
|
550 |
text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
|
@@ -561,6 +605,9 @@ model_name = "rnn_att_v4"
|
|
561 |
|
562 |
rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
563 |
|
|
|
|
|
|
|
564 |
beam_width=3
|
565 |
max_length=sentence_length-1
|
566 |
temperature=1.0
|
@@ -584,6 +631,7 @@ def has_repeated_ngrams(seq, n=2):
|
|
584 |
return len(ngrams) != len(set(ngrams))
|
585 |
|
586 |
|
|
|
587 |
def generate_caption_coca(image):
|
588 |
img_processed = load_and_preprocess_image(image)
|
589 |
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
|
@@ -611,6 +659,7 @@ def generate_caption_coca(image):
|
|
611 |
new_seq = seq + [token]
|
612 |
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
613 |
|
|
|
614 |
if has_repeated_ngrams(new_seq, n=2):
|
615 |
new_log_prob -= 0.5
|
616 |
|
@@ -624,6 +673,7 @@ def generate_caption_coca(image):
|
|
624 |
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
|
625 |
|
626 |
|
|
|
627 |
def generate_caption_rnn(image):
|
628 |
image_embedding = create_features(image)
|
629 |
beams = [([word_index[start_token]], 0.0)]
|
@@ -648,6 +698,7 @@ def generate_caption_rnn(image):
|
|
648 |
new_seq = seq + [token]
|
649 |
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
650 |
|
|
|
651 |
if has_repeated_ngrams(new_seq, n=2):
|
652 |
new_log_prob -= 0.5
|
653 |
new_beams.append((new_seq, new_log_prob))
|
@@ -666,6 +717,25 @@ def generate_both(image):
|
|
666 |
return f"RNN: {caption1}\n\nCoCa: {caption2}"
|
667 |
|
668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
669 |
interface = gr.Interface(
|
670 |
fn=generate_both,
|
671 |
inputs=gr.Image(type="pil", label="Изображение"),
|
@@ -676,11 +746,36 @@ interface = gr.Interface(
|
|
676 |
deep_link=False
|
677 |
)
|
678 |
|
679 |
-
with gr.Blocks() as demo:
|
680 |
gr.Markdown("# 🖼️ Генератор описаний к изображениям")
|
681 |
interface.render()
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
if __name__ == "__main__":
|
685 |
-
demo.launch(
|
686 |
-
|
|
|
|
|
|
16 |
from tensorflow.keras.preprocessing import image
|
17 |
from tensorflow.keras.models import Model
|
18 |
|
|
|
19 |
os.environ["KERAS_BACKEND"] = "tensorflow"
|
20 |
+
|
21 |
+
# Переменные ================================
|
22 |
start_token = "[BOS]"
|
23 |
end_token = "[EOS]"
|
24 |
cls_token = "[CLS]"
|
|
|
43 |
attn_pool_heads = num_heads
|
44 |
cap_query_num = 128
|
45 |
|
46 |
+
#RNN
|
47 |
rnn_embedding_dim = 256
|
48 |
rnn_proj_dim = 512
|
49 |
|
50 |
+
# =================================
|
51 |
|
52 |
+
# Загрузка word_index
|
53 |
with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
|
54 |
word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
|
55 |
|
56 |
+
# Загрузка index_word
|
57 |
with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
|
58 |
index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
|
59 |
|
|
|
81 |
return output
|
82 |
|
83 |
|
84 |
+
class AttentionalPooling(tf.keras.layers.Layer):
|
85 |
def __init__(self, embed_dim, num_heads=6):
|
86 |
super().__init__()
|
87 |
self.embed_dim = embed_dim
|
|
|
100 |
return self.norm(attn_output)
|
101 |
|
102 |
|
103 |
+
class TransformerBlock(tf.keras.layers.Layer):
|
104 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
|
105 |
super().__init__(**kwargs)
|
106 |
self.embed_dim = embed_dim
|
|
|
109 |
self.dropout_rate = dropout_rate
|
110 |
self.ln_epsilon = ln_epsilon
|
111 |
|
112 |
+
# Self-Attention
|
113 |
self.self_attention = layers.MultiHeadAttention(
|
114 |
num_heads=self.num_heads,
|
115 |
key_dim=self.embed_dim,
|
116 |
dropout=self.dropout_rate
|
117 |
)
|
118 |
|
119 |
+
# Cross-Attention
|
120 |
if is_multimodal:
|
121 |
self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
122 |
self.dropout2 = layers.Dropout(self.dropout_rate)
|
|
|
126 |
dropout=self.dropout_rate
|
127 |
)
|
128 |
|
129 |
+
|
130 |
+
# Feed-Forward Network
|
131 |
self.dense_proj = tf.keras.Sequential([
|
132 |
layers.Dense(self.dense_dim, activation="gelu"),
|
133 |
layers.Dropout(self.dropout_rate),
|
134 |
layers.Dense(self.embed_dim)
|
135 |
])
|
136 |
|
137 |
+
# Layer Normalization
|
138 |
self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
139 |
self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
140 |
|
141 |
+
# Dropout
|
142 |
self.dropout1 = layers.Dropout(self.dropout_rate)
|
143 |
self.dropout3 = layers.Dropout(self.dropout_rate)
|
144 |
|
|
|
148 |
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
|
149 |
return tf.expand_dims(causal_mask, 0)
|
150 |
|
151 |
+
|
152 |
def get_combined_mask(self, causal_mask, padding_mask):
|
153 |
padding_mask = tf.cast(padding_mask, tf.bool)
|
154 |
|
155 |
+
padding_mask = tf.expand_dims(padding_mask, 1) # (B, 1, L)
|
156 |
return causal_mask & padding_mask
|
157 |
|
158 |
|
|
|
161 |
if mask is not None:
|
162 |
att_mask = self.get_combined_mask(att_mask, mask)
|
163 |
|
164 |
+
# Self-Attention
|
165 |
x = self.norm1(inputs)
|
166 |
attention_output_1 = self.self_attention(
|
167 |
query=x, key=x, value=x, attention_mask=att_mask
|
168 |
)
|
169 |
attention_output_1 = self.dropout1(attention_output_1)
|
170 |
+
x = x + attention_output_1 # Add & Norm
|
171 |
+
|
172 |
+
# Cross-Attention
|
173 |
if encoder_outputs is not None:
|
174 |
x_norm = self.norm2(x)
|
175 |
attention_output_2 = self.cross_attention(
|
176 |
query=x_norm, key=encoder_outputs, value=encoder_outputs
|
177 |
)
|
178 |
attention_output_2 = self.dropout2(attention_output_2)
|
179 |
+
x = x + attention_output_2 # Add & Norm
|
180 |
|
181 |
+
# Feed-Forward Network (FFN)
|
182 |
x_norm = self.norm3(x)
|
183 |
proj_output = self.dense_proj(x_norm)
|
184 |
proj_output = self.dropout3(proj_output)
|
185 |
+
return x + proj_output # Add & Norm
|
186 |
|
187 |
|
188 |
+
class UnimodalTextDecoder(tf.keras.layers.Layer):
|
189 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
190 |
super().__init__()
|
191 |
self.embed_dim = embed_dim
|
|
|
201 |
]
|
202 |
self.norm = tf.keras.layers.LayerNormalization()
|
203 |
|
204 |
+
|
205 |
def call(self, x, mask=None):
|
206 |
for layer in self.layers:
|
207 |
x = layer(inputs=x, mask=mask)
|
208 |
return self.norm(x)
|
209 |
|
210 |
|
211 |
+
|
212 |
+
class MultimodalTextDecoder(tf.keras.layers.Layer):
|
213 |
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
214 |
super().__init__()
|
215 |
self.embed_dim = embed_dim
|
|
|
225 |
]
|
226 |
self.norm = tf.keras.layers.LayerNormalization()
|
227 |
|
228 |
+
|
229 |
def call(self, x, encoder_outputs, mask=None):
|
230 |
for layer in self.layers:
|
231 |
x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
|
|
|
302 |
name="cap_query"
|
303 |
)
|
304 |
|
305 |
+
|
306 |
def call(self, input, training=False):
|
307 |
img_feature = self.vit(input).last_hidden_state
|
308 |
|
|
|
316 |
return con_feature, cap_feature
|
317 |
|
318 |
|
319 |
+
|
320 |
class CoCaDecoder(tf.keras.Model):
|
321 |
def __init__(self,
|
322 |
cls_token_id,
|
|
|
344 |
|
345 |
self.norm = layers.LayerNormalization()
|
346 |
|
347 |
+
|
348 |
def call(self, inputs, training=False):
|
349 |
input_text, cap_feature = inputs
|
350 |
batch_size = tf.shape(input_text)[0]
|
|
|
366 |
return cls_token_feature, multimodal_logits
|
367 |
|
368 |
|
369 |
+
|
370 |
+
# день 6
|
371 |
class CoCaModel(tf.keras.Model):
|
372 |
def __init__(self,
|
373 |
vit,
|
374 |
cls_token_id,
|
375 |
num_heads,
|
376 |
num_layers):
|
377 |
+
|
378 |
super().__init__()
|
379 |
|
380 |
self.encoder = CoCaEncoder(vit, name="coca_encoder")
|
|
|
384 |
self.text_to_latents = EmbedToLatents(proj_dim)
|
385 |
|
386 |
self.pad_id = 0
|
387 |
+
self.temperature = 0.2 # 0.5 #0.9 #1.0
|
388 |
self.caption_loss_weight = 1.0
|
389 |
self.contrastive_loss_weight = 1.0
|
390 |
|
391 |
self.perplexity = Perplexity()
|
392 |
|
393 |
+
|
394 |
def call(self, inputs, training=False):
|
395 |
image, text = inputs
|
396 |
+
|
397 |
con_feature, cap_feature = self.encoder(image)
|
398 |
cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
|
399 |
+
|
400 |
return con_feature, cls_token_feature, multimodal_logits
|
401 |
|
402 |
+
|
403 |
def compile(self, optimizer):
|
404 |
super().compile()
|
405 |
self.optimizer = optimizer
|
406 |
|
407 |
+
|
408 |
def compute_caption_loss(self, multimodal_out, caption_target):
|
409 |
caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
|
410 |
caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
|
411 |
|
412 |
return tf.reduce_mean(caption_loss)
|
413 |
|
414 |
+
|
415 |
def compute_contrastive_loss(self, con_feature, cls_feature):
|
416 |
text_embeds = tf.squeeze(cls_feature, axis=1)
|
417 |
image_embeds = tf.squeeze(con_feature, axis=1)
|
|
|
419 |
text_latents = self.text_to_latents(text_embeds)
|
420 |
image_latents = self.img_to_latents(image_embeds)
|
421 |
|
422 |
+
# Матрица схожести
|
423 |
+
sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature # tf.exp(self.log_temp)
|
424 |
|
425 |
+
# Метки
|
426 |
batch_size = tf.shape(sim)[0]
|
427 |
contrastive_labels = tf.range(batch_size)
|
428 |
|
429 |
+
# Вычисление потерь
|
430 |
loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
|
431 |
loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
|
432 |
contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
|
433 |
|
434 |
return contrastive_loss
|
435 |
|
436 |
+
|
437 |
def train_step(self, data):
|
438 |
(images, caption_input), caption_target = data
|
439 |
|
|
|
457 |
'perplexity': self.perplexity.result()
|
458 |
}
|
459 |
|
460 |
+
|
461 |
def test_step(self, data):
|
462 |
(images, caption_input), caption_target = data
|
463 |
|
|
|
477 |
'perplexity': self.perplexity.result()
|
478 |
}
|
479 |
|
480 |
+
|
481 |
def reset_metrics(self):
|
482 |
self.perplexity.reset_state()
|
483 |
|
484 |
|
485 |
+
# ===========================================
|
486 |
+
# Загрузка весов для коки
|
487 |
+
|
488 |
coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
|
489 |
|
490 |
dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
|
|
|
498 |
model_name = "coca_007"
|
499 |
coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
500 |
|
501 |
+
# ===========================================
|
502 |
+
# RNN =======================================
|
503 |
img_embed_dim = 2048
|
504 |
reg_count = 7 * 7
|
505 |
|
506 |
base_model = ResNet50(weights='imagenet', include_top=False)
|
507 |
model = Model(inputs=base_model.input, outputs=base_model.output)
|
508 |
|
509 |
+
|
510 |
def preprocess_image(img):
|
511 |
img = tf.image.resize(img, (img_size, img_size))
|
512 |
img = tf.convert_to_tensor(img)
|
513 |
img = preprocess_input(img)
|
514 |
return np.expand_dims(img, axis=0)
|
515 |
|
516 |
+
|
517 |
def create_features(img):
|
518 |
img = preprocess_image(img)
|
519 |
features = model.predict(img, verbose=0)
|
|
|
539 |
return context, alpha
|
540 |
|
541 |
|
542 |
+
|
543 |
class ImageCaptioningModel(tf.keras.Model):
|
544 |
def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
|
545 |
super().__init__(**kwargs)
|
|
|
562 |
|
563 |
self.concatenate = layers.Concatenate(axis=-1)
|
564 |
|
565 |
+
|
566 |
def call(self, inputs):
|
567 |
features, captions = inputs
|
568 |
|
|
|
588 |
return self.fc(outputs)
|
589 |
|
590 |
|
591 |
+
|
592 |
rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
|
593 |
image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
|
594 |
text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
|
|
|
605 |
|
606 |
rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
607 |
|
608 |
+
# =====================================
|
609 |
+
# Методы генерации
|
610 |
+
|
611 |
beam_width=3
|
612 |
max_length=sentence_length-1
|
613 |
temperature=1.0
|
|
|
631 |
return len(ngrams) != len(set(ngrams))
|
632 |
|
633 |
|
634 |
+
# метод с улучшениями для коки
|
635 |
def generate_caption_coca(image):
|
636 |
img_processed = load_and_preprocess_image(image)
|
637 |
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
|
|
|
659 |
new_seq = seq + [token]
|
660 |
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
661 |
|
662 |
+
# Штраф за повторения
|
663 |
if has_repeated_ngrams(new_seq, n=2):
|
664 |
new_log_prob -= 0.5
|
665 |
|
|
|
673 |
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
|
674 |
|
675 |
|
676 |
+
# метод с улучшениями для rnn
|
677 |
def generate_caption_rnn(image):
|
678 |
image_embedding = create_features(image)
|
679 |
beams = [([word_index[start_token]], 0.0)]
|
|
|
698 |
new_seq = seq + [token]
|
699 |
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
700 |
|
701 |
+
# Штраф за повторения
|
702 |
if has_repeated_ngrams(new_seq, n=2):
|
703 |
new_log_prob -= 0.5
|
704 |
new_beams.append((new_seq, new_log_prob))
|
|
|
717 |
return f"RNN: {caption1}\n\nCoCa: {caption2}"
|
718 |
|
719 |
|
720 |
+
# interface = gr.Interface(
|
721 |
+
# fn=generate_both,
|
722 |
+
# inputs=gr.Image(type="pil", label="Изображение"),
|
723 |
+
# outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
|
724 |
+
# title="Генератор описаний к изображениям",
|
725 |
+
# allow_flagging="never",
|
726 |
+
# submit_btn="Сгенерировать",
|
727 |
+
# clear_btn="Очистить"
|
728 |
+
# )
|
729 |
+
|
730 |
+
#------------------------------
|
731 |
+
css = """
|
732 |
+
#hosted-by-hf {
|
733 |
+
top: unset !important;
|
734 |
+
bottom: 20px !important;
|
735 |
+
right: 20px !important;
|
736 |
+
}
|
737 |
+
"""
|
738 |
+
|
739 |
interface = gr.Interface(
|
740 |
fn=generate_both,
|
741 |
inputs=gr.Image(type="pil", label="Изображение"),
|
|
|
746 |
deep_link=False
|
747 |
)
|
748 |
|
749 |
+
with gr.Blocks(css=css) as demo:
|
750 |
gr.Markdown("# 🖼️ Генератор описаний к изображениям")
|
751 |
interface.render()
|
752 |
|
753 |
+
# if __name__ == "__main__":
|
754 |
+
# #interface.launch(ssr_mode=False)
|
755 |
+
# demo.launch(ssr_mode=False)
|
756 |
+
|
757 |
+
|
758 |
+
# custom_css = """
|
759 |
+
# footer {visibility: hidden !important;}
|
760 |
+
# .share-button {display: none !important;}
|
761 |
+
# #component-1 {margin-top: -1.5rem !important;} # Уменьшаем отступ сверху
|
762 |
+
# """
|
763 |
+
|
764 |
+
# interface = gr.Interface(
|
765 |
+
# fn=generate_both,
|
766 |
+
# inputs=gr.Image(type="pil", label="Изображение"),
|
767 |
+
# outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
|
768 |
+
# allow_flagging="never",
|
769 |
+
# submit_btn="Сгенерировать",
|
770 |
+
# clear_btn="Очистить"
|
771 |
+
# )
|
772 |
+
|
773 |
+
# with gr.Blocks(css=custom_css) as demo:
|
774 |
+
# gr.Markdown("## 🖼️ Генератор описаний к изображениям")
|
775 |
+
# interface.render()
|
776 |
|
777 |
if __name__ == "__main__":
|
778 |
+
demo.launch(
|
779 |
+
ssr_mode=False,
|
780 |
+
show_api=False
|
781 |
+
)
|