File size: 7,065 Bytes
9228b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997f017
 
 
9228b7a
997f017
 
 
9228b7a
997f017
 
9228b7a
997f017
 
9228b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d65247
9228b7a
 
 
9d65247
9228b7a
 
 
997f017
 
 
9228b7a
 
 
 
 
 
 
997f017
9228b7a
997f017
9228b7a
 
997f017
9228b7a
 
997f017
9228b7a
 
 
 
 
997f017
 
 
 
 
9228b7a
 
 
 
 
 
 
 
 
 
997f017
 
 
9228b7a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#os.environ["WANDB_DISABLED"] = "true"
import csv
import os

import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args

from datasets import load_dataset, Image
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate
import numpy as np


import nltk
from transformers import default_data_collator

import PIL

import wandb
import nltk
nltk.download('punkt')
import os
os.environ["WANDB_DISABLED"] = "true"

import torch
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions,
                       padding="max_length",
                       max_length=max_target_length).input_ids

    return labels


# image preprocessing step
def feature_extraction_fn(image_paths, check_image=True):
    """
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.
    """

    model_inputs = {}

    if check_image:
        images = []
        to_keep = []
        for image_file in image_paths:
            try:
                img = PIL.Image.open(image_file)
                images.append(img)
                to_keep.append(True)
            except Exception:
                to_keep.append(False)
    else:
        images = [PIL.Image.open(image_file) for image_file in image_paths]

    encoder_inputs = feature_extractor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

def preprocess_fn(example_batch):
    """Run tokenization + image feature extraction"""
    model_inputs = {}
    model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']])
    model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128)
    return model_inputs

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

def load_csv_as_dict(file_path):
    with open(file_path, mode='r') as csv_file:
        reader = csv.reader(csv_file)
        result = {rows[0]: rows[1] for rows in reader}
    return result

image_encoder_model = "google/vit-base-patch16-224"# actual use "google/vit-large-patch16-384"#google/vit-large-patch16-224-in21k
text_decode_model = "Thouph/GPT-E6-small"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model)

model.eval()
for p in model.parameters():
        p.requires_grad = False

    # only allow training of cross attention parameters
for layer in model.decoder.transformer.h:
        layer.crossattention.train()
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        layer.ln_cross_attn.train()
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True

# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge")

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
output_dir = "vit-gpt-model"
model.save_pretrained(output_dir)
for name, param in model.named_parameters():
 if "crossattention" not in name:
     param.requires_grad = False
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)



dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv")
print(dataset)
def add_image_path(example):
    image_name = [i + '.jpg' for i in example["image_id"]]
    folder_name=example["folder_name"]
    image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
    example['image_path'] = image_path
    return example

ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"]
print(ds)

ds = ds.train_test_split(test_size=0.02)
print(ds['train'][0:2])
ds.set_transform(preprocess_fn)
print(ds['train'][0:2])


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    eval_steps=100,
    gradient_accumulation_steps=4,
    per_device_train_batch_size=128,
    weight_decay=0.1,
    max_steps=10000,
    warmup_steps=1000,
    logging_strategy="steps",
    save_steps=5000,
    fp16=True,
    tpu_num_cores=8,
    per_device_eval_batch_size=128,
    output_dir="image-captioning-output",
    learning_rate=5e-4,
    lr_scheduler_type="cosine",
)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = evaluate.load("rouge")
ignore_pad_token_for_loss = True

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=ds['train'],
    eval_dataset=ds['test'],
    data_collator=collate_fn,
)


trainer.train()


trainer.save_model("image-captioning-output1")
tokenizer.save_pretrained("image-captioning-output1")