test_temp / train.py
Thouph's picture
Upload train.py
997f017
#os.environ["WANDB_DISABLED"] = "true"
import csv
import os
import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args
from datasets import load_dataset, Image
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate
import numpy as np
import nltk
from transformers import default_data_collator
import PIL
import wandb
import nltk
nltk.download('punkt')
import os
os.environ["WANDB_DISABLED"] = "true"
import torch
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
# text preprocessing step
def tokenization_fn(captions, max_target_length):
"""Run tokenization on captions."""
labels = tokenizer(captions,
padding="max_length",
max_length=max_target_length).input_ids
return labels
# image preprocessing step
def feature_extraction_fn(image_paths, check_image=True):
"""
Run feature extraction on images
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
Otherwise, an exception will be thrown.
"""
model_inputs = {}
if check_image:
images = []
to_keep = []
for image_file in image_paths:
try:
img = PIL.Image.open(image_file)
images.append(img)
to_keep.append(True)
except Exception:
to_keep.append(False)
else:
images = [PIL.Image.open(image_file) for image_file in image_paths]
encoder_inputs = feature_extractor(images=images, return_tensors="np")
return encoder_inputs.pixel_values
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
# Don't forget to include the labels!
inputs['labels'] = example_batch['labels']
return inputs
def preprocess_fn(example_batch):
"""Run tokenization + image feature extraction"""
model_inputs = {}
model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']])
model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128)
return model_inputs
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds,
decoded_labels)
result = metric.compute(predictions=decoded_preds,
references=decoded_labels,
use_stemmer=True)
result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
]
result["gen_len"] = np.mean(prediction_lens)
return result
def load_csv_as_dict(file_path):
with open(file_path, mode='r') as csv_file:
reader = csv.reader(csv_file)
result = {rows[0]: rows[1] for rows in reader}
return result
image_encoder_model = "google/vit-base-patch16-224"# actual use "google/vit-large-patch16-384"#google/vit-large-patch16-224-in21k
text_decode_model = "Thouph/GPT-E6-small"
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
image_encoder_model, text_decode_model)
model.eval()
for p in model.parameters():
p.requires_grad = False
# only allow training of cross attention parameters
for layer in model.decoder.transformer.h:
layer.crossattention.train()
for p in layer.crossattention.parameters():
p.requires_grad = True
layer.ln_cross_attn.train()
for p in layer.ln_cross_attn.parameters():
p.requires_grad = True
# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge")
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token
# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
output_dir = "vit-gpt-model"
model.save_pretrained(output_dir)
for name, param in model.named_parameters():
if "crossattention" not in name:
param.requires_grad = False
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv")
print(dataset)
def add_image_path(example):
image_name = [i + '.jpg' for i in example["image_id"]]
folder_name=example["folder_name"]
image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
example['image_path'] = image_path
return example
ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"]
print(ds)
ds = ds.train_test_split(test_size=0.02)
print(ds['train'][0:2])
ds.set_transform(preprocess_fn)
print(ds['train'][0:2])
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
eval_steps=100,
gradient_accumulation_steps=4,
per_device_train_batch_size=128,
weight_decay=0.1,
max_steps=10000,
warmup_steps=1000,
logging_strategy="steps",
save_steps=5000,
fp16=True,
tpu_num_cores=8,
per_device_eval_batch_size=128,
output_dir="image-captioning-output",
learning_rate=5e-4,
lr_scheduler_type="cosine",
)
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
metric = evaluate.load("rouge")
ignore_pad_token_for_loss = True
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=feature_extractor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=ds['train'],
eval_dataset=ds['test'],
data_collator=collate_fn,
)
trainer.train()
trainer.save_model("image-captioning-output1")
tokenizer.save_pretrained("image-captioning-output1")