arifsy commited on
Commit
96ba5ec
1 Parent(s): 6003da2

code to train model

Browse files
Files changed (1) hide show
  1. neural_models.py +152 -0
neural_models.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoImageProcessor, create_optimizer, TFAutoModelForImageClassification, KerasMetricCallback, \
3
+ PushToHubCallback, pipeline
4
+ import tensorflow as tf
5
+ from tensorflow.python import keras
6
+ from keras import layers, losses
7
+ import numpy as np
8
+ from PIL import Image
9
+ from transformers import DefaultDataCollator
10
+ import evaluate
11
+
12
+
13
+ def convert_to_tf_tensor(image: Image):
14
+ np_image = np.array(image)
15
+ tf_image = tf.convert_to_tensor(np_image)
16
+
17
+ # `expand_dims()` is used to add a batch dimension since
18
+ # the TF augmentation layers operates on batched inputs.
19
+ return tf.expand_dims(tf_image, 0)
20
+
21
+
22
+ def preprocess_train(example_batch):
23
+ """Apply train_transforms across a batch."""
24
+ images = [
25
+ train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
26
+ ]
27
+ example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
28
+ return example_batch
29
+
30
+
31
+ def preprocess_val(example_batch):
32
+ """Apply val_transforms across a batch."""
33
+ images = [
34
+ val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
35
+ ]
36
+ example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
37
+ return example_batch
38
+
39
+
40
+ def compute_metrics(eval_pred):
41
+ predictions, labels = eval_pred
42
+ predictions = np.argmax(predictions, axis=1)
43
+ return accuracy.compute(predictions=predictions, references=labels)
44
+
45
+
46
+ # load dataset
47
+ fashion = load_dataset("fashion_mnist", split="train[:4000]")
48
+ # Split into train/test sets
49
+ fashion = fashion.train_test_split(test_size=0.2)
50
+ # an example
51
+ print(fashion["train"][0])
52
+
53
+ # Map label names to an integer and vice-versa
54
+ labels = fashion["train"].features["label"].names
55
+ label2id, id2label = dict(), dict()
56
+ for i, label in enumerate(labels):
57
+ label2id[label] = str(i)
58
+ id2label[str(i)] = label
59
+
60
+ # Should convert label id into a name
61
+ # print(label2id)
62
+ # print(id2label)
63
+
64
+ # Pre-processing with ViT
65
+ # Load image processor to process image into tensor
66
+ checkpoint = "google/vit-base-patch16-224-in21k"
67
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
68
+
69
+ # To avoid overfitting and make the model more robust, add data augmentation to the training set.
70
+ # User Keras preprocessing layers to define transformations for the training set.
71
+ size = (image_processor.size["height"], image_processor.size["width"])
72
+
73
+ train_data_augmentation = keras.Sequential(
74
+ [
75
+ layers.RandomCrop(size[0], size[1]),
76
+ layers.Rescaling(scale=1.0 / 127.5, offset=-1),
77
+ layers.RandomFlip("horizontal"),
78
+ layers.RandomRotation(factor=0.02),
79
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
80
+ ],
81
+ name="train_data_augmentation",
82
+ )
83
+
84
+ val_data_augmentation = keras.Sequential(
85
+ [
86
+ layers.CenterCrop(size[0], size[1]),
87
+ layers.Rescaling(scale=1.0 / 127.5, offset=-1),
88
+ ],
89
+ name="val_data_augmentation",
90
+ )
91
+
92
+ fashion["train"].set_transform(preprocess_train)
93
+ fashion["test"].set_transform(preprocess_val)
94
+
95
+ data_collator = DefaultDataCollator(return_tensors="tf")
96
+
97
+ accuracy = evaluate.load("accuracy")
98
+
99
+ # Set hyperparameters
100
+ batch_size = 16
101
+ num_epochs = 4
102
+ num_train_steps = len(fashion["train"]) * num_epochs
103
+ learning_rate = 3e-5
104
+ weight_decay_rate = 0.01
105
+
106
+ # define optimizer, learning rate schedule
107
+ optimizer, lr_schedule = create_optimizer(
108
+ init_lr=learning_rate,
109
+ num_train_steps=num_train_steps,
110
+ weight_decay_rate=weight_decay_rate,
111
+ num_warmup_steps=0,
112
+ )
113
+
114
+ # Load ViT along with label mappings
115
+ model = TFAutoModelForImageClassification.from_pretrained(
116
+ checkpoint,
117
+ id2label=id2label,
118
+ label2id=label2id,
119
+ )
120
+
121
+ # converting datasets to tf.data.Dataset
122
+ tf_train_dataset = fashion["train"].to_tf_dataset(
123
+ columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator
124
+ )
125
+
126
+ tf_eval_dataset = fashion["test"].to_tf_dataset(
127
+ columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator
128
+ )
129
+
130
+ # Configure model for training
131
+ loss = losses.SparseCategoricalCrossentropy(from_logits=True)
132
+ model.compile(optimizer=optimizer, loss=loss)
133
+
134
+
135
+ metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
136
+ push_to_hub_callback = PushToHubCallback(
137
+ output_dir="../fashion_classifier",
138
+ tokenizer=image_processor,
139
+ save_strategy="no",
140
+ )
141
+
142
+ callbacks = [metric_callback, push_to_hub_callback]
143
+
144
+ model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)
145
+
146
+ # model.push_to_hub()
147
+
148
+ # ds = load_dataset("fashion_mnist", split="test[:10]")
149
+ # image = ds["image"][0]
150
+ # classifier = pipeline("image-classification", model="my_awesome_fashion_model")
151
+ # print(classifier(image))
152
+