Spaces:
Sleeping
Sleeping
Add files
Browse files- app.py +685 -0
- models/coca.weights.h5 +3 -0
- models/rnn_attn.weights.h5 +3 -0
- requirements.txt +7 -0
- vocabs/index_word.json +0 -0
- vocabs/word_index.json +0 -0
app.py
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow import keras
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import data as tf_data
|
5 |
+
from tensorflow import image as tf_image
|
6 |
+
from tensorflow import io as tf_io
|
7 |
+
from PIL import Image
|
8 |
+
import json
|
9 |
+
from tensorflow.keras import layers, Model
|
10 |
+
import string
|
11 |
+
from transformers import TFAutoModel
|
12 |
+
import gradio as gr
|
13 |
+
import os
|
14 |
+
import numpy as np
|
15 |
+
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
16 |
+
from tensorflow.keras.preprocessing import image
|
17 |
+
from tensorflow.keras.models import Model
|
18 |
+
|
19 |
+
|
20 |
+
os.environ["KERAS_BACKEND"] = "tensorflow"
|
21 |
+
start_token = "[BOS]"
|
22 |
+
end_token = "[EOS]"
|
23 |
+
cls_token = "[CLS]"
|
24 |
+
|
25 |
+
data_dir = '/content/coco'
|
26 |
+
data_type_train = 'train2014'
|
27 |
+
data_type_val = 'val2014'
|
28 |
+
|
29 |
+
vocab_size = 24000
|
30 |
+
sentence_length = 20
|
31 |
+
batch_size = 128
|
32 |
+
img_size = 224
|
33 |
+
|
34 |
+
proj_dim = 192
|
35 |
+
dropout_rate = 0.1
|
36 |
+
num_patches = 14
|
37 |
+
patch_size = img_size // num_patches
|
38 |
+
|
39 |
+
num_heads = 3
|
40 |
+
num_layers = 6
|
41 |
+
attn_pool_dim = proj_dim
|
42 |
+
attn_pool_heads = num_heads
|
43 |
+
cap_query_num = 128
|
44 |
+
|
45 |
+
rnn_embedding_dim = 256
|
46 |
+
rnn_proj_dim = 512
|
47 |
+
|
48 |
+
|
49 |
+
with open('vocabs/word_index.json', 'r', encoding='utf-8') as f:
|
50 |
+
word_index = {np.str_(word): np.int64(idx) for word, idx in json.load(f).items()}
|
51 |
+
|
52 |
+
with open('vocabs/index_word.json', 'r', encoding='utf-8') as f:
|
53 |
+
index_word = {np.int64(idx): np.str_(word) for idx, word in json.load(f).items()}
|
54 |
+
|
55 |
+
cls_token_id = word_index[cls_token]
|
56 |
+
|
57 |
+
|
58 |
+
class PositionalEmbedding(layers.Layer):
|
59 |
+
def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
|
60 |
+
super().__init__(**kwargs)
|
61 |
+
self.sequence_length = sequence_length
|
62 |
+
self.input_dim = input_dim
|
63 |
+
self.output_dim = output_dim
|
64 |
+
self.token_embeddings = layers.Embedding(
|
65 |
+
input_dim=input_dim, output_dim=output_dim
|
66 |
+
)
|
67 |
+
self.position_embeddings = layers.Embedding(
|
68 |
+
input_dim=sequence_length, output_dim=output_dim
|
69 |
+
)
|
70 |
+
|
71 |
+
def call(self, inputs):
|
72 |
+
positions = tf.range(start=0, limit=self.sequence_length, delta=1)
|
73 |
+
embedded_tokens = self.token_embeddings(inputs)
|
74 |
+
embedded_positions = self.position_embeddings(positions)
|
75 |
+
output = embedded_tokens + embedded_positions
|
76 |
+
return output
|
77 |
+
|
78 |
+
|
79 |
+
class AttentionalPooling(layers.Layer):
|
80 |
+
def __init__(self, embed_dim, num_heads=6):
|
81 |
+
super().__init__()
|
82 |
+
self.embed_dim = embed_dim
|
83 |
+
self.num_heads = num_heads
|
84 |
+
self.multihead_attention = layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_dim)
|
85 |
+
self.norm = layers.LayerNormalization()
|
86 |
+
|
87 |
+
|
88 |
+
def call(self, features, query):
|
89 |
+
attn_output = self.multihead_attention(
|
90 |
+
query=query,
|
91 |
+
value=features,
|
92 |
+
key=features
|
93 |
+
)
|
94 |
+
|
95 |
+
return self.norm(attn_output)
|
96 |
+
|
97 |
+
|
98 |
+
class TransformerBlock(layers.Layer):
|
99 |
+
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, is_multimodal=False, **kwargs):
|
100 |
+
super().__init__(**kwargs)
|
101 |
+
self.embed_dim = embed_dim
|
102 |
+
self.dense_dim = dense_dim
|
103 |
+
self.num_heads = num_heads
|
104 |
+
self.dropout_rate = dropout_rate
|
105 |
+
self.ln_epsilon = ln_epsilon
|
106 |
+
|
107 |
+
self.self_attention = layers.MultiHeadAttention(
|
108 |
+
num_heads=self.num_heads,
|
109 |
+
key_dim=self.embed_dim,
|
110 |
+
dropout=self.dropout_rate
|
111 |
+
)
|
112 |
+
|
113 |
+
if is_multimodal:
|
114 |
+
self.norm2 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
115 |
+
self.dropout2 = layers.Dropout(self.dropout_rate)
|
116 |
+
self.cross_attention = layers.MultiHeadAttention(
|
117 |
+
num_heads=self.num_heads,
|
118 |
+
key_dim=self.embed_dim,
|
119 |
+
dropout=self.dropout_rate
|
120 |
+
)
|
121 |
+
|
122 |
+
self.dense_proj = tf.keras.Sequential([
|
123 |
+
layers.Dense(self.dense_dim, activation="gelu"),
|
124 |
+
layers.Dropout(self.dropout_rate),
|
125 |
+
layers.Dense(self.embed_dim)
|
126 |
+
])
|
127 |
+
|
128 |
+
self.norm1 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
129 |
+
self.norm3 = layers.LayerNormalization(epsilon=self.ln_epsilon)
|
130 |
+
|
131 |
+
self.dropout1 = layers.Dropout(self.dropout_rate)
|
132 |
+
self.dropout3 = layers.Dropout(self.dropout_rate)
|
133 |
+
|
134 |
+
|
135 |
+
def get_causal_attention_mask(self, inputs):
|
136 |
+
seq_len = tf.shape(inputs)[1]
|
137 |
+
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), tf.bool), -1, 0)
|
138 |
+
return tf.expand_dims(causal_mask, 0)
|
139 |
+
|
140 |
+
|
141 |
+
def get_combined_mask(self, causal_mask, padding_mask):
|
142 |
+
padding_mask = tf.cast(padding_mask, tf.bool)
|
143 |
+
|
144 |
+
padding_mask = tf.expand_dims(padding_mask, 1)
|
145 |
+
return causal_mask & padding_mask
|
146 |
+
|
147 |
+
|
148 |
+
def call(self, inputs, encoder_outputs=None, mask=None):
|
149 |
+
att_mask = self.get_causal_attention_mask(inputs)
|
150 |
+
if mask is not None:
|
151 |
+
att_mask = self.get_combined_mask(att_mask, mask)
|
152 |
+
|
153 |
+
x = self.norm1(inputs)
|
154 |
+
attention_output_1 = self.self_attention(
|
155 |
+
query=x, key=x, value=x, attention_mask=att_mask
|
156 |
+
)
|
157 |
+
attention_output_1 = self.dropout1(attention_output_1)
|
158 |
+
x = x + attention_output_1
|
159 |
+
|
160 |
+
if encoder_outputs is not None:
|
161 |
+
x_norm = self.norm2(x)
|
162 |
+
attention_output_2 = self.cross_attention(
|
163 |
+
query=x_norm, key=encoder_outputs, value=encoder_outputs
|
164 |
+
)
|
165 |
+
attention_output_2 = self.dropout2(attention_output_2)
|
166 |
+
x = x + attention_output_2
|
167 |
+
|
168 |
+
x_norm = self.norm3(x)
|
169 |
+
proj_output = self.dense_proj(x_norm)
|
170 |
+
proj_output = self.dropout3(proj_output)
|
171 |
+
return x + proj_output
|
172 |
+
|
173 |
+
|
174 |
+
class UnimodalTextDecoder(layers.Layer):
|
175 |
+
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
176 |
+
super().__init__()
|
177 |
+
self.embed_dim = embed_dim
|
178 |
+
self.dense_dim = dense_dim
|
179 |
+
self.num_heads = num_heads
|
180 |
+
self.dropout_rate = dropout_rate
|
181 |
+
self.ln_epsilon = ln_epsilon
|
182 |
+
self.num_layers = num_layers
|
183 |
+
|
184 |
+
self.layers = [
|
185 |
+
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=False)
|
186 |
+
for _ in range(self.num_layers)
|
187 |
+
]
|
188 |
+
self.norm = tf.keras.layers.LayerNormalization()
|
189 |
+
|
190 |
+
def call(self, x, mask=None):
|
191 |
+
for layer in self.layers:
|
192 |
+
x = layer(inputs=x, mask=mask)
|
193 |
+
return self.norm(x)
|
194 |
+
|
195 |
+
|
196 |
+
class MultimodalTextDecoder(layers.Layer):
|
197 |
+
def __init__(self, embed_dim, dense_dim, num_heads, dropout_rate=0.1, ln_epsilon=1e-6, num_layers=4, **kwargs):
|
198 |
+
super().__init__()
|
199 |
+
self.embed_dim = embed_dim
|
200 |
+
self.dense_dim = dense_dim
|
201 |
+
self.num_heads = num_heads
|
202 |
+
self.dropout_rate = dropout_rate
|
203 |
+
self.ln_epsilon = ln_epsilon
|
204 |
+
self.num_layers = num_layers
|
205 |
+
|
206 |
+
self.layers = [
|
207 |
+
TransformerBlock(self.embed_dim, self.dense_dim, self.num_heads, self.dropout_rate, self.ln_epsilon, is_multimodal=True)
|
208 |
+
for _ in range(self.num_layers)
|
209 |
+
]
|
210 |
+
self.norm = tf.keras.layers.LayerNormalization()
|
211 |
+
|
212 |
+
def call(self, x, encoder_outputs, mask=None):
|
213 |
+
for layer in self.layers:
|
214 |
+
x = layer(inputs=x, encoder_outputs=encoder_outputs, mask=mask)
|
215 |
+
return self.norm(x)
|
216 |
+
|
217 |
+
|
218 |
+
class EmbedToLatents(layers.Layer):
|
219 |
+
def __init__(self, dim_latents, **kwargs):
|
220 |
+
super(EmbedToLatents, self).__init__(**kwargs)
|
221 |
+
self.dim_latents = dim_latents
|
222 |
+
self.to_latents = layers.Dense(
|
223 |
+
self.dim_latents,
|
224 |
+
use_bias=False
|
225 |
+
)
|
226 |
+
|
227 |
+
def call(self, inputs):
|
228 |
+
latents = self.to_latents(inputs)
|
229 |
+
return tf.math.l2_normalize(latents, axis=-1)
|
230 |
+
|
231 |
+
|
232 |
+
class Perplexity(tf.keras.metrics.Metric):
|
233 |
+
def __init__(self, name='perplexity', **kwargs):
|
234 |
+
super().__init__(name=name, **kwargs)
|
235 |
+
self.total_loss = self.add_weight(name='total_loss', initializer='zeros')
|
236 |
+
self.total_tokens = self.add_weight(name='total_tokens', initializer='zeros')
|
237 |
+
|
238 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
239 |
+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
|
240 |
+
loss = loss_fn(y_true, y_pred)
|
241 |
+
|
242 |
+
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
|
243 |
+
loss = tf.reduce_sum(loss * mask)
|
244 |
+
num_tokens = tf.reduce_sum(mask)
|
245 |
+
|
246 |
+
self.total_loss.assign_add(loss)
|
247 |
+
self.total_tokens.assign_add(num_tokens)
|
248 |
+
|
249 |
+
def result(self):
|
250 |
+
return tf.exp(self.total_loss / self.total_tokens)
|
251 |
+
|
252 |
+
def reset_states(self):
|
253 |
+
self.total_loss.assign(0.0)
|
254 |
+
self.total_tokens.assign(0.0)
|
255 |
+
|
256 |
+
|
257 |
+
model_name = "WinKawaks/vit-tiny-patch16-224"
|
258 |
+
vit_tiny_model = TFAutoModel.from_pretrained(model_name)
|
259 |
+
vit_tiny_model.trainable = True
|
260 |
+
|
261 |
+
for layer in vit_tiny_model.layers:
|
262 |
+
layer.trainable = True
|
263 |
+
|
264 |
+
|
265 |
+
class CoCaEncoder(tf.keras.Model):
|
266 |
+
def __init__(self,
|
267 |
+
vit, **kwargs):
|
268 |
+
|
269 |
+
super().__init__(**kwargs)
|
270 |
+
|
271 |
+
self.vit = vit
|
272 |
+
|
273 |
+
self.contrastive_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
|
274 |
+
self.caption_pooling = AttentionalPooling(attn_pool_dim, attn_pool_heads)
|
275 |
+
|
276 |
+
self.con_query = tf.Variable(
|
277 |
+
initial_value=tf.random.normal([1, 1, proj_dim]),
|
278 |
+
trainable=True,
|
279 |
+
name="con_query"
|
280 |
+
)
|
281 |
+
|
282 |
+
self.cap_query = tf.Variable(
|
283 |
+
initial_value=tf.random.normal([1, cap_query_num, proj_dim]),
|
284 |
+
trainable=True,
|
285 |
+
name="cap_query"
|
286 |
+
)
|
287 |
+
|
288 |
+
def call(self, input, training=False):
|
289 |
+
img_feature = self.vit(input).last_hidden_state
|
290 |
+
|
291 |
+
batch_size = tf.shape(img_feature)[0]
|
292 |
+
con_query_b = tf.repeat(self.con_query, repeats=batch_size, axis=0)
|
293 |
+
cap_query_b = tf.repeat(self.cap_query, repeats=batch_size, axis=0)
|
294 |
+
|
295 |
+
con_feature = self.contrastive_pooling(img_feature, con_query_b)
|
296 |
+
cap_feature = self.caption_pooling(img_feature, cap_query_b)
|
297 |
+
|
298 |
+
return con_feature, cap_feature
|
299 |
+
|
300 |
+
|
301 |
+
class CoCaDecoder(tf.keras.Model):
|
302 |
+
def __init__(self,
|
303 |
+
cls_token_id,
|
304 |
+
num_heads,
|
305 |
+
num_layers,
|
306 |
+
**kwargs):
|
307 |
+
|
308 |
+
super().__init__(**kwargs)
|
309 |
+
|
310 |
+
self.cls_token_id = cls_token_id
|
311 |
+
|
312 |
+
self.pos_emb = PositionalEmbedding(sentence_length, vocab_size, proj_dim)
|
313 |
+
|
314 |
+
self.unimodal_decoder = UnimodalTextDecoder(
|
315 |
+
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
|
316 |
+
)
|
317 |
+
self.multimodal_decoder = MultimodalTextDecoder(
|
318 |
+
proj_dim, proj_dim * 4, num_heads, dropout_rate, num_layers=num_layers
|
319 |
+
)
|
320 |
+
|
321 |
+
self.to_logits = tf.keras.layers.Dense(
|
322 |
+
vocab_size,
|
323 |
+
name='logits_projection'
|
324 |
+
)
|
325 |
+
|
326 |
+
self.norm = layers.LayerNormalization()
|
327 |
+
|
328 |
+
def call(self, inputs, training=False):
|
329 |
+
input_text, cap_feature = inputs
|
330 |
+
batch_size = tf.shape(input_text)[0]
|
331 |
+
cls_tokens = tf.fill([batch_size, 1], tf.cast(self.cls_token_id, input_text.dtype))
|
332 |
+
ids = tf.concat([input_text, cls_tokens], axis=1)
|
333 |
+
|
334 |
+
text_mask = tf.not_equal(input_text, 0)
|
335 |
+
cls_mask = tf.zeros([batch_size, 1], dtype=text_mask.dtype)
|
336 |
+
extended_mask = tf.concat([text_mask, cls_mask], axis=1)
|
337 |
+
|
338 |
+
txt_embs = self.pos_emb(ids)
|
339 |
+
|
340 |
+
unimodal_out = self.unimodal_decoder(txt_embs, mask=extended_mask)
|
341 |
+
multimodal_out = self.multimodal_decoder(unimodal_out[:, :-1, :], cap_feature, mask=text_mask)
|
342 |
+
|
343 |
+
cls_token_feature = self.norm(unimodal_out[:, -1:, :])
|
344 |
+
multimodal_logits = self.to_logits(multimodal_out)
|
345 |
+
|
346 |
+
return cls_token_feature, multimodal_logits
|
347 |
+
|
348 |
+
|
349 |
+
class CoCaModel(tf.keras.Model):
|
350 |
+
def __init__(self,
|
351 |
+
vit,
|
352 |
+
cls_token_id,
|
353 |
+
num_heads,
|
354 |
+
num_layers):
|
355 |
+
super().__init__()
|
356 |
+
|
357 |
+
self.encoder = CoCaEncoder(vit, name="coca_encoder")
|
358 |
+
self.decoder = CoCaDecoder(cls_token_id, num_heads, num_layers, name="coca_decoder")
|
359 |
+
|
360 |
+
self.img_to_latents = EmbedToLatents(proj_dim)
|
361 |
+
self.text_to_latents = EmbedToLatents(proj_dim)
|
362 |
+
|
363 |
+
self.pad_id = 0
|
364 |
+
self.temperature = 0.07
|
365 |
+
self.caption_loss_weight = 1.0
|
366 |
+
self.contrastive_loss_weight = 1.0
|
367 |
+
|
368 |
+
self.perplexity = Perplexity()
|
369 |
+
|
370 |
+
def call(self, inputs, training=False):
|
371 |
+
image, text = inputs
|
372 |
+
con_feature, cap_feature = self.encoder(image)
|
373 |
+
cls_token_feature, multimodal_logits = self.decoder([text, cap_feature])
|
374 |
+
return con_feature, cls_token_feature, multimodal_logits
|
375 |
+
|
376 |
+
def compile(self, optimizer):
|
377 |
+
super().compile()
|
378 |
+
self.optimizer = optimizer
|
379 |
+
|
380 |
+
def compute_caption_loss(self, multimodal_out, caption_target):
|
381 |
+
caption_loss = tf.keras.losses.sparse_categorical_crossentropy(
|
382 |
+
caption_target, multimodal_out, from_logits=True, ignore_class=self.pad_id)
|
383 |
+
|
384 |
+
return tf.reduce_mean(caption_loss)
|
385 |
+
|
386 |
+
def compute_contrastive_loss(self, con_feature, cls_feature):
|
387 |
+
text_embeds = tf.squeeze(cls_feature, axis=1)
|
388 |
+
image_embeds = tf.squeeze(con_feature, axis=1)
|
389 |
+
|
390 |
+
text_latents = self.text_to_latents(text_embeds)
|
391 |
+
image_latents = self.img_to_latents(image_embeds)
|
392 |
+
|
393 |
+
sim = tf.matmul(text_latents, image_latents, transpose_b=True) / self.temperature
|
394 |
+
|
395 |
+
batch_size = tf.shape(sim)[0]
|
396 |
+
contrastive_labels = tf.range(batch_size)
|
397 |
+
|
398 |
+
loss1 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, sim, from_logits=True)
|
399 |
+
loss2 = tf.keras.losses.sparse_categorical_crossentropy(contrastive_labels, tf.transpose(sim), from_logits=True)
|
400 |
+
contrastive_loss = tf.reduce_mean((loss1 + loss2) * 0.5)
|
401 |
+
|
402 |
+
return contrastive_loss
|
403 |
+
|
404 |
+
def train_step(self, data):
|
405 |
+
(images, caption_input), caption_target = data
|
406 |
+
|
407 |
+
with tf.GradientTape() as tape:
|
408 |
+
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=True)
|
409 |
+
|
410 |
+
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
|
411 |
+
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
|
412 |
+
|
413 |
+
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
|
414 |
+
|
415 |
+
gradients = tape.gradient(total_loss, self.trainable_variables)
|
416 |
+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
417 |
+
|
418 |
+
self.perplexity.update_state(caption_target, multimodal_out)
|
419 |
+
|
420 |
+
return {
|
421 |
+
'total_loss': total_loss,
|
422 |
+
'caption_loss': caption_loss,
|
423 |
+
'contrastive_loss': contrastive_loss,
|
424 |
+
'perplexity': self.perplexity.result()
|
425 |
+
}
|
426 |
+
|
427 |
+
def test_step(self, data):
|
428 |
+
(images, caption_input), caption_target = data
|
429 |
+
|
430 |
+
con_feature, cls_feature, multimodal_out = self([images, caption_input], training=False)
|
431 |
+
|
432 |
+
caption_loss = self.compute_caption_loss(multimodal_out, caption_target)
|
433 |
+
contrastive_loss = self.compute_contrastive_loss(con_feature, cls_feature)
|
434 |
+
|
435 |
+
total_loss = self.caption_loss_weight * caption_loss + self.contrastive_loss_weight * contrastive_loss
|
436 |
+
|
437 |
+
self.perplexity.update_state(caption_target, multimodal_out)
|
438 |
+
|
439 |
+
return {
|
440 |
+
'total_loss': total_loss,
|
441 |
+
'caption_loss': caption_loss,
|
442 |
+
'contrastive_loss': contrastive_loss,
|
443 |
+
'perplexity': self.perplexity.result()
|
444 |
+
}
|
445 |
+
|
446 |
+
def reset_metrics(self):
|
447 |
+
self.perplexity.reset_state()
|
448 |
+
|
449 |
+
|
450 |
+
coca_model = CoCaModel(vit_tiny_model, cls_token_id=cls_token_id, num_heads=num_heads, num_layers=num_layers)
|
451 |
+
|
452 |
+
dummy_features = tf.zeros((1, 3, img_size, img_size), dtype=tf.float32)
|
453 |
+
dummy_captions = tf.zeros((1, sentence_length-1), dtype=tf.int64)
|
454 |
+
_ = coca_model((dummy_features, dummy_captions))
|
455 |
+
|
456 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
|
457 |
+
coca_model.compile(optimizer)
|
458 |
+
|
459 |
+
save_dir = "models/"
|
460 |
+
model_name = "coca"
|
461 |
+
coca_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
462 |
+
|
463 |
+
|
464 |
+
img_embed_dim = 2048
|
465 |
+
reg_count = 7 * 7
|
466 |
+
|
467 |
+
base_model = ResNet50(weights='imagenet', include_top=False)
|
468 |
+
model = Model(inputs=base_model.input, outputs=base_model.output)
|
469 |
+
|
470 |
+
def preprocess_image(img):
|
471 |
+
img = tf.image.resize(img, (img_size, img_size))
|
472 |
+
img = tf.convert_to_tensor(img)
|
473 |
+
img = preprocess_input(img)
|
474 |
+
return np.expand_dims(img, axis=0)
|
475 |
+
|
476 |
+
def create_features(img):
|
477 |
+
img = preprocess_image(img)
|
478 |
+
features = model.predict(img, verbose=0)
|
479 |
+
features = features.reshape((1, reg_count, img_embed_dim))
|
480 |
+
return features
|
481 |
+
|
482 |
+
|
483 |
+
class BahdanauAttention(layers.Layer):
|
484 |
+
def __init__(self, units, **kwargs):
|
485 |
+
super().__init__(**kwargs)
|
486 |
+
self.units = units
|
487 |
+
self.W1 = layers.Dense(units)
|
488 |
+
self.W2 = layers.Dense(units)
|
489 |
+
self.V = layers.Dense(1)
|
490 |
+
|
491 |
+
def call(self, features, hidden):
|
492 |
+
hidden = tf.expand_dims(hidden, 1)
|
493 |
+
score = self.V(tf.nn.tanh(
|
494 |
+
self.W1(features) + self.W2(hidden)
|
495 |
+
))
|
496 |
+
alpha = tf.nn.softmax(score, axis=1)
|
497 |
+
context = tf.reduce_sum(alpha * features, axis=1)
|
498 |
+
return context, alpha
|
499 |
+
|
500 |
+
|
501 |
+
class ImageCaptioningModel(tf.keras.Model):
|
502 |
+
def __init__(self, vocab_size, max_caption_len, embedding_dim=512, lstm_units=512, dropout_rate=0.5, **kwargs):
|
503 |
+
super().__init__(**kwargs)
|
504 |
+
|
505 |
+
self.vocab_size = vocab_size
|
506 |
+
self.max_caption_len = max_caption_len
|
507 |
+
self.embedding_dim = embedding_dim
|
508 |
+
self.lstm_units = lstm_units
|
509 |
+
self.dropout_rate = dropout_rate
|
510 |
+
|
511 |
+
self.embedding = layers.Embedding(vocab_size, embedding_dim)
|
512 |
+
self.embedding_dropout = layers.Dropout(dropout_rate)
|
513 |
+
self.lstm = layers.LSTM(lstm_units, return_sequences=True, return_state=True)
|
514 |
+
self.attention = BahdanauAttention(lstm_units)
|
515 |
+
self.fc_dropout = layers.Dropout(dropout_rate)
|
516 |
+
self.fc = layers.Dense(vocab_size, activation='softmax')
|
517 |
+
|
518 |
+
self.init_h = layers.Dense(lstm_units, activation='tanh')
|
519 |
+
self.init_c = layers.Dense(lstm_units)
|
520 |
+
|
521 |
+
self.concatenate = layers.Concatenate(axis=-1)
|
522 |
+
|
523 |
+
def call(self, inputs):
|
524 |
+
features, captions = inputs
|
525 |
+
|
526 |
+
mean_features = tf.reduce_mean(features, axis=1)
|
527 |
+
h = self.init_h(mean_features)
|
528 |
+
c = self.init_c(mean_features)
|
529 |
+
|
530 |
+
embeddings = self.embedding(captions)
|
531 |
+
embeddings = self.embedding_dropout(embeddings)
|
532 |
+
|
533 |
+
outputs = []
|
534 |
+
for t in range(self.max_caption_len):
|
535 |
+
context, _ = self.attention(features, h)
|
536 |
+
|
537 |
+
lstm_input = self.concatenate([embeddings[:, t, :], context])
|
538 |
+
lstm_input = tf.expand_dims(lstm_input, 1)
|
539 |
+
|
540 |
+
output, h, c = self.lstm(lstm_input, initial_state=[h, c])
|
541 |
+
outputs.append(output)
|
542 |
+
|
543 |
+
outputs = tf.concat(outputs, axis=1)
|
544 |
+
outputs = self.fc_dropout(outputs)
|
545 |
+
return self.fc(outputs)
|
546 |
+
|
547 |
+
|
548 |
+
rnn_model = ImageCaptioningModel(vocab_size, sentence_length-1, rnn_embedding_dim, rnn_proj_dim)
|
549 |
+
image_input = np.random.rand(batch_size, reg_count, img_embed_dim).astype(np.float32)
|
550 |
+
text_input = np.random.randint(0, 10000, size=(batch_size, sentence_length))
|
551 |
+
_ = rnn_model([image_input, text_input])
|
552 |
+
|
553 |
+
rnn_model.compile(
|
554 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
|
555 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
556 |
+
metrics=[Perplexity()]
|
557 |
+
)
|
558 |
+
|
559 |
+
save_dir = "models/"
|
560 |
+
model_name = "rnn_attn"
|
561 |
+
|
562 |
+
rnn_model.load_weights(f"{save_dir}/{model_name}.weights.h5")
|
563 |
+
|
564 |
+
beam_width=3
|
565 |
+
max_length=sentence_length-1
|
566 |
+
temperature=1.0
|
567 |
+
|
568 |
+
image_mean = [0.5, 0.5, 0.5]
|
569 |
+
image_std = [0.5, 0.5, 0.5]
|
570 |
+
|
571 |
+
def load_and_preprocess_image(img):
|
572 |
+
img = tf.convert_to_tensor(img)
|
573 |
+
img = tf.image.resize(img, (img_size, img_size))
|
574 |
+
img = img / 255.0
|
575 |
+
|
576 |
+
img = (img - image_mean) / image_std
|
577 |
+
img = tf.transpose(img, perm=[2, 0, 1])
|
578 |
+
|
579 |
+
return np.expand_dims(img, axis=0)
|
580 |
+
|
581 |
+
|
582 |
+
def has_repeated_ngrams(seq, n=2):
|
583 |
+
ngrams = [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
|
584 |
+
return len(ngrams) != len(set(ngrams))
|
585 |
+
|
586 |
+
|
587 |
+
def generate_caption_coca(image):
|
588 |
+
img_processed = load_and_preprocess_image(image)
|
589 |
+
_, cap_features = coca_model.encoder.predict(img_processed, verbose=0)
|
590 |
+
|
591 |
+
beams = [([word_index[start_token]], 0.0)]
|
592 |
+
|
593 |
+
for _ in range(max_length):
|
594 |
+
new_beams = []
|
595 |
+
for seq, log_prob in beams:
|
596 |
+
if seq[-1] == word_index[end_token]:
|
597 |
+
new_beams.append((seq, log_prob))
|
598 |
+
continue
|
599 |
+
|
600 |
+
text_input = np.zeros((1, max_length), dtype=np.int32)
|
601 |
+
text_input[0, :len(seq)] = seq
|
602 |
+
|
603 |
+
predictions = coca_model.decoder.predict([text_input, cap_features], verbose=0)
|
604 |
+
_, logits = predictions
|
605 |
+
logits = logits[0, len(seq)-1, :] / temperature
|
606 |
+
probs = np.exp(logits - np.max(logits))
|
607 |
+
probs /= probs.sum()
|
608 |
+
|
609 |
+
top_k = np.argpartition(probs, -beam_width)[-beam_width:]
|
610 |
+
for token in top_k:
|
611 |
+
new_seq = seq + [token]
|
612 |
+
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
613 |
+
|
614 |
+
if has_repeated_ngrams(new_seq, n=2):
|
615 |
+
new_log_prob -= 0.5
|
616 |
+
|
617 |
+
new_beams.append((new_seq, new_log_prob))
|
618 |
+
|
619 |
+
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
|
620 |
+
if all(beam[0][-1] == word_index[end_token] for beam in beams):
|
621 |
+
break
|
622 |
+
|
623 |
+
best_seq = max(beams, key=lambda x: x[1])[0]
|
624 |
+
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
|
625 |
+
|
626 |
+
|
627 |
+
def generate_caption_rnn(image):
|
628 |
+
image_embedding = create_features(image)
|
629 |
+
beams = [([word_index[start_token]], 0.0)]
|
630 |
+
|
631 |
+
for _ in range(max_length):
|
632 |
+
new_beams = []
|
633 |
+
for seq, log_prob in beams:
|
634 |
+
if seq[-1] == word_index[end_token]:
|
635 |
+
new_beams.append((seq, log_prob))
|
636 |
+
continue
|
637 |
+
|
638 |
+
text_input = np.zeros((1, max_length), dtype=np.int32)
|
639 |
+
text_input[0, :len(seq)] = seq
|
640 |
+
|
641 |
+
predictions = rnn_model.predict([image_embedding, text_input], verbose=0)
|
642 |
+
probs = predictions[0, len(seq)-1, :]
|
643 |
+
probs = probs ** (1 / temperature)
|
644 |
+
probs /= probs.sum()
|
645 |
+
|
646 |
+
top_k = np.argpartition(probs, -beam_width)[-beam_width:]
|
647 |
+
for token in top_k:
|
648 |
+
new_seq = seq + [token]
|
649 |
+
new_log_prob = (log_prob * len(seq) + np.log(probs[token])) / (len(seq) + 1)
|
650 |
+
|
651 |
+
if has_repeated_ngrams(new_seq, n=2):
|
652 |
+
new_log_prob -= 0.5
|
653 |
+
new_beams.append((new_seq, new_log_prob))
|
654 |
+
|
655 |
+
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
|
656 |
+
if all(beam[0][-1] == word_index[end_token] for beam in beams):
|
657 |
+
break
|
658 |
+
|
659 |
+
best_seq = max(beams, key=lambda x: x[1])[0]
|
660 |
+
return " ".join(index_word[i] for i in best_seq if i not in {word_index[start_token], word_index[end_token]})
|
661 |
+
|
662 |
+
|
663 |
+
def generate_both(image):
|
664 |
+
caption1 = generate_caption_rnn(image)
|
665 |
+
caption2 = generate_caption_coca(image)
|
666 |
+
return f"RNN: {caption1}\n\nCoCa: {caption2}"
|
667 |
+
|
668 |
+
|
669 |
+
interface = gr.Interface(
|
670 |
+
fn=generate_both,
|
671 |
+
inputs=gr.Image(type="pil", label="Изображение"),
|
672 |
+
outputs=gr.Textbox(label="Описания", autoscroll=True, show_copy_button=True),
|
673 |
+
allow_flagging="never",
|
674 |
+
submit_btn="Сгенерировать",
|
675 |
+
clear_btn="Очистить",
|
676 |
+
deep_link=False
|
677 |
+
)
|
678 |
+
|
679 |
+
with gr.Blocks() as demo:
|
680 |
+
gr.Markdown("# 🖼️ Генератор описаний к изображениям")
|
681 |
+
interface.render()
|
682 |
+
|
683 |
+
|
684 |
+
if __name__ == "__main__":
|
685 |
+
demo.launch(ssr_mode=False, show_api=False)
|
models/coca.weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5dc33edd1df6158e35bef3f5c4e151c6ce69f4105a487e052754712debfd3656
|
3 |
+
size 262132344
|
models/rnn_attn.weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79e09a294e234d15baae6ef4916f35772ec53e2645e2de58c54e0996a7baa027
|
3 |
+
size 331683632
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow>=2.18.0
|
2 |
+
keras>=3.8.0
|
3 |
+
numpy>=2.0.2
|
4 |
+
pillow>=11.2.1
|
5 |
+
transformers>=4.52.4
|
6 |
+
gradio>=5.31.0
|
7 |
+
h5py>=3.14.0
|
vocabs/index_word.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocabs/word_index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|