|
import tensorflow as tf |
|
from tensorflow.keras import datasets, layers, models |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() |
|
|
|
|
|
train_images, test_images = train_images / 255.0, test_images / 255.0 |
|
|
|
|
|
model = models.Sequential() |
|
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
model.add(layers.MaxPooling2D((2, 2))) |
|
model.add(layers.Conv2D(64, (3, 3), activation='relu')) |
|
model.add(layers.Flatten()) |
|
model.add(layers.Dense(64, activation='relu')) |
|
model.add(layers.Dense(10)) |
|
|
|
|
|
model.compile(optimizer='adam', |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy']) |
|
|
|
|
|
class AccuracyCallback(tf.keras.callbacks.Callback): |
|
def on_epoch_end(self, epoch, logs={}): |
|
if logs.get('val_accuracy') > 0.90: |
|
print("\nReached 90% accuracy, stopping training...") |
|
self.model.stop_training = True |
|
accuracy_callback = AccuracyCallback() |
|
|
|
|
|
history = model.fit(train_images, train_labels, epochs=50, |
|
validation_data=(test_images, test_labels), |
|
callbacks=[accuracy_callback]) |
|
|
|
|
|
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) |
|
print('Test accuracy:', test_acc) |
|
|
|
|
|
acc = history.history['accuracy'] |
|
val_acc = history.history['val_accuracy'] |
|
loss = history.history['loss'] |
|
val_loss = history.history['val_loss'] |
|
epochs = range(len(acc)) |
|
plt.figure(figsize=(10, 5)) |
|
plt.subplot(1, 2, 1) |
|
plt.plot(epochs, acc, 'r', label='Training accuracy') |
|
plt.plot(epochs, val_acc, 'b', label='Validation accuracy') |
|
plt.title('Training and validation accuracy') |
|
plt.legend() |
|
plt.subplot(1, 2, 2) |
|
plt.plot(epochs, loss, 'r', label='Training loss') |
|
plt.plot(epochs, val_loss, 'b', label='Validation loss') |
|
plt.title('Training and validation loss') |
|
plt.legend() |
|
plt.show() |
|
|