AutoWeightLogger1 / train_mnist_model.py
Sanjayraju30's picture
Create train_mnist_model.py
323a5cf verified
raw
history blame
1.66 kB
import tensorflow as tf
from tensorflow.keras import layers, models
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess MNIST dataset
def load_and_preprocess_data():
try:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
logger.info("MNIST dataset loaded and preprocessed successfully")
return x_train, y_train, x_test, y_test
except Exception as e:
logger.error(f"Error loading MNIST data: {e}")
return None, None, None, None
# Build and train CNN model
def train_model():
x_train, y_train, x_test, y_test = load_and_preprocess_data()
if x_train is None:
return
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
try:
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
model.save('mnist_cnn.h5')
logger.info("Model trained and saved as mnist_cnn.h5")
except Exception as e:
logger.error(f"Error training model: {e}")
if __name__ == "__main__":
train_model()