|
|
|
import datasets |
|
import pandas as pd |
|
from PIL import Image |
|
import multiprocessing as mp |
|
from sklearn.model_selection import train_test_split |
|
|
|
import torch |
|
from torchvision import transforms |
|
from torch.utils.data import Dataset |
|
|
|
from transformers import Seq2SeqTrainer ,Seq2SeqTrainingArguments |
|
from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor |
|
from transformers import AutoTokenizer , default_data_collator |
|
import os |
|
os.environ["WANDB_DISABLED"] = "true" |
|
import torch_xla.core.xla_model as xm |
|
|
|
dev = xm.xla_device() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
device = torch.device("cuda") |
|
|
|
print('There are %d GPU(s) available.' % torch.cuda.device_count()) |
|
|
|
print('We will use the GPU:', torch.cuda.get_device_name(0)) |
|
|
|
else: |
|
print('No GPU available, using the CPU instead.') |
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
class config : |
|
ENCODER = "google/vit-base-patch16-224" |
|
DECODER = "gpt2" |
|
TRAIN_BATCH_SIZE = 64 |
|
VAL_BATCH_SIZE = 64 |
|
VAL_EPOCHS = 1 |
|
LR = 5e-5 |
|
SEED = 42 |
|
MAX_LEN = 128 |
|
SUMMARY_LEN = 20 |
|
WEIGHT_DECAY = 0.01 |
|
MEAN = (0.485, 0.456, 0.406) |
|
STD = (0.229, 0.224, 0.225) |
|
TRAIN_PCT = 0.95 |
|
NUM_WORKERS = mp.cpu_count() |
|
EPOCHS = 1 |
|
IMG_SIZE = (224,224) |
|
LABEL_MASK = -100 |
|
TOP_K = 10 |
|
TOP_P = 0.95 |
|
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
|
return outputs |
|
AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens |
|
|
|
|
|
|
|
rouge = datasets.load_metric("rouge") |
|
|
|
def compute_metrics(pred): |
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
|
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = tokenizer.pad_token_id |
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid |
|
|
|
return { |
|
"rouge2_precision": round(rouge_output.precision, 4), |
|
"rouge2_recall": round(rouge_output.recall, 4), |
|
"rouge2_fmeasure": round(rouge_output.fmeasure, 4), |
|
} |
|
|
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(config.ENCODER) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
transforms = transforms.Compose( |
|
[ |
|
|
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.5, 0.5, 0.5], |
|
std=[0.5, 0.5, 0.5], |
|
) |
|
] |
|
) |
|
|
|
|
|
|
|
class ImgDataset(torch.utils.data.Dataset): |
|
def __init__(self, df, root_dir, tokenizer, feature_extractor, transform): |
|
self.df = df |
|
self.transform = transform |
|
self.root_dir = root_dir |
|
self.tokenizer = tokenizer |
|
self.feature_extractor = feature_extractor |
|
self.max_length = 128 |
|
|
|
def __len__(self, ): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
caption = self.df.tags.iloc[idx] |
|
image = self.df.image_id.iloc[idx]+".jpg" |
|
folder_name = str(self.df.folder_name.iloc[idx]) |
|
img_path = os.path.join(os.path.join(self.root_dir, folder_name), image) |
|
img = Image.open(img_path).convert("RGB") |
|
|
|
|
|
img = self.transform(img) |
|
|
|
|
|
if img.min() < 0.0: |
|
img = (img + 1.0) / 2.0 |
|
|
|
pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values |
|
captions = self.tokenizer(caption, |
|
padding='max_length', |
|
max_length=self.max_length, |
|
truncation=True).input_ids |
|
captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions] |
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)} |
|
return encoding |
|
|
|
for j in range(1, 179+1): |
|
df=pd.read_csv(rf"posts/posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder_{j}.csv") |
|
train_df , val_df = train_test_split(df , test_size = 0.02) |
|
print(df.head(3)) |
|
|
|
train_dataset = ImgDataset( |
|
train_df, |
|
root_dir = rf"dump_small", |
|
tokenizer=tokenizer, |
|
feature_extractor = feature_extractor , |
|
transform = transforms, |
|
) |
|
|
|
val_dataset = ImgDataset( |
|
val_df , |
|
root_dir = rf"dump_small", |
|
tokenizer=tokenizer, |
|
feature_extractor = feature_extractor , |
|
transform = transforms |
|
) |
|
|
|
if os.path.exists('VIT_large_gpt2_model'): |
|
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model') |
|
else: |
|
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER) |
|
|
|
|
|
model.config.decoder_start_token_id = tokenizer.cls_token_id |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
model.config.eos_token_id = tokenizer.sep_token_id |
|
model.config.decoder_start_token_id = tokenizer.bos_token_id |
|
model.config.max_length = 128 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 2 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 2 |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir='VIT_large_gpt2', |
|
per_device_train_batch_size=config.TRAIN_BATCH_SIZE, |
|
per_device_eval_batch_size=config.VAL_BATCH_SIZE, |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
do_train=True, |
|
do_eval=True, |
|
logging_steps=1000, |
|
save_steps=1000, |
|
warmup_steps=200, |
|
learning_rate = 5e-5-j*2.2e-7, |
|
|
|
num_train_epochs = config.EPOCHS, |
|
overwrite_output_dir=True, |
|
save_total_limit=3, |
|
) |
|
|
|
|
|
|
|
|
|
"""import transformers.trainer |
|
from transformers.trainer import SequentialSampler |
|
|
|
|
|
def sampler_monkey_patch(dataset, generator): |
|
return SequentialSampler(dataset) |
|
|
|
|
|
transformers.trainer.RandomSampler = sampler_monkey_patch""" |
|
|
|
trainer = Seq2SeqTrainer( |
|
tokenizer=feature_extractor, |
|
model=model, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
data_collator=default_data_collator, |
|
) |
|
try: |
|
trainer.train(resume_from_checkpoint='VIT_large_gpt2_model') |
|
except: |
|
trainer.train() |
|
trainer.save_model('VIT_large_gpt2_model') |
|
|
|
|