Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
import os | |
import json | |
import pandas as pd | |
import re | |
import numpy as np | |
import time | |
import matplotlib.pyplot as plt | |
import collections | |
import random | |
import requests | |
import json | |
import pickle | |
from math import sqrt | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from model import CNN_Encoder, TransformerEncoderLayer, Embeddings, TransformerDecoderLayer, ImageCaptioningModel | |
DATASET_PATH = "coco2017" | |
MAX_LENGTH = 40 | |
MAX_VOCABULARY = 12000 | |
BATCH_SIZE = 64 | |
BUFFER_SIZE = 1000 | |
EMBEDDING_DIM = 512 | |
UNITS = 512 | |
EPOCHS = 1 | |
with open(f"{DATASET_PATH}/annotations/captions_train2017.json", "r") as f: | |
data = json.load(f) | |
data = data["annotations"] | |
img_cap_pairs = [] | |
for sample in data: | |
img_name = "%012d.jpg" % sample["image_id"] | |
img_cap_pairs.append([img_name, sample["caption"]]) | |
captions = pd.DataFrame(img_cap_pairs, columns=["image", "caption"]) | |
captions["image"] = captions["image"].apply(lambda x: f"{DATASET_PATH}/train2017/{x}") | |
captions = captions.sample(70000) | |
captions = captions.reset_index(drop=True) | |
captions.head() | |
def preprocessing(text): | |
text = text.lower() | |
text = re.sub(r"[^\w\s]", "", text) | |
text = re.sub("\s+", " ", text) | |
text = text.strip() | |
text = "[start] " + text + " [end]" | |
return text | |
captions["caption"] = captions["caption"].apply(preprocessing) | |
captions.head() | |
tokenizer = tf.keras.layers.TextVectorization( | |
max_tokens=MAX_VOCABULARY, standardize=None, output_sequence_length=MAX_LENGTH | |
) | |
tokenizer.adapt(captions["caption"]) | |
pickle.dump( | |
tokenizer.get_vocabulary(), | |
open("./vocabulary/vocab_coco.file", "wb"), | |
) | |
word2idx = tf.keras.layers.StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary() | |
) | |
idx2word = tf.keras.layers.StringLookup( | |
mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True | |
) | |
img_to_cap_vector = collections.defaultdict(list) | |
for img, cap in zip(captions["image"], captions["caption"]): | |
img_to_cap_vector[img].append(cap) | |
img_keys = list(img_to_cap_vector.keys()) | |
random.shuffle(img_keys) | |
slice_index = int(len(img_keys) * 0.8) | |
img_name_train_keys, img_name_test_keys = ( | |
img_keys[:slice_index], | |
img_keys[slice_index:], | |
) | |
train_img = [] | |
train_caption = [] | |
for imgt in img_name_train_keys: | |
capt_len = len(img_to_cap_vector[imgt]) | |
train_img.extend([imgt] * capt_len) | |
train_caption.extend(img_to_cap_vector[imgt]) | |
test_img = [] | |
test_caption = [] | |
for imgtest in img_name_test_keys: | |
capv_len = len(img_to_cap_vector[imgtest]) | |
test_img.extend([imgtest] * capv_len) | |
test_caption.extend(img_to_cap_vector[imgtest]) | |
len(train_img), len(train_caption), len(test_img), len(test_caption) | |
def load_data(img_path, caption): | |
img = tf.io.read_file(img_path) | |
img = tf.io.decode_jpeg(img, channels=3) | |
img = tf.keras.layers.Resizing(299, 299)(img) | |
img = tf.keras.applications.inception_v3.preprocess_input(img) | |
caption = tokenizer(caption) | |
return img, caption | |
train_dataset = tf.data.Dataset.from_tensor_slices((train_img, train_caption)) | |
train_dataset = ( | |
train_dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE) | |
.shuffle(BUFFER_SIZE) | |
.batch(BATCH_SIZE) | |
) | |
test_dataset = tf.data.Dataset.from_tensor_slices((test_img, test_caption)) | |
test_dataset = ( | |
test_dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE) | |
.shuffle(BUFFER_SIZE) | |
.batch(BATCH_SIZE) | |
) | |
image_augmentation = tf.keras.Sequential( | |
[ | |
tf.keras.layers.RandomFlip("horizontal"), | |
tf.keras.layers.RandomRotation(0.2), | |
tf.keras.layers.RandomContrast(0.3), | |
] | |
) | |
encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1) | |
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8) | |
cnn_model = CNN_Encoder() | |
caption_model = ImageCaptioningModel( | |
cnn_model=cnn_model, | |
encoder=encoder, | |
decoder=decoder, | |
image_aug=image_augmentation, | |
) | |
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( | |
from_logits=False, reduction="none" | |
) | |
early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True) | |
caption_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=cross_entropy) | |
history = caption_model.fit( | |
train_dataset, | |
epochs=EPOCHS, | |
validation_data=test_dataset, | |
callbacks=[early_stopping], | |
) | |
caption_model.save_weights("./image-caption-generator/models/trained_coco_weights_2.h5") | |