|
|
|
import csv |
|
import os |
|
|
|
import torch |
|
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args |
|
|
|
from datasets import load_dataset, Image |
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments |
|
import evaluate |
|
import numpy as np |
|
|
|
|
|
import nltk |
|
from transformers import default_data_collator |
|
|
|
import PIL |
|
|
|
import wandb |
|
import nltk |
|
nltk.download('punkt') |
|
import os |
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
|
import torch |
|
import torch_xla.core.xla_model as xm |
|
|
|
dev = xm.xla_device() |
|
|
|
|
|
def tokenization_fn(captions, max_target_length): |
|
"""Run tokenization on captions.""" |
|
labels = tokenizer(captions, |
|
padding="max_length", |
|
max_length=max_target_length).input_ids |
|
|
|
return labels |
|
|
|
|
|
|
|
def feature_extraction_fn(image_paths, check_image=True): |
|
""" |
|
Run feature extraction on images |
|
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded. |
|
Otherwise, an exception will be thrown. |
|
""" |
|
|
|
model_inputs = {} |
|
|
|
if check_image: |
|
images = [] |
|
to_keep = [] |
|
for image_file in image_paths: |
|
try: |
|
img = PIL.Image.open(image_file) |
|
images.append(img) |
|
to_keep.append(True) |
|
except Exception: |
|
to_keep.append(False) |
|
else: |
|
images = [PIL.Image.open(image_file) for image_file in image_paths] |
|
|
|
encoder_inputs = feature_extractor(images=images, return_tensors="np") |
|
|
|
return encoder_inputs.pixel_values |
|
|
|
def transform(example_batch): |
|
|
|
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt') |
|
|
|
|
|
inputs['labels'] = example_batch['labels'] |
|
return inputs |
|
|
|
def preprocess_fn(example_batch): |
|
"""Run tokenization + image feature extraction""" |
|
model_inputs = {} |
|
model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']]) |
|
model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128) |
|
return model_inputs |
|
|
|
def postprocess_text(preds, labels): |
|
preds = [pred.strip() for pred in preds] |
|
labels = [label.strip() for label in labels] |
|
|
|
|
|
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] |
|
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] |
|
|
|
return preds, labels |
|
|
|
|
|
def compute_metrics(eval_preds): |
|
preds, labels = eval_preds |
|
if isinstance(preds, tuple): |
|
preds = preds[0] |
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
|
if ignore_pad_token_for_loss: |
|
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
|
|
decoded_preds, decoded_labels = postprocess_text(decoded_preds, |
|
decoded_labels) |
|
|
|
result = metric.compute(predictions=decoded_preds, |
|
references=decoded_labels, |
|
use_stemmer=True) |
|
result = {k: round(v * 100, 4) for k, v in result.items()} |
|
prediction_lens = [ |
|
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds |
|
] |
|
result["gen_len"] = np.mean(prediction_lens) |
|
return result |
|
|
|
def load_csv_as_dict(file_path): |
|
with open(file_path, mode='r') as csv_file: |
|
reader = csv.reader(csv_file) |
|
result = {rows[0]: rows[1] for rows in reader} |
|
return result |
|
|
|
image_encoder_model = "google/vit-base-patch16-224" |
|
text_decode_model = "Thouph/GPT-E6-small" |
|
|
|
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
image_encoder_model, text_decode_model) |
|
|
|
model.eval() |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
for layer in model.decoder.transformer.h: |
|
layer.crossattention.train() |
|
for p in layer.crossattention.parameters(): |
|
p.requires_grad = True |
|
layer.ln_cross_attn.train() |
|
for p in layer.ln_cross_attn.parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge") |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
model.config.decoder_start_token_id = tokenizer.bos_token_id |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
output_dir = "vit-gpt-model" |
|
model.save_pretrained(output_dir) |
|
for name, param in model.named_parameters(): |
|
if "crossattention" not in name: |
|
param.requires_grad = False |
|
feature_extractor.save_pretrained(output_dir) |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv") |
|
print(dataset) |
|
def add_image_path(example): |
|
image_name = [i + '.jpg' for i in example["image_id"]] |
|
folder_name=example["folder_name"] |
|
image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))] |
|
example['image_path'] = image_path |
|
return example |
|
|
|
ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"] |
|
print(ds) |
|
|
|
ds = ds.train_test_split(test_size=0.02) |
|
print(ds['train'][0:2]) |
|
ds.set_transform(preprocess_fn) |
|
print(ds['train'][0:2]) |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
eval_steps=100, |
|
gradient_accumulation_steps=4, |
|
per_device_train_batch_size=128, |
|
weight_decay=0.1, |
|
max_steps=10000, |
|
warmup_steps=1000, |
|
logging_strategy="steps", |
|
save_steps=5000, |
|
fp16=True, |
|
tpu_num_cores=8, |
|
per_device_eval_batch_size=128, |
|
output_dir="image-captioning-output", |
|
learning_rate=5e-4, |
|
lr_scheduler_type="cosine", |
|
) |
|
|
|
def collate_fn(batch): |
|
return { |
|
'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
|
'labels': torch.tensor([x['labels'] for x in batch]) |
|
} |
|
|
|
metric = evaluate.load("rouge") |
|
ignore_pad_token_for_loss = True |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=feature_extractor, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=ds['train'], |
|
eval_dataset=ds['test'], |
|
data_collator=collate_fn, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model("image-captioning-output1") |
|
tokenizer.save_pretrained("image-captioning-output1") |
|
|