Sanjayraju30 commited on
Commit
323a5cf
·
verified ·
1 Parent(s): 235bed0

Create train_mnist_model.py

Browse files
Files changed (1) hide show
  1. train_mnist_model.py +47 -0
train_mnist_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, models
3
+ import logging
4
+
5
+ # Set up logging
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Load and preprocess MNIST dataset
10
+ def load_and_preprocess_data():
11
+ try:
12
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
13
+ x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
14
+ x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
15
+ logger.info("MNIST dataset loaded and preprocessed successfully")
16
+ return x_train, y_train, x_test, y_test
17
+ except Exception as e:
18
+ logger.error(f"Error loading MNIST data: {e}")
19
+ return None, None, None, None
20
+
21
+ # Build and train CNN model
22
+ def train_model():
23
+ x_train, y_train, x_test, y_test = load_and_preprocess_data()
24
+ if x_train is None:
25
+ return
26
+
27
+ model = models.Sequential([
28
+ layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
29
+ layers.MaxPooling2D((2, 2)),
30
+ layers.Conv2D(64, (3, 3), activation='relu'),
31
+ layers.MaxPooling2D((2, 2)),
32
+ layers.Flatten(),
33
+ layers.Dense(128, activation='relu'),
34
+ layers.Dense(10, activation='softmax')
35
+ ])
36
+
37
+ model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
38
+
39
+ try:
40
+ model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
41
+ model.save('mnist_cnn.h5')
42
+ logger.info("Model trained and saved as mnist_cnn.h5")
43
+ except Exception as e:
44
+ logger.error(f"Error training model: {e}")
45
+
46
+ if __name__ == "__main__":
47
+ train_model()