#os.environ["WANDB_DISABLED"] = "true" import csv import os import torch from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args from datasets import load_dataset, Image from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments import evaluate import numpy as np import nltk from transformers import default_data_collator import PIL import wandb import nltk nltk.download('punkt') import os os.environ["WANDB_DISABLED"] = "true" import torch import torch_xla.core.xla_model as xm dev = xm.xla_device() # text preprocessing step def tokenization_fn(captions, max_target_length): """Run tokenization on captions.""" labels = tokenizer(captions, padding="max_length", max_length=max_target_length).input_ids return labels # image preprocessing step def feature_extraction_fn(image_paths, check_image=True): """ Run feature extraction on images If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded. Otherwise, an exception will be thrown. """ model_inputs = {} if check_image: images = [] to_keep = [] for image_file in image_paths: try: img = PIL.Image.open(image_file) images.append(img) to_keep.append(True) except Exception: to_keep.append(False) else: images = [PIL.Image.open(image_file) for image_file in image_paths] encoder_inputs = feature_extractor(images=images, return_tensors="np") return encoder_inputs.pixel_values def transform(example_batch): # Take a list of PIL images and turn them to pixel values inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt') # Don't forget to include the labels! inputs['labels'] = example_batch['labels'] return inputs def preprocess_fn(example_batch): """Run tokenization + image feature extraction""" model_inputs = {} model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']]) model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128) return model_inputs def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics(eval_preds): preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) if ignore_pad_token_for_loss: # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) result = {k: round(v * 100, 4) for k, v in result.items()} prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds ] result["gen_len"] = np.mean(prediction_lens) return result def load_csv_as_dict(file_path): with open(file_path, mode='r') as csv_file: reader = csv.reader(csv_file) result = {rows[0]: rows[1] for rows in reader} return result image_encoder_model = "google/vit-base-patch16-224"# actual use "google/vit-large-patch16-384"#google/vit-large-patch16-224-in21k text_decode_model = "Thouph/GPT-E6-small" model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( image_encoder_model, text_decode_model) model.eval() for p in model.parameters(): p.requires_grad = False # only allow training of cross attention parameters for layer in model.decoder.transformer.h: layer.crossattention.train() for p in layer.crossattention.parameters(): p.requires_grad = True layer.ln_cross_attn.train() for p in layer.ln_cross_attn.parameters(): p.requires_grad = True # image feature extractor feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model) # text tokenizer tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge") # GPT2 only has bos/eos tokens but not decoder_start/pad tokens tokenizer.pad_token = tokenizer.eos_token # update the model config model.config.eos_token_id = tokenizer.eos_token_id model.config.decoder_start_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id output_dir = "vit-gpt-model" model.save_pretrained(output_dir) for name, param in model.named_parameters(): if "crossattention" not in name: param.requires_grad = False feature_extractor.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv") print(dataset) def add_image_path(example): image_name = [i + '.jpg' for i in example["image_id"]] folder_name=example["folder_name"] image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))] example['image_path'] = image_path return example ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"] print(ds) ds = ds.train_test_split(test_size=0.02) print(ds['train'][0:2]) ds.set_transform(preprocess_fn) print(ds['train'][0:2]) training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="steps", eval_steps=100, gradient_accumulation_steps=4, per_device_train_batch_size=128, weight_decay=0.1, max_steps=10000, warmup_steps=1000, logging_strategy="steps", save_steps=5000, fp16=True, tpu_num_cores=8, per_device_eval_batch_size=128, output_dir="image-captioning-output", learning_rate=5e-4, lr_scheduler_type="cosine", ) def collate_fn(batch): return { 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), 'labels': torch.tensor([x['labels'] for x in batch]) } metric = evaluate.load("rouge") ignore_pad_token_for_loss = True # instantiate trainer trainer = Seq2SeqTrainer( model=model, tokenizer=feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=ds['train'], eval_dataset=ds['test'], data_collator=collate_fn, ) trainer.train() trainer.save_model("image-captioning-output1") tokenizer.save_pretrained("image-captioning-output1")