|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Training script for the FEELVOS model. |
|
|
|
See model.py for more details and usage. |
|
""" |
|
import six |
|
import tensorflow as tf |
|
|
|
from feelvos import common |
|
from feelvos import model |
|
from feelvos.datasets import video_dataset |
|
from feelvos.utils import embedding_utils |
|
from feelvos.utils import train_utils |
|
from feelvos.utils import video_input_generator |
|
from deployment import model_deploy |
|
|
|
slim = tf.contrib.slim |
|
prefetch_queue = slim.prefetch_queue |
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
|
|
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.') |
|
|
|
flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.') |
|
|
|
flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.') |
|
|
|
flags.DEFINE_integer('startup_delay_steps', 15, |
|
'Number of training steps between replicas startup.') |
|
|
|
flags.DEFINE_integer('num_ps_tasks', 0, |
|
'The number of parameter servers. If the value is 0, then ' |
|
'the parameters are handled locally by the worker.') |
|
|
|
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') |
|
|
|
flags.DEFINE_integer('task', 0, 'The task ID.') |
|
|
|
|
|
|
|
flags.DEFINE_string('train_logdir', None, |
|
'Where the checkpoint and logs are stored.') |
|
|
|
flags.DEFINE_integer('log_steps', 10, |
|
'Display logging information at every log_steps.') |
|
|
|
flags.DEFINE_integer('save_interval_secs', 1200, |
|
'How often, in seconds, we save the model to disk.') |
|
|
|
flags.DEFINE_integer('save_summaries_secs', 600, |
|
'How often, in seconds, we compute the summaries.') |
|
|
|
|
|
|
|
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'], |
|
'Learning rate policy for training.') |
|
|
|
flags.DEFINE_float('base_learning_rate', 0.0007, |
|
'The base learning rate for model training.') |
|
|
|
flags.DEFINE_float('learning_rate_decay_factor', 0.1, |
|
'The rate to decay the base learning rate.') |
|
|
|
flags.DEFINE_integer('learning_rate_decay_step', 2000, |
|
'Decay the base learning rate at a fixed step.') |
|
|
|
flags.DEFINE_float('learning_power', 0.9, |
|
'The power value used in the poly learning policy.') |
|
|
|
flags.DEFINE_integer('training_number_of_steps', 200000, |
|
'The number of steps used for training') |
|
|
|
flags.DEFINE_float('momentum', 0.9, 'The momentum value to use') |
|
|
|
flags.DEFINE_integer('train_batch_size', 6, |
|
'The number of images in each batch during training.') |
|
|
|
flags.DEFINE_integer('train_num_frames_per_video', 3, |
|
'The number of frames used per video during training') |
|
|
|
flags.DEFINE_float('weight_decay', 0.00004, |
|
'The value of the weight decay for training.') |
|
|
|
flags.DEFINE_multi_integer('train_crop_size', [465, 465], |
|
'Image crop size [height, width] during training.') |
|
|
|
flags.DEFINE_float('last_layer_gradient_multiplier', 1.0, |
|
'The gradient multiplier for last layers, which is used to ' |
|
'boost the gradient of last layers if the value > 1.') |
|
|
|
flags.DEFINE_boolean('upsample_logits', True, |
|
'Upsample logits during training.') |
|
|
|
flags.DEFINE_integer('batch_capacity_factor', 16, 'Batch capacity factor.') |
|
|
|
flags.DEFINE_integer('num_readers', 1, 'Number of readers for data provider.') |
|
|
|
flags.DEFINE_integer('batch_num_threads', 1, 'Batch number of threads.') |
|
|
|
flags.DEFINE_integer('prefetch_queue_capacity_factor', 32, |
|
'Prefetch queue capacity factor.') |
|
|
|
flags.DEFINE_integer('prefetch_queue_num_threads', 1, |
|
'Prefetch queue number of threads.') |
|
|
|
flags.DEFINE_integer('train_max_neighbors_per_object', 1024, |
|
'The maximum number of candidates for the nearest ' |
|
'neighbor query per object after subsampling') |
|
|
|
|
|
|
|
flags.DEFINE_string('tf_initial_checkpoint', None, |
|
'The initial checkpoint in tensorflow format.') |
|
|
|
flags.DEFINE_boolean('initialize_last_layer', False, |
|
'Initialize the last layer.') |
|
|
|
flags.DEFINE_boolean('last_layers_contain_logits_only', False, |
|
'Only consider logits as last layers or not.') |
|
|
|
flags.DEFINE_integer('slow_start_step', 0, |
|
'Training model with small learning rate for few steps.') |
|
|
|
flags.DEFINE_float('slow_start_learning_rate', 1e-4, |
|
'Learning rate employed during slow start.') |
|
|
|
flags.DEFINE_boolean('fine_tune_batch_norm', False, |
|
'Fine tune the batch norm parameters or not.') |
|
|
|
flags.DEFINE_float('min_scale_factor', 1., |
|
'Mininum scale factor for data augmentation.') |
|
|
|
flags.DEFINE_float('max_scale_factor', 1.3, |
|
'Maximum scale factor for data augmentation.') |
|
|
|
flags.DEFINE_float('scale_factor_step_size', 0, |
|
'Scale factor step size for data augmentation.') |
|
|
|
flags.DEFINE_multi_integer('atrous_rates', None, |
|
'Atrous rates for atrous spatial pyramid pooling.') |
|
|
|
flags.DEFINE_integer('output_stride', 8, |
|
'The ratio of input to output spatial resolution.') |
|
|
|
flags.DEFINE_boolean('sample_only_first_frame_for_finetuning', False, |
|
'Whether to only sample the first frame during ' |
|
'fine-tuning. This should be False when using lucid data, ' |
|
'but True when fine-tuning on the first frame only. Only ' |
|
'has an effect if first_frame_finetuning is True.') |
|
|
|
flags.DEFINE_multi_integer('first_frame_finetuning', [0], |
|
'Whether to only sample the first frame for ' |
|
'fine-tuning.') |
|
|
|
|
|
|
|
flags.DEFINE_multi_string('dataset', [], 'Name of the segmentation datasets.') |
|
|
|
flags.DEFINE_multi_float('dataset_sampling_probabilities', [], |
|
'A list of probabilities to sample each of the ' |
|
'datasets.') |
|
|
|
flags.DEFINE_string('train_split', 'train', |
|
'Which split of the dataset to be used for training') |
|
|
|
flags.DEFINE_multi_string('dataset_dir', [], 'Where the datasets reside.') |
|
|
|
flags.DEFINE_multi_integer('three_frame_dataset', [0], |
|
'Whether the dataset has exactly three frames per ' |
|
'video of which the first is to be used as reference' |
|
' and the two others are consecutive frames to be ' |
|
'used as query frames.' |
|
'Set true for pascal lucid data.') |
|
|
|
flags.DEFINE_boolean('damage_initial_previous_frame_mask', False, |
|
'Whether to artificially damage the initial previous ' |
|
'frame mask. Only has an effect if ' |
|
'also_attend_to_previous_frame is True.') |
|
|
|
flags.DEFINE_float('top_k_percent_pixels', 0.15, 'Float in [0.0, 1.0].' |
|
'When its value < 1.0, only compute the loss for the top k' |
|
'percent pixels (e.g., the top 20% pixels). This is useful' |
|
'for hard pixel mining.') |
|
|
|
flags.DEFINE_integer('hard_example_mining_step', 100000, |
|
'The training step in which the hard exampling mining ' |
|
'kicks off. Note that we gradually reduce the mining ' |
|
'percent to the top_k_percent_pixels. For example, if ' |
|
'hard_example_mining_step=100K and ' |
|
'top_k_percent_pixels=0.25, then mining percent will ' |
|
'gradually reduce from 100% to 25% until 100K steps ' |
|
'after which we only mine top 25% pixels. Only has an ' |
|
'effect if top_k_percent_pixels < 1.0') |
|
|
|
|
|
def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes, |
|
ignore_label): |
|
"""Builds a clone of DeepLab. |
|
|
|
Args: |
|
inputs_queue_or_samples: A prefetch queue for images and labels, or |
|
directly a dict of the samples. |
|
outputs_to_num_classes: A map from output type to the number of classes. |
|
For example, for the task of semantic segmentation with 21 semantic |
|
classes, we would have outputs_to_num_classes['semantic'] = 21. |
|
ignore_label: Ignore label. |
|
|
|
Returns: |
|
A map of maps from output_type (e.g., semantic prediction) to a |
|
dictionary of multi-scale logits names to logits. For each output_type, |
|
the dictionary has keys which correspond to the scales and values which |
|
correspond to the logits. For example, if `scales` equals [1.0, 1.5], |
|
then the keys would include 'merged_logits', 'logits_1.00' and |
|
'logits_1.50'. |
|
|
|
Raises: |
|
ValueError: If classification_loss is not softmax, softmax_with_attention, |
|
or triplet. |
|
""" |
|
if hasattr(inputs_queue_or_samples, 'dequeue'): |
|
samples = inputs_queue_or_samples.dequeue() |
|
else: |
|
samples = inputs_queue_or_samples |
|
train_crop_size = (None if 0 in FLAGS.train_crop_size else |
|
FLAGS.train_crop_size) |
|
|
|
model_options = common.VideoModelOptions( |
|
outputs_to_num_classes=outputs_to_num_classes, |
|
crop_size=train_crop_size, |
|
atrous_rates=FLAGS.atrous_rates, |
|
output_stride=FLAGS.output_stride) |
|
|
|
if model_options.classification_loss == 'softmax_with_attention': |
|
clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones |
|
|
|
|
|
for n in range(clone_batch_size): |
|
tf.summary.image( |
|
'gt_label_%d' % n, |
|
tf.cast(samples[common.LABEL][ |
|
n * FLAGS.train_num_frames_per_video: |
|
(n + 1) * FLAGS.train_num_frames_per_video], |
|
tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video) |
|
|
|
if common.PRECEDING_FRAME_LABEL in samples: |
|
preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL] |
|
init_softmax = [] |
|
for n in range(clone_batch_size): |
|
init_softmax_n = embedding_utils.create_initial_softmax_from_labels( |
|
preceding_frame_label[n, tf.newaxis], |
|
samples[common.LABEL][n * FLAGS.train_num_frames_per_video, |
|
tf.newaxis], |
|
common.parse_decoder_output_stride(), |
|
reduce_labels=True) |
|
init_softmax_n = tf.squeeze(init_softmax_n, axis=0) |
|
init_softmax.append(init_softmax_n) |
|
tf.summary.image('preceding_frame_label', |
|
tf.cast(preceding_frame_label[n, tf.newaxis], |
|
tf.uint8) * 32) |
|
else: |
|
init_softmax = None |
|
|
|
outputs_to_scales_to_logits = ( |
|
model.multi_scale_logits_with_nearest_neighbor_matching( |
|
samples[common.IMAGE], |
|
model_options=model_options, |
|
image_pyramid=FLAGS.image_pyramid, |
|
weight_decay=FLAGS.weight_decay, |
|
is_training=True, |
|
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, |
|
reference_labels=samples[common.LABEL], |
|
clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones, |
|
num_frames_per_video=FLAGS.train_num_frames_per_video, |
|
embedding_dimension=FLAGS.embedding_dimension, |
|
max_neighbors_per_object=FLAGS.train_max_neighbors_per_object, |
|
k_nearest_neighbors=FLAGS.k_nearest_neighbors, |
|
use_softmax_feedback=FLAGS.use_softmax_feedback, |
|
initial_softmax_feedback=init_softmax, |
|
embedding_seg_feature_dimension= |
|
FLAGS.embedding_seg_feature_dimension, |
|
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers, |
|
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size, |
|
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates, |
|
normalize_nearest_neighbor_distances= |
|
FLAGS.normalize_nearest_neighbor_distances, |
|
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame, |
|
damage_initial_previous_frame_mask= |
|
FLAGS.damage_initial_previous_frame_mask, |
|
use_local_previous_frame_attention= |
|
FLAGS.use_local_previous_frame_attention, |
|
previous_frame_attention_window_size= |
|
FLAGS.previous_frame_attention_window_size, |
|
use_first_frame_matching=FLAGS.use_first_frame_matching |
|
)) |
|
else: |
|
outputs_to_scales_to_logits = model.multi_scale_logits_v2( |
|
samples[common.IMAGE], |
|
model_options=model_options, |
|
image_pyramid=FLAGS.image_pyramid, |
|
weight_decay=FLAGS.weight_decay, |
|
is_training=True, |
|
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) |
|
|
|
if model_options.classification_loss == 'softmax': |
|
for output, num_classes in six.iteritems(outputs_to_num_classes): |
|
train_utils.add_softmax_cross_entropy_loss_for_each_scale( |
|
outputs_to_scales_to_logits[output], |
|
samples[common.LABEL], |
|
num_classes, |
|
ignore_label, |
|
loss_weight=1.0, |
|
upsample_logits=FLAGS.upsample_logits, |
|
scope=output) |
|
elif model_options.classification_loss == 'triplet': |
|
for output, _ in six.iteritems(outputs_to_num_classes): |
|
train_utils.add_triplet_loss_for_each_scale( |
|
FLAGS.train_batch_size // FLAGS.num_clones, |
|
FLAGS.train_num_frames_per_video, |
|
FLAGS.embedding_dimension, outputs_to_scales_to_logits[output], |
|
samples[common.LABEL], scope=output) |
|
elif model_options.classification_loss == 'softmax_with_attention': |
|
labels = samples[common.LABEL] |
|
batch_size = FLAGS.train_batch_size // FLAGS.num_clones |
|
num_frames_per_video = FLAGS.train_num_frames_per_video |
|
h, w = train_utils.resolve_shape(labels)[1:3] |
|
labels = tf.reshape(labels, tf.stack( |
|
[batch_size, num_frames_per_video, h, w, 1])) |
|
|
|
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback: |
|
n_ref_frames = 2 |
|
else: |
|
n_ref_frames = 1 |
|
labels = labels[:, n_ref_frames:] |
|
|
|
labels = tf.reshape(labels, tf.stack( |
|
[batch_size * (num_frames_per_video - n_ref_frames), h, w, 1])) |
|
|
|
for output, num_classes in six.iteritems(outputs_to_num_classes): |
|
train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale( |
|
outputs_to_scales_to_logits[output], |
|
labels, |
|
ignore_label, |
|
loss_weight=1.0, |
|
upsample_logits=FLAGS.upsample_logits, |
|
scope=output, |
|
top_k_percent_pixels=FLAGS.top_k_percent_pixels, |
|
hard_example_mining_step=FLAGS.hard_example_mining_step) |
|
else: |
|
raise ValueError('Only support softmax, softmax_with_attention' |
|
' or triplet for classification_loss.') |
|
|
|
return outputs_to_scales_to_logits |
|
|
|
|
|
def main(unused_argv): |
|
|
|
config = model_deploy.DeploymentConfig( |
|
num_clones=FLAGS.num_clones, |
|
clone_on_cpu=FLAGS.clone_on_cpu, |
|
replica_id=FLAGS.task, |
|
num_replicas=FLAGS.num_replicas, |
|
num_ps_tasks=FLAGS.num_ps_tasks) |
|
|
|
with tf.Graph().as_default(): |
|
with tf.device(config.inputs_device()): |
|
train_crop_size = (None if 0 in FLAGS.train_crop_size else |
|
FLAGS.train_crop_size) |
|
assert FLAGS.dataset |
|
assert len(FLAGS.dataset) == len(FLAGS.dataset_dir) |
|
if len(FLAGS.first_frame_finetuning) == 1: |
|
first_frame_finetuning = (list(FLAGS.first_frame_finetuning) |
|
* len(FLAGS.dataset)) |
|
else: |
|
first_frame_finetuning = FLAGS.first_frame_finetuning |
|
if len(FLAGS.three_frame_dataset) == 1: |
|
three_frame_dataset = (list(FLAGS.three_frame_dataset) |
|
* len(FLAGS.dataset)) |
|
else: |
|
three_frame_dataset = FLAGS.three_frame_dataset |
|
assert len(FLAGS.dataset) == len(first_frame_finetuning) |
|
assert len(FLAGS.dataset) == len(three_frame_dataset) |
|
datasets, samples_list = zip( |
|
*[_get_dataset_and_samples(config, train_crop_size, dataset, |
|
dataset_dir, bool(first_frame_finetuning_), |
|
bool(three_frame_dataset_)) |
|
for dataset, dataset_dir, first_frame_finetuning_, |
|
three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir, |
|
first_frame_finetuning, |
|
three_frame_dataset)]) |
|
|
|
|
|
|
|
dataset = datasets[0] |
|
if len(samples_list) == 1: |
|
samples = samples_list[0] |
|
else: |
|
probabilities = FLAGS.dataset_sampling_probabilities |
|
if probabilities: |
|
assert len(probabilities) == len(samples_list) |
|
else: |
|
|
|
probabilities = [1.0 / len(samples_list) for _ in samples_list] |
|
probabilities = tf.constant(probabilities) |
|
logits = tf.log(probabilities[tf.newaxis]) |
|
rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32), |
|
axis=[0, 1]) |
|
|
|
def wrap(x): |
|
def f(): |
|
return x |
|
return f |
|
|
|
samples = tf.case({tf.equal(rand_idx, idx): wrap(s) |
|
for idx, s in enumerate(samples_list)}, |
|
exclusive=True) |
|
|
|
|
|
|
|
if train_crop_size is None: |
|
inputs_queue = samples |
|
else: |
|
inputs_queue = prefetch_queue.prefetch_queue( |
|
samples, |
|
capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones, |
|
num_threads=FLAGS.prefetch_queue_num_threads) |
|
|
|
|
|
with tf.device(config.variables_device()): |
|
global_step = tf.train.get_or_create_global_step() |
|
|
|
|
|
model_fn = _build_deeplab |
|
if FLAGS.classification_loss == 'triplet': |
|
embedding_dim = FLAGS.embedding_dimension |
|
output_type_to_dim = {'embedding': embedding_dim} |
|
else: |
|
output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes} |
|
model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label) |
|
clones = model_deploy.create_clones(config, model_fn, args=model_args) |
|
|
|
|
|
|
|
first_clone_scope = config.clone_scope(0) |
|
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) |
|
|
|
|
|
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) |
|
|
|
|
|
for model_var in tf.contrib.framework.get_model_variables(): |
|
summaries.add(tf.summary.histogram(model_var.op.name, model_var)) |
|
|
|
|
|
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): |
|
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) |
|
|
|
|
|
with tf.device(config.optimizer_device()): |
|
learning_rate = train_utils.get_model_learning_rate( |
|
FLAGS.learning_policy, |
|
FLAGS.base_learning_rate, |
|
FLAGS.learning_rate_decay_step, |
|
FLAGS.learning_rate_decay_factor, |
|
FLAGS.training_number_of_steps, |
|
FLAGS.learning_power, |
|
FLAGS.slow_start_step, |
|
FLAGS.slow_start_learning_rate) |
|
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) |
|
summaries.add(tf.summary.scalar('learning_rate', learning_rate)) |
|
|
|
startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps |
|
|
|
with tf.device(config.variables_device()): |
|
total_loss, grads_and_vars = model_deploy.optimize_clones( |
|
clones, optimizer) |
|
total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') |
|
summaries.add(tf.summary.scalar('total_loss', total_loss)) |
|
|
|
|
|
last_layers = model.get_extra_layer_scopes( |
|
FLAGS.last_layers_contain_logits_only) |
|
grad_mult = train_utils.get_model_gradient_multipliers( |
|
last_layers, FLAGS.last_layer_gradient_multiplier) |
|
if grad_mult: |
|
grads_and_vars = slim.learning.multiply_gradients(grads_and_vars, |
|
grad_mult) |
|
|
|
with tf.name_scope('grad_clipping'): |
|
grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_updates = optimizer.apply_gradients(grads_and_vars, |
|
global_step=global_step) |
|
update_ops.append(grad_updates) |
|
update_op = tf.group(*update_ops) |
|
with tf.control_dependencies([update_op]): |
|
train_tensor = tf.identity(total_loss, name='train_op') |
|
|
|
|
|
|
|
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, |
|
first_clone_scope)) |
|
|
|
|
|
summary_op = tf.summary.merge(list(summaries)) |
|
|
|
|
|
session_config = tf.ConfigProto(allow_soft_placement=True, |
|
log_device_placement=False) |
|
|
|
|
|
slim.learning.train( |
|
train_tensor, |
|
logdir=FLAGS.train_logdir, |
|
log_every_n_steps=FLAGS.log_steps, |
|
master=FLAGS.master, |
|
number_of_steps=FLAGS.training_number_of_steps, |
|
is_chief=(FLAGS.task == 0), |
|
session_config=session_config, |
|
startup_delay_steps=startup_delay_steps, |
|
init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir, |
|
FLAGS.tf_initial_checkpoint, |
|
FLAGS.initialize_last_layer, |
|
last_layers, |
|
ignore_missing_vars=True), |
|
summary_op=summary_op, |
|
save_summaries_secs=FLAGS.save_summaries_secs, |
|
save_interval_secs=FLAGS.save_interval_secs) |
|
|
|
|
|
def _get_dataset_and_samples(config, train_crop_size, dataset_name, |
|
dataset_dir, first_frame_finetuning, |
|
three_frame_dataset): |
|
"""Creates dataset object and samples dict of tensor. |
|
|
|
Args: |
|
config: A DeploymentConfig. |
|
train_crop_size: Integer, the crop size used for training. |
|
dataset_name: String, the name of the dataset. |
|
dataset_dir: String, the directory of the dataset. |
|
first_frame_finetuning: Boolean, whether the used dataset is a dataset |
|
for first frame fine-tuning. |
|
three_frame_dataset: Boolean, whether the dataset has exactly three frames |
|
per video of which the first is to be used as reference and the two |
|
others are consecutive frames to be used as query frames. |
|
|
|
Returns: |
|
dataset: An instance of slim Dataset. |
|
samples: A dictionary of tensors for semantic segmentation. |
|
""" |
|
|
|
|
|
assert FLAGS.train_batch_size % config.num_clones == 0, ( |
|
'Training batch size not divisble by number of clones (GPUs).') |
|
|
|
clone_batch_size = FLAGS.train_batch_size / config.num_clones |
|
|
|
if first_frame_finetuning: |
|
train_split = 'val' |
|
else: |
|
train_split = FLAGS.train_split |
|
|
|
data_type = 'tf_sequence_example' |
|
|
|
dataset = video_dataset.get_dataset( |
|
dataset_name, |
|
train_split, |
|
dataset_dir=dataset_dir, |
|
data_type=data_type) |
|
|
|
tf.gfile.MakeDirs(FLAGS.train_logdir) |
|
tf.logging.info('Training on %s set', train_split) |
|
|
|
samples = video_input_generator.get( |
|
dataset, |
|
FLAGS.train_num_frames_per_video, |
|
train_crop_size, |
|
clone_batch_size, |
|
num_readers=FLAGS.num_readers, |
|
num_threads=FLAGS.batch_num_threads, |
|
min_resize_value=FLAGS.min_resize_value, |
|
max_resize_value=FLAGS.max_resize_value, |
|
resize_factor=FLAGS.resize_factor, |
|
min_scale_factor=FLAGS.min_scale_factor, |
|
max_scale_factor=FLAGS.max_scale_factor, |
|
scale_factor_step_size=FLAGS.scale_factor_step_size, |
|
dataset_split=FLAGS.train_split, |
|
is_training=True, |
|
model_variant=FLAGS.model_variant, |
|
batch_capacity_factor=FLAGS.batch_capacity_factor, |
|
decoder_output_stride=common.parse_decoder_output_stride(), |
|
first_frame_finetuning=first_frame_finetuning, |
|
sample_only_first_frame_for_finetuning= |
|
FLAGS.sample_only_first_frame_for_finetuning, |
|
sample_adjacent_and_consistent_query_frames= |
|
FLAGS.sample_adjacent_and_consistent_query_frames or |
|
FLAGS.use_softmax_feedback, |
|
remap_labels_to_reference_frame=True, |
|
three_frame_dataset=three_frame_dataset, |
|
add_prev_frame_label=not FLAGS.also_attend_to_previous_frame |
|
) |
|
return dataset, samples |
|
|
|
|
|
if __name__ == '__main__': |
|
flags.mark_flag_as_required('train_logdir') |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
tf.app.run() |
|
|