Spaces:
Runtime error
Runtime error
import pickle | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
MAX_LENGTH = 40 | |
BATCH_SIZE = 32 | |
BUFFER_SIZE = 1000 | |
EMBEDDING_DIM = 512 | |
UNITS = 512 | |
# LOADING DATA | |
vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb')) | |
tokenizer = tf.keras.layers.TextVectorization( | |
standardize = None, | |
output_sequence_length = MAX_LENGTH, | |
vocabulary = vocab | |
) | |
idx2word = tf.keras.layers.StringLookup( | |
mask_token = "", | |
vocabulary = tokenizer.get_vocabulary(), | |
invert = True | |
) | |
def load_image_from_path(img_path): | |
img = tf.io.read_file(img_path) | |
img = tf.io.decode_jpeg(img, channels=3) | |
img = tf.keras.layers.Resizing(299, 299)(img) | |
img = tf.keras.applications.inception_v3.preprocess_input(img) | |
return img | |
def generate_caption(img, caption_model, add_noise=False): | |
if isinstance(img, str): | |
img = load_image_from_path(img) | |
if add_noise == True: | |
noise = tf.random.normal(img.shape)*0.1 | |
img = (img + noise) | |
img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img)) | |
img = tf.expand_dims(img, axis=0) | |
img_embed = caption_model.cnn_model(img) | |
img_encoded = caption_model.encoder(img_embed, training=False) | |
y_inp = '[start]' | |
for i in range(MAX_LENGTH-1): | |
tokenized = tokenizer([y_inp])[:, :-1] | |
mask = tf.cast(tokenized != 0, tf.int32) | |
pred = caption_model.decoder( | |
tokenized, img_encoded, training=False, mask=mask) | |
pred_idx = np.argmax(pred[0, i, :]) | |
pred_word = idx2word(pred_idx).numpy().decode('utf-8') | |
if pred_word == '[end]': | |
break | |
y_inp += ' ' + pred_word | |
y_inp = y_inp.replace('[start] ', '') | |
return y_inp | |
def get_caption_model(): | |
encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1) | |
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8) | |
cnn_model = CNN_Encoder() | |
caption_model = ImageCaptioningModel( | |
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None, | |
) | |
def call_fn(batch, training): | |
return batch | |
caption_model.call = call_fn | |
sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40)) | |
caption_model((sample_x, sample_y)) | |
sample_img_embed = caption_model.cnn_model(sample_x) | |
sample_enc_out = caption_model.encoder(sample_img_embed, training=False) | |
caption_model.decoder(sample_y, sample_enc_out, training=False) | |
try: | |
caption_model.load_weights('models/trained_coco_weights.h5') | |
except FileNotFoundError: | |
caption_model.load_weights('image-caption-generator/models/trained_coco_weights.h5') | |
return caption_model |