deanna-emery's picture
updates
93528c6
raw
history blame
6.12 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs a simple model on the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf, tf_keras
import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.legacy.image_classification.resnet import common
from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
def build_model():
"""Constructs the ML model used to predict handwritten digits."""
image = tf_keras.layers.Input(shape=(28, 28, 1))
y = tf_keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(image)
y = tf_keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf_keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(y)
y = tf_keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf_keras.layers.Flatten()(y)
y = tf_keras.layers.Dense(1024, activation='relu')(y)
y = tf_keras.layers.Dropout(0.4)(y)
probs = tf_keras.layers.Dense(10, activation='softmax')(y)
model = tf_keras.models.Model(image, probs, name='mnist')
return model
@tfds.decode.make_decoder(output_dtype=tf.float32)
def decode_image(example, feature):
"""Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
def run(flags_obj, datasets_override=None, strategy_override=None):
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
datasets_override: A pair of `tf.data.Dataset` objects to train the model,
representing the train and test sets.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training and eval stats.
"""
# Start TF profiler server.
tf.profiler.experimental.server.start(flags_obj.profiler_port)
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
mnist.download_and_prepare()
mnist_train, mnist_test = datasets_override or mnist.as_dataset(
split=['train', 'test'],
decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
as_supervised=True)
train_input_dataset = mnist_train.cache().repeat().shuffle(
buffer_size=50000).batch(flags_obj.batch_size)
eval_input_dataset = mnist_test.cache().repeat().batch(flags_obj.batch_size)
with strategy_scope:
lr_schedule = tf_keras.optimizers.schedules.ExponentialDecay(
0.05, decay_steps=100000, decay_rate=0.96)
optimizer = tf_keras.optimizers.SGD(learning_rate=lr_schedule)
model = build_model()
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
num_train_examples = mnist.info.splits['train'].num_examples
train_steps = num_train_examples // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks = [
tf_keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True),
tf_keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
]
num_eval_examples = mnist.info.splits['test'].num_examples
num_eval_steps = num_eval_examples // flags_obj.batch_size
history = model.fit(
train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
validation_freq=flags_obj.epochs_between_evals)
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
model.save(export_path, include_optimizer=False)
eval_output = model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_mnist_flags():
"""Define command line flags for MNIST model."""
flags_core.define_base(
clean=True,
num_gpu=True,
train_epochs=True,
epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_device()
flags_core.define_distribution()
flags.DEFINE_bool('download', True,
'Whether to download data to `--data_dir`.')
flags.DEFINE_integer('profiler_port', 9012,
'Port to start profiler server on.')
FLAGS.set_default('batch_size', 1024)
def main(_):
model_helpers.apply_clean(FLAGS)
stats = run(flags.FLAGS)
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_mnist_flags()
app.run(main)