Thouph commited on
Commit
9228b7a
·
1 Parent(s): 306310a

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +225 -0
train.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #os.environ["WANDB_DISABLED"] = "true"
2
+ import csv
3
+ import os
4
+
5
+ import torch
6
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args
7
+
8
+ from datasets import load_dataset, Image
9
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
10
+ import evaluate
11
+ import numpy as np
12
+
13
+
14
+ import nltk
15
+ from transformers import default_data_collator
16
+
17
+ import PIL
18
+
19
+ import wandb
20
+ import nltk
21
+ nltk.download('punkt')
22
+ import os
23
+ os.environ["WANDB_DISABLED"] = "true"
24
+
25
+ import torch
26
+ import torch_xla.core.xla_model as xm
27
+
28
+ dev = xm.xla_device()
29
+
30
+ # text preprocessing step
31
+ def tokenization_fn(captions, max_target_length):
32
+ """Run tokenization on captions."""
33
+ labels = tokenizer(captions,
34
+ padding="max_length",
35
+ max_length=max_target_length).input_ids
36
+
37
+ return labels
38
+
39
+
40
+ # image preprocessing step
41
+ def feature_extraction_fn(image_paths, check_image=True):
42
+ """
43
+ Run feature extraction on images
44
+ If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
45
+ Otherwise, an exception will be thrown.
46
+ """
47
+
48
+ model_inputs = {}
49
+
50
+ if check_image:
51
+ images = []
52
+ to_keep = []
53
+ for image_file in image_paths:
54
+ try:
55
+ img = PIL.Image.open(image_file)
56
+ images.append(img)
57
+ to_keep.append(True)
58
+ except Exception:
59
+ to_keep.append(False)
60
+ else:
61
+ images = [PIL.Image.open(image_file) for image_file in image_paths]
62
+
63
+ encoder_inputs = feature_extractor(images=images, return_tensors="np")
64
+
65
+ return encoder_inputs.pixel_values
66
+
67
+
68
+ def preprocess_fn(examples, max_target_length, check_image=True):
69
+ """Run tokenization + image feature extraction"""
70
+ image_paths = examples["image_path"]
71
+ captions = examples['tags']
72
+
73
+ model_inputs = {}
74
+ # This contains image path column
75
+ model_inputs['labels'] = tokenization_fn(captions, max_target_length)
76
+ model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)
77
+
78
+ return model_inputs
79
+
80
+ def postprocess_text(preds, labels):
81
+ preds = [pred.strip() for pred in preds]
82
+ labels = [label.strip() for label in labels]
83
+
84
+ # rougeLSum expects newline after each sentence
85
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
86
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
87
+
88
+ return preds, labels
89
+
90
+
91
+ def compute_metrics(eval_preds):
92
+ preds, labels = eval_preds
93
+ if isinstance(preds, tuple):
94
+ preds = preds[0]
95
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
96
+ if ignore_pad_token_for_loss:
97
+ # Replace -100 in the labels as we can't decode them.
98
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
99
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
100
+
101
+ # Some simple post-processing
102
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds,
103
+ decoded_labels)
104
+
105
+ result = metric.compute(predictions=decoded_preds,
106
+ references=decoded_labels,
107
+ use_stemmer=True)
108
+ result = {k: round(v * 100, 4) for k, v in result.items()}
109
+ prediction_lens = [
110
+ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
111
+ ]
112
+ result["gen_len"] = np.mean(prediction_lens)
113
+ return result
114
+
115
+ def load_csv_as_dict(file_path):
116
+ with open(file_path, mode='r') as csv_file:
117
+ reader = csv.reader(csv_file)
118
+ result = {rows[0]: rows[1] for rows in reader}
119
+ return result
120
+
121
+ image_encoder_model = "google/vit-base-patch16-224"# actual use "google/vit-large-patch16-384"#google/vit-large-patch16-224-in21k
122
+ text_decode_model = "Thouph/GPT-E6-small"
123
+
124
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
125
+ image_encoder_model, text_decode_model)
126
+
127
+ model.eval()
128
+ for p in model.parameters():
129
+ p.requires_grad = False
130
+
131
+ # only allow training of cross attention parameters
132
+ for layer in model.decoder.transformer.h:
133
+ layer.crossattention.train()
134
+ for p in layer.crossattention.parameters():
135
+ p.requires_grad = True
136
+ layer.ln_cross_attn.train()
137
+ for p in layer.ln_cross_attn.parameters():
138
+ p.requires_grad = True
139
+
140
+ # image feature extractor
141
+ feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
142
+ # text tokenizer
143
+ tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge")
144
+
145
+ # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
146
+ tokenizer.pad_token = tokenizer.eos_token
147
+
148
+ # update the model config
149
+ model.config.eos_token_id = tokenizer.eos_token_id
150
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
151
+ model.config.pad_token_id = tokenizer.pad_token_id
152
+ output_dir = "vit-gpt-model"
153
+ model.save_pretrained(output_dir)
154
+ for name, param in model.named_parameters():
155
+ if "crossattention" not in name:
156
+ param.requires_grad = False
157
+ feature_extractor.save_pretrained(output_dir)
158
+ tokenizer.save_pretrained(output_dir)
159
+
160
+
161
+
162
+ dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv")
163
+ print(dataset)
164
+ def add_image_path(example):
165
+ image_name = [i + '.jpg' for i in example["image_id"]]
166
+ folder_name=example["folder_name"]
167
+ #image_name = example['image_id'] + '.jpg'
168
+ #image_path = os.path.join(r"D:\dump384_224x224_384\384", image_name)
169
+ image_path = [os.path.join(rf"~/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
170
+ example['image_path'] = image_path
171
+ return example
172
+
173
+ ds = dataset.map(add_image_path, batched=True, batch_size=1024)["train"]
174
+ print(ds)
175
+
176
+ ds = ds.train_test_split(test_size=0.02)
177
+
178
+ print(ds['train'][0])
179
+ processed_dataset = ds.map(
180
+ function=preprocess_fn,
181
+ batched=True,
182
+ fn_kwargs={"max_target_length": 128},
183
+ #remove_columns=ds['train'].column_names
184
+ )
185
+
186
+ training_args = Seq2SeqTrainingArguments(
187
+ predict_with_generate=True,
188
+ evaluation_strategy="steps",
189
+ eval_steps=100,
190
+ gradient_accumulation_steps=4,
191
+ per_device_train_batch_size=1,
192
+ weight_decay=0.1,
193
+ max_steps=1000,
194
+ warmup_steps=1000,
195
+ logging_strategy="steps",
196
+ save_steps=200,
197
+ fp16=True,
198
+ tpu_num_cores=8,
199
+ per_device_eval_batch_size=1,
200
+ output_dir="image-captioning-output",
201
+ learning_rate=5e-4,
202
+ lr_scheduler_type="cosine",
203
+ )
204
+
205
+
206
+ metric = evaluate.load("rouge")
207
+ ignore_pad_token_for_loss = True
208
+
209
+ # instantiate trainer
210
+ trainer = Seq2SeqTrainer(
211
+ model=model,
212
+ tokenizer=feature_extractor,
213
+ args=training_args,
214
+ compute_metrics=compute_metrics,
215
+ train_dataset=processed_dataset['train'],
216
+ eval_dataset=processed_dataset['test'],
217
+ data_collator=default_data_collator,
218
+ )
219
+
220
+
221
+ trainer.train()
222
+
223
+
224
+ trainer.save_model("image-captioning-output1")
225
+ tokenizer.save_pretrained("image-captioning-output1")