lipnet / train.py
milselarch's picture
Upload folder using huggingface_hub
e1cca9e
import os
import cv2
import numpy as np
import tensorflow as tf
import imageio
import yaml
from keras.src.callbacks import ModelCheckpoint, LearningRateScheduler
from matplotlib import pyplot as plt
from helpers import *
from typing import List
from Loader import GridLoader
from model import Predictor, ProduceExample
from tensorflow.keras.optimizers import Adam
with open('config.yml', 'r') as config_file_obj:
yaml_config = yaml.safe_load(config_file_obj)
dataset_config = yaml_config['datasets']
VIDEO_DIR = dataset_config['video_dir']
ALIGNMENTS_DIR = dataset_config['alignments_dir']
corrupt_video_paths = set(
open('corrupted.txt').read().strip().split('\n')
)
loader = GridLoader()
video_filepaths = loader.load_videos(blacklist=corrupt_video_paths)
print(f'videos loaded: {len(video_filepaths)}')
data = tf.data.Dataset.from_tensor_slices(video_filepaths)
# print('DATA', data)
# List to store filenames
filenames = []
# Iterate over the dataset to get all filenames
for file_path in data:
filenames.append(file_path.numpy().decode("utf-8"))
# print(filenames)
data = data.shuffle(500, reshuffle_each_iteration=False)
data = data.map(mappable_function)
data = data.padded_batch(2, padded_shapes=(
[75, None, None, None], [40]
))
data = data.prefetch(tf.data.AUTOTUNE)
# Added for split
train = data.take(450)
test = data.skip(450)
# print(load_data('GRID-dataset/videos/s1/briz8p.mpg'))
frames, alignments = data.as_numpy_iterator().next()
predictor = Predictor()
predictor.compile(
optimizer=Adam(learning_rate=0.0001),
loss=predictor.CTCLoss
)
checkpoint_callback = ModelCheckpoint(
os.path.join('models', 'checkpoint'),
monitor='loss', save_weights_only=True
)
schedule_callback = LearningRateScheduler(predictor.scheduler)
example_callback = ProduceExample(test)
predictor.fit(
train, validation_data=test, epochs=100,
callbacks=[
checkpoint_callback, schedule_callback,
example_callback
]
)