File size: 1,976 Bytes
3a3c68a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1cca9e
3a3c68a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import cv2
import numpy as np
import tensorflow as tf
import imageio
import yaml

from keras.src.callbacks import ModelCheckpoint, LearningRateScheduler
from matplotlib import pyplot as plt
from helpers import *
from typing import List
from Loader import GridLoader
from model import Predictor, ProduceExample
from tensorflow.keras.optimizers import Adam

with open('config.yml', 'r') as config_file_obj:
    yaml_config = yaml.safe_load(config_file_obj)

dataset_config = yaml_config['datasets']
VIDEO_DIR = dataset_config['video_dir']
ALIGNMENTS_DIR = dataset_config['alignments_dir']
corrupt_video_paths = set(
    open('corrupted.txt').read().strip().split('\n')
)

loader = GridLoader()
video_filepaths = loader.load_videos(blacklist=corrupt_video_paths)
print(f'videos loaded: {len(video_filepaths)}')
data = tf.data.Dataset.from_tensor_slices(video_filepaths)
# print('DATA', data)

# List to store filenames
filenames = []

# Iterate over the dataset to get all filenames
for file_path in data:
    filenames.append(file_path.numpy().decode("utf-8"))

# print(filenames)
data = data.shuffle(500, reshuffle_each_iteration=False)
data = data.map(mappable_function)
data = data.padded_batch(2, padded_shapes=(
    [75, None, None, None], [40]
))

data = data.prefetch(tf.data.AUTOTUNE)
# Added for split
train = data.take(450)
test = data.skip(450)

# print(load_data('GRID-dataset/videos/s1/briz8p.mpg'))
frames, alignments = data.as_numpy_iterator().next()

predictor = Predictor()
predictor.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss=predictor.CTCLoss
)

checkpoint_callback = ModelCheckpoint(
    os.path.join('models', 'checkpoint'),
    monitor='loss', save_weights_only=True
)
schedule_callback = LearningRateScheduler(predictor.scheduler)
example_callback = ProduceExample(test)

predictor.fit(
    train, validation_data=test, epochs=100,
    callbacks=[
        checkpoint_callback, schedule_callback,
        example_callback
    ]
)