mansesa3 commited on
Commit
db2ade3
·
verified ·
1 Parent(s): 5600d5a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK']= 'True'
4
+
5
+ import tensorflow as tf
6
+ tf.__version__
7
+
8
+ # %%
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+
14
+ # %%
15
+ # Set the paths to your dataset directories
16
+ train_dir = r'pokemon/train'
17
+ val_dir = r'pokemon/val'
18
+
19
+ # Ensure the paths are correctly formatted
20
+ train_dir = os.path.normpath(train_dir)
21
+ val_dir = os.path.normpath(val_dir)
22
+
23
+ # Load the datasets
24
+ train_ds = tf.keras.utils.image_dataset_from_directory(
25
+ directory=train_dir,
26
+ labels='inferred',
27
+ label_mode='int', # Use 'int' for sparse_categorical_crossentropy loss
28
+ batch_size=12,
29
+ image_size=(150, 150))
30
+
31
+ validation_ds = tf.keras.utils.image_dataset_from_directory(
32
+ directory=val_dir,
33
+ labels='inferred',
34
+ label_mode='int',
35
+ batch_size=12,
36
+ image_size=(150, 150))
37
+
38
+ # %%
39
+ val_batches = tf.data.experimental.cardinality(validation_ds)
40
+ test_ds = validation_ds.take(val_batches // 5)
41
+ validation_ds = validation_ds.skip(val_batches // 5)
42
+
43
+ # %%
44
+ print('Number of training batches: %d' % tf.data.experimental.cardinality(train_ds))
45
+ print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_ds))
46
+ print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))
47
+
48
+ # %%
49
+ class_names = train_ds.class_names
50
+
51
+ plt.figure(figsize=(10, 10))
52
+ for images, labels in train_ds.take(1):
53
+ for i in range(9):
54
+ ax = plt.subplot(3, 3, i + 1)
55
+ plt.imshow(images[i].numpy().astype("uint8"))
56
+ plt.title(class_names[labels[i]])
57
+ plt.axis("off")
58
+
59
+ # %%
60
+ number_of_classes = len(train_ds.class_names)
61
+ print(number_of_classes)
62
+ print(class_names)
63
+
64
+ # %%
65
+ #resize 150x150?
66
+ resize_fn = tf.keras.layers.Resizing(150, 150)
67
+
68
+ train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
69
+ validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
70
+ test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
71
+
72
+ # %%
73
+
74
+ # Build the model
75
+ base_model = tf.keras.applications.Xception(
76
+ weights="imagenet", # Load weights pre-trained on ImageNet.
77
+ input_shape=(150, 150, 3),
78
+ include_top=False,
79
+ ) # Do not include the ImageNet classifier at the top.
80
+
81
+ # Freeze the base_model
82
+ base_model.trainable = False
83
+
84
+ # Create new model on top
85
+ inputs = tf.keras.Input(shape=(150, 150, 3))
86
+
87
+ # Pre-trained Xception weights require that input be scaled
88
+ # from (0, 255) to a range of (-1., +1.), the rescaling layer
89
+ # outputs: `(inputs * scale) + offset`
90
+ scale_layer = tf.keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
91
+ x = scale_layer(inputs)
92
+
93
+ # The base model contains batchnorm layers. We want to keep them in inference mode
94
+ # when we unfreeze the base model for fine-tuning, so we make sure that the
95
+ # base_model is running in inference mode here.
96
+ x = base_model(x, training=False)
97
+ x = tf.keras.layers.GlobalAveragePooling2D()(x)
98
+ x = tf.keras.layers.Dropout(0.2)(x) # Regularize with dropout
99
+ outputs = tf.keras.layers.Dense(number_of_classes, activation="softmax")(x)
100
+ model = tf.keras.Model(inputs, outputs)
101
+
102
+ model.summary(show_trainable=True)
103
+
104
+
105
+ # %%
106
+ model.compile(optimizer=tf.keras.optimizers.Adam(),
107
+ loss="sparse_categorical_crossentropy",
108
+ metrics=['accuracy'])
109
+
110
+ initial_epochs = 4
111
+ print("Fitting the top layer of the model")
112
+ history = model.fit(train_ds, epochs=initial_epochs, validation_data=validation_ds)
113
+
114
+
115
+ # %%
116
+ acc = history.history['accuracy']
117
+ val_acc = history.history['val_accuracy']
118
+
119
+ loss = history.history['loss']
120
+ val_loss = history.history['val_loss']
121
+
122
+ plt.figure(figsize=(8, 8))
123
+ plt.subplot(2, 1, 1)
124
+ plt.plot(acc, label='Training Accuracy')
125
+ plt.plot(val_acc, label='Validation Accuracy')
126
+ plt.legend(loc='lower right')
127
+ plt.ylabel('Accuracy')
128
+ plt.ylim([min(plt.ylim()),1])
129
+ plt.title('Training and Validation Accuracy')
130
+
131
+ plt.subplot(2, 1, 2)
132
+ plt.plot(loss, label='Training Loss')
133
+ plt.plot(val_loss, label='Validation Loss')
134
+ plt.legend(loc='upper right')
135
+ plt.ylabel('Cross Entropy')
136
+ plt.title('Training and Validation Loss')
137
+ plt.xlabel('epoch')
138
+ plt.show()
139
+
140
+ # %%
141
+ base_model.trainable = True
142
+ model.summary(show_trainable=True)
143
+
144
+ model.compile(
145
+ optimizer=tf.keras.optimizers.Adam(1e-5), # Low learning rate
146
+ loss="sparse_categorical_crossentropy",
147
+ metrics=['accuracy']
148
+ )
149
+
150
+ epochs = 1
151
+ print("Fitting the end-to-end model")
152
+ history_fine = model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
153
+
154
+
155
+ # %%
156
+ acc += history_fine.history['accuracy']
157
+ val_acc += history_fine.history['val_accuracy']
158
+
159
+ loss += history_fine.history['loss']
160
+ val_loss += history_fine.history['val_loss']
161
+
162
+ plt.figure(figsize=(8, 8))
163
+ plt.subplot(2, 1, 1)
164
+ plt.plot(acc, label='Training Accuracy')
165
+ plt.plot(val_acc, label='Validation Accuracy')
166
+ plt.ylim([0.4, 1]) # set the y-axis limits
167
+ plt.plot([initial_epochs-1,initial_epochs-1],
168
+ plt.ylim(), label='Start Fine Tuning')
169
+ plt.legend(loc='lower right')
170
+ plt.title('Training and Validation Accuracy')
171
+
172
+ plt.subplot(2, 1, 2)
173
+ plt.plot(loss, label='Training Loss')
174
+ plt.plot(val_loss, label='Validation Loss')
175
+ plt.plot([initial_epochs-1,initial_epochs-1],
176
+ plt.ylim(), label='Start Fine Tuning')
177
+ plt.legend(loc='upper right')
178
+ plt.title('Training and Validation Loss')
179
+ plt.xlabel('epoch')
180
+ plt.show()
181
+
182
+ # %%
183
+ print("Test dataset evaluation")
184
+ model.evaluate(test_ds)
185
+
186
+ # %%
187
+ image_batch, label_batch = test_ds.as_numpy_iterator().next()
188
+ predictions_in_percentage = model.predict_on_batch(image_batch)
189
+ predictions = np.argmax(predictions_in_percentage, axis=-1)
190
+ print('Predictions:\n', predictions)
191
+ print('Labels:\n', label_batch)
192
+ plt.figure(figsize=(10, 10))
193
+ for i in range(9):
194
+ ax = plt.subplot(3, 3, i + 1)
195
+ plt.imshow(image_batch[i].astype("uint8"))
196
+ plt.title('pred. ' + class_names[predictions[i]] + ' was ' + class_names[label_batch[i]] + ' ' + str(np.round(predictions_in_percentage[i], 2)), fontsize=8)
197
+ plt.axis("off")
198
+
199
+ # %%
200
+ model.save('pokemon-model_transferlearning.keras')
201
+
202
+