Spaces:
Running
Running
import tensorflow as tf | |
from tensorflow.keras import layers, models | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load and preprocess MNIST dataset | |
def load_and_preprocess_data(): | |
try: | |
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | |
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0 | |
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0 | |
logger.info("MNIST dataset loaded and preprocessed successfully") | |
return x_train, y_train, x_test, y_test | |
except Exception as e: | |
logger.error(f"Error loading MNIST data: {e}") | |
return None, None, None, None | |
# Build and train CNN model | |
def train_model(): | |
x_train, y_train, x_test, y_test = load_and_preprocess_data() | |
if x_train is None: | |
return | |
model = models.Sequential([ | |
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(64, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Flatten(), | |
layers.Dense(128, activation='relu'), | |
layers.Dense(10, activation='softmax') | |
]) | |
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
try: | |
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test)) | |
model.save('mnist_cnn.h5') | |
logger.info("Model trained and saved as mnist_cnn.h5") | |
except Exception as e: | |
logger.error(f"Error training model: {e}") | |
if __name__ == "__main__": | |
train_model() |