Manu
updated files
54c0759
from icecream import ic
ic("--- Importing tensorflow ---")
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow. keras. utils import plot_model
from tensorflow.keras import Input
# load mnist dataset
ic("------ Loading mnist dataset ------")
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# normalize 60000 instances, 28x28 pixels 1 channel
ic("------ Normalizing data ------")
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
# labels are output numbers, 0 to 9 we need to convert them to one-hot encoding
ic("------ One-hot encoding labels ------")
train_labels = tf.keras.utils.to_categorical(train_labels)
# 1 for correct digit, 0 for incorrect
ic("------ Creating model ------")
# define model
model = models.Sequential()
# Add an Input layer
model.add(Input(shape=(28, 28, 1)))
# create convolutional layer
# 32 filters, 3x3 kernel, relu activation function, input shape 28x28x1
# them create max pooling layer 2x2
#model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu',))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu',))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
# output layer,
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# generate model.png with architecture plot
ic("------ Plotting model ------")
plot_model(model, to_file='static/model.png', show_shapes=True, show_layer_names=True)
# train
ic("------ Training model ------")
#model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_split=0.1)
# save model
ic("------ Saving .h5 model ------")
model.save('saved_models/keras/mnist_model.h5')
ic("------ Saving .keras model ------")
model.save('saved_models/keras/mnist_model.keras')
ic("------ Exporting .keras model ------")
model.export('saved_models/exported')