import tensorflow as tf from tensorflow.keras import datasets, layers, models import matplotlib.pyplot as plt # Load CIFAR-10 dataset (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() # Normalize pixel values to be between 0 and 1 train_images, test_images = train_images / 255.0, test_images / 255.0 # Define the CNN model model = models.Sequential() model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(64, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(64, (3, 3), activation='relu')) model.add(layers.Flatten()) model.add(layers.Dense(64, activation='relu')) model.add(layers.Dense(10)) # Compile the model model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # Define a callback to stop training when desired accuracy is achieved class AccuracyCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs={}): if logs.get('val_accuracy') > 0.90: print("\nReached 90% accuracy, stopping training...") self.model.stop_training = True accuracy_callback = AccuracyCallback() # Train the model history = model.fit(train_images, train_labels, epochs=50, validation_data=(test_images, test_labels), callbacks=[accuracy_callback]) # Evaluate the model test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) print('Test accuracy:', test_acc) # Plot accuracy and loss curves acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.plot(epochs, acc, 'r', label='Training accuracy') plt.plot(epochs, val_acc, 'b', label='Validation accuracy') plt.title('Training and validation accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(epochs, loss, 'r', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()