Spaces:
Runtime error
Runtime error
from icecream import ic | |
ic("--- Importing tensorflow ---") | |
import tensorflow as tf | |
from tensorflow.keras import layers, models | |
from tensorflow. keras. utils import plot_model | |
from tensorflow.keras import Input | |
# load mnist dataset | |
ic("------ Loading mnist dataset ------") | |
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() | |
# normalize 60000 instances, 28x28 pixels 1 channel | |
ic("------ Normalizing data ------") | |
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255 | |
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255 | |
# labels are output numbers, 0 to 9 we need to convert them to one-hot encoding | |
ic("------ One-hot encoding labels ------") | |
train_labels = tf.keras.utils.to_categorical(train_labels) | |
# 1 for correct digit, 0 for incorrect | |
ic("------ Creating model ------") | |
# define model | |
model = models.Sequential() | |
# Add an Input layer | |
model.add(Input(shape=(28, 28, 1))) | |
# create convolutional layer | |
# 32 filters, 3x3 kernel, relu activation function, input shape 28x28x1 | |
# them create max pooling layer 2x2 | |
#model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu',)) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu',)) | |
model.add(layers.Flatten()) | |
model.add(layers.Dense(64, activation='relu')) | |
# output layer, | |
model.add(layers.Dense(10, activation='softmax')) | |
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) | |
# generate model.png with architecture plot | |
ic("------ Plotting model ------") | |
plot_model(model, to_file='static/model.png', show_shapes=True, show_layer_names=True) | |
# train | |
ic("------ Training model ------") | |
#model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_split=0.1) | |
# save model | |
ic("------ Saving .h5 model ------") | |
model.save('saved_models/keras/mnist_model.h5') | |
ic("------ Saving .keras model ------") | |
model.save('saved_models/keras/mnist_model.keras') | |
ic("------ Exporting .keras model ------") | |
model.export('saved_models/exported') | |