import csv |
import os |
import torch |
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args |
from datasets import load_dataset, Image |
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments |
import evaluate |
import numpy as np |
import nltk |
from transformers import default_data_collator |
import PIL |
import wandb |
import nltk |
nltk.download('punkt') |
import os |
os.environ["WANDB_DISABLED"] = "true" |
import torch |
import torch_xla.core.xla_model as xm |
dev = xm.xla_device() |
def tokenization_fn(captions, max_target_length): |
"""Run tokenization on captions.""" |
labels = tokenizer(captions, |
padding="max_length", |
max_length=max_target_length).input_ids |
return labels |
def feature_extraction_fn(image_paths, check_image=True): |
""" |
Run feature extraction on images |
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded. |
Otherwise, an exception will be thrown. |
""" |
model_inputs = {} |
if check_image: |
images = [] |
to_keep = [] |
for image_file in image_paths: |
try: |
img = PIL.Image.open(image_file) |
images.append(img) |
to_keep.append(True) |
except Exception: |
to_keep.append(False) |
else: |
images = [PIL.Image.open(image_file) for image_file in image_paths] |
encoder_inputs = feature_extractor(images=images, return_tensors="np") |
return encoder_inputs.pixel_values |
def transform(example_batch): |
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt') |
inputs['labels'] = example_batch['labels'] |
return inputs |
def preprocess_fn(example_batch): |
"""Run tokenization + image feature extraction""" |
model_inputs = {} |
model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']]) |
model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128) |
return model_inputs |
def postprocess_text(preds, labels): |
preds = [pred.strip() for pred in preds] |
labels = [label.strip() for label in labels] |
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] |
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] |
return preds, labels |
def compute_metrics(eval_preds): |
preds, labels = eval_preds |
if isinstance(preds, tuple): |
preds = preds[0] |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
if ignore_pad_token_for_loss: |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, |
decoded_labels) |
result = metric.compute(predictions=decoded_preds, |
references=decoded_labels, |
use_stemmer=True) |
result = {k: round(v * 100, 4) for k, v in result.items()} |
prediction_lens = [ |
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds |
] |
result["gen_len"] = np.mean(prediction_lens) |
return result |
def load_csv_as_dict(file_path): |
with open(file_path, mode='r') as csv_file: |
reader = csv.reader(csv_file) |
result = {rows[0]: rows[1] for rows in reader} |
return result |
image_encoder_model = "google/vit-base-patch16-224" |
text_decode_model = "Thouph/GPT-E6-small" |
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
image_encoder_model, text_decode_model) |
model.eval() |
for p in model.parameters(): |
p.requires_grad = False |
for layer in model.decoder.transformer.h: |
layer.crossattention.train() |
for p in layer.crossattention.parameters(): |
p.requires_grad = True |
layer.ln_cross_attn.train() |
for p in layer.ln_cross_attn.parameters(): |
p.requires_grad = True |
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model) |
tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge") |
tokenizer.pad_token = tokenizer.eos_token |
model.config.eos_token_id = tokenizer.eos_token_id |
model.config.decoder_start_token_id = tokenizer.bos_token_id |
model.config.pad_token_id = tokenizer.pad_token_id |
output_dir = "vit-gpt-model" |
model.save_pretrained(output_dir) |
for name, param in model.named_parameters(): |
if "crossattention" not in name: |
param.requires_grad = False |
feature_extractor.save_pretrained(output_dir) |
tokenizer.save_pretrained(output_dir) |
dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv") |
print(dataset) |
def add_image_path(example): |
image_name = [i + '.jpg' for i in example["image_id"]] |
folder_name=example["folder_name"] |
image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))] |
example['image_path'] = image_path |
return example |
ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"] |
print(ds) |
ds = ds.train_test_split(test_size=0.02) |
print(ds['train'][0:2]) |
ds.set_transform(preprocess_fn) |
print(ds['train'][0:2]) |
training_args = Seq2SeqTrainingArguments( |
predict_with_generate=True, |
evaluation_strategy="steps", |
eval_steps=100, |
gradient_accumulation_steps=4, |
per_device_train_batch_size=128, |
weight_decay=0.1, |
max_steps=10000, |
warmup_steps=1000, |
logging_strategy="steps", |
save_steps=5000, |
fp16=True, |
tpu_num_cores=8, |
per_device_eval_batch_size=128, |
output_dir="image-captioning-output", |
learning_rate=5e-4, |
lr_scheduler_type="cosine", |
) |
def collate_fn(batch): |
return { |
'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
'labels': torch.tensor([x['labels'] for x in batch]) |
} |
metric = evaluate.load("rouge") |
ignore_pad_token_for_loss = True |
trainer = Seq2SeqTrainer( |
model=model, |
tokenizer=feature_extractor, |
args=training_args, |
compute_metrics=compute_metrics, |
train_dataset=ds['train'], |
eval_dataset=ds['test'], |
data_collator=collate_fn, |
) |
trainer.train() |
trainer.save_model("image-captioning-output1") |
tokenizer.save_pretrained("image-captioning-output1") |