okeowo1014 commited on
Commit
0cfc9ed
·
verified ·
1 Parent(s): 12e32c5

Upload train3.py

Browse files
Files changed (1) hide show
  1. train3.py +78 -0
train3.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
3
+ from tensorflow.keras.applications import VGG16
4
+ from tensorflow.keras.layers import Flatten, Dense
5
+
6
+ # Define data paths (modify as needed)
7
+ train_data_dir = 'tt'
8
+ validation_data_dir = 'tt'
9
+ test_data_dir = 'tt'
10
+
11
+ # Set image dimensions (adjust if necessary)
12
+ img_width, img_height = 224, 224 # VGG16 expects these dimensions
13
+
14
+ # Data augmentation for improved generalization (optional)
15
+ train_datagen = ImageDataGenerator(
16
+ rescale=1./255, # Normalize pixel values
17
+ shear_range=0.2,
18
+ zoom_range=0.2,
19
+ horizontal_flip=True,
20
+ fill_mode='nearest'
21
+ )
22
+
23
+ validation_datagen = ImageDataGenerator(rescale=1./255) # Only rescale for validation
24
+
25
+ # Load training and validation data
26
+ train_generator = train_datagen.flow_from_directory(
27
+ train_data_dir,
28
+ target_size=(img_width, img_height),
29
+ batch_size=32, # Adjust batch size based on GPU memory
30
+ class_mode='binary' # Two classes: cat or dog
31
+ )
32
+
33
+ validation_generator = validation_datagen.flow_from_directory(
34
+ validation_data_dir,
35
+ target_size=(img_width, img_height),
36
+ batch_size=32,
37
+ class_mode='binary'
38
+ )
39
+
40
+ # Load pre-trained VGG16 model (without the top layers)
41
+ base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
42
+
43
+ # Freeze the base model layers (optional - experiment with unfreezing for fine-tuning)
44
+ base_model.trainable = False
45
+
46
+ # Add custom layers for classification on top of the pre-trained model
47
+ x = base_model.output
48
+ x = Flatten()(x)
49
+ predictions = Dense(1, activation='sigmoid')(x) # Sigmoid for binary classification
50
+
51
+ # Create the final model
52
+ model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
53
+
54
+ # Compile the model
55
+ model.compile(loss='binary_crossentropy',
56
+ optimizer='adam',
57
+ metrics=['accuracy'])
58
+
59
+ # Train the model
60
+ history = model.fit(
61
+ train_generator,
62
+ epochs=10, # Adjust number of epochs based on dataset size and validation performance
63
+ validation_data=validation_generator
64
+ )
65
+
66
+ # Evaluate the model on test data (optional)
67
+ test_generator = validation_datagen.flow_from_directory(
68
+ test_data_dir,
69
+ target_size=(img_width, img_height),
70
+ batch_size=32,
71
+ class_mode='binary'
72
+ )
73
+
74
+ test_loss, test_acc = model.evaluate(test_generator)
75
+ print('Test accuracy:', test_acc)
76
+
77
+ # Save the model for future use (optional)
78
+ model.save('cat_dog_classifier.h5')