test_temp / train_k.py
Thouph's picture
Upload train_k.py
9086e4d
import datasets
import pandas as pd
from PIL import Image
import multiprocessing as mp
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer ,Seq2SeqTrainingArguments
from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor
from transformers import AutoTokenizer , default_data_collator
import os
os.environ["WANDB_DISABLED"] = "true"
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
if torch.cuda.is_available():
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
#os.environ["WANDB_DISABLED"] = "true"
class config :
ENCODER = "google/vit-base-patch16-224"
DECODER = "gpt2"
TRAIN_BATCH_SIZE = 64#8
VAL_BATCH_SIZE = 64#8
VAL_EPOCHS = 1
LR = 5e-5
SEED = 42
MAX_LEN = 128
SUMMARY_LEN = 20
WEIGHT_DECAY = 0.01
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
TRAIN_PCT = 0.95
NUM_WORKERS = mp.cpu_count()
EPOCHS = 1
IMG_SIZE = (224,224)
LABEL_MASK = -100
TOP_K = 10
TOP_P = 0.95
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return outputs
AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
rouge = datasets.load_metric("rouge")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
feature_extractor = ViTFeatureExtractor.from_pretrained(config.ENCODER)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.unk_token
transforms = transforms.Compose(
[
#transforms.Resize(config.IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
)
]
)
class ImgDataset(torch.utils.data.Dataset):
def __init__(self, df, root_dir, tokenizer, feature_extractor, transform):
self.df = df
self.transform = transform
self.root_dir = root_dir
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.max_length = 128
def __len__(self, ):
return len(self.df)
def __getitem__(self, idx):
caption = self.df.tags.iloc[idx]
image = self.df.image_id.iloc[idx]+".jpg"
folder_name = str(self.df.folder_name.iloc[idx])
img_path = os.path.join(os.path.join(self.root_dir, folder_name), image)
img = Image.open(img_path).convert("RGB")
img = self.transform(img)
# Check if normalization is required
if img.min() < 0.0:
img = (img + 1.0) / 2.0
pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values
captions = self.tokenizer(caption,
padding='max_length',
max_length=self.max_length,
truncation=True).input_ids
captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)}
return encoding
for j in range(1, 179+1):
df=pd.read_csv(rf"posts/posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder_{j}.csv")#r"Z:\posts-2023-04-17_MD5_caption_sifted_no_symbol_purged.csv")
train_df , val_df = train_test_split(df , test_size = 0.02)
print(df.head(3))
train_dataset = ImgDataset(
train_df,
root_dir = rf"dump_small",
tokenizer=tokenizer,
feature_extractor = feature_extractor ,
transform = transforms,
)
val_dataset = ImgDataset(
val_df ,
root_dir = rf"dump_small",
tokenizer=tokenizer,
feature_extractor = feature_extractor ,
transform = transforms
)
if os.path.exists('VIT_large_gpt2_model'):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model')
else:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 2
model.config.length_penalty = 2.0
model.config.num_beams = 2
training_args = Seq2SeqTrainingArguments(
output_dir='VIT_large_gpt2',
per_device_train_batch_size=config.TRAIN_BATCH_SIZE,
per_device_eval_batch_size=config.VAL_BATCH_SIZE,
predict_with_generate=True,
evaluation_strategy="steps",
do_train=True,
do_eval=True,
logging_steps=1000,
save_steps=1000,
warmup_steps=200,
learning_rate = 5e-5-j*2.2e-7,
#max_steps=400, # delete for full training
num_train_epochs = config.EPOCHS, #TRAIN_EPOCHS
overwrite_output_dir=True,
save_total_limit=3,
)
"""import transformers.trainer
from transformers.trainer import SequentialSampler
def sampler_monkey_patch(dataset, generator):
return SequentialSampler(dataset)
transformers.trainer.RandomSampler = sampler_monkey_patch"""
trainer = Seq2SeqTrainer(
tokenizer=feature_extractor,
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=default_data_collator,
)
try:
trainer.train(resume_from_checkpoint='VIT_large_gpt2_model')
except:
trainer.train()
trainer.save_model('VIT_large_gpt2_model')