File size: 1,976 Bytes
3a3c68a e1cca9e 3a3c68a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import cv2
import numpy as np
import tensorflow as tf
import imageio
import yaml
from keras.src.callbacks import ModelCheckpoint, LearningRateScheduler
from matplotlib import pyplot as plt
from helpers import *
from typing import List
from Loader import GridLoader
from model import Predictor, ProduceExample
from tensorflow.keras.optimizers import Adam
with open('config.yml', 'r') as config_file_obj:
yaml_config = yaml.safe_load(config_file_obj)
dataset_config = yaml_config['datasets']
VIDEO_DIR = dataset_config['video_dir']
ALIGNMENTS_DIR = dataset_config['alignments_dir']
corrupt_video_paths = set(
open('corrupted.txt').read().strip().split('\n')
)
loader = GridLoader()
video_filepaths = loader.load_videos(blacklist=corrupt_video_paths)
print(f'videos loaded: {len(video_filepaths)}')
data = tf.data.Dataset.from_tensor_slices(video_filepaths)
# print('DATA', data)
# List to store filenames
filenames = []
# Iterate over the dataset to get all filenames
for file_path in data:
filenames.append(file_path.numpy().decode("utf-8"))
# print(filenames)
data = data.shuffle(500, reshuffle_each_iteration=False)
data = data.map(mappable_function)
data = data.padded_batch(2, padded_shapes=(
[75, None, None, None], [40]
))
data = data.prefetch(tf.data.AUTOTUNE)
# Added for split
train = data.take(450)
test = data.skip(450)
# print(load_data('GRID-dataset/videos/s1/briz8p.mpg'))
frames, alignments = data.as_numpy_iterator().next()
predictor = Predictor()
predictor.compile(
optimizer=Adam(learning_rate=0.0001),
loss=predictor.CTCLoss
)
checkpoint_callback = ModelCheckpoint(
os.path.join('models', 'checkpoint'),
monitor='loss', save_weights_only=True
)
schedule_callback = LearningRateScheduler(predictor.scheduler)
example_callback = ProduceExample(test)
predictor.fit(
train, validation_data=test, epochs=100,
callbacks=[
checkpoint_callback, schedule_callback,
example_callback
]
) |