|
import os
|
|
|
|
import requests
|
|
from transformers import BlipProcessor, BlipForQuestionAnswering
|
|
from datasets import load_dataset
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
import pickle
|
|
|
|
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
torch.cuda.empty_cache()
|
|
torch.manual_seed(42)
|
|
|
|
class VQADataset(torch.utils.data.Dataset):
|
|
"""VQA (v2) dataset."""
|
|
|
|
def __init__(self, dataset, processor):
|
|
self.dataset = dataset
|
|
self.processor = processor
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
question = self.dataset[idx]['question']
|
|
answer = self.dataset[idx]['answer']
|
|
image_id = self.dataset[idx]['pid']
|
|
image_path = f"Data/train_fill_in_blank/{image_id}/image.png"
|
|
image = Image.open(image_path).convert("RGB")
|
|
text = question
|
|
|
|
encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
|
|
labels = self.processor.tokenizer.encode(
|
|
answer, max_length= 8, pad_to_max_length=True, return_tensors='pt'
|
|
)
|
|
encoding["labels"] = labels
|
|
|
|
for k,v in encoding.items(): encoding[k] = v.squeeze()
|
|
return encoding
|
|
|
|
training_dataset = load_dataset("json", data_files="Data/train.jsonl", split="train[:90%]")
|
|
valid_dataset = load_dataset("json", data_files="Data/train.jsonl", split="train[90%:]")
|
|
print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))
|
|
|
|
train_dataset = VQADataset(dataset=training_dataset,
|
|
processor=processor)
|
|
valid_dataset = VQADataset(dataset=valid_dataset,
|
|
processor=processor)
|
|
|
|
batch_size = 12
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)
|
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
|
|
|
|
num_epochs = 100
|
|
patience = 10
|
|
min_eval_loss = float("inf")
|
|
early_stopping_hook = 0
|
|
tracking_information = []
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
for epoch in range(num_epochs):
|
|
epoch_loss = 0
|
|
model.train()
|
|
for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
|
|
input_ids = batch.pop('input_ids').to(device)
|
|
pixel_values = batch.pop('pixel_values').to(device)
|
|
attention_masked = batch.pop('attention_mask').to(device)
|
|
labels = batch.pop('labels').to(device)
|
|
|
|
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
outputs = model(input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
|
|
labels=labels)
|
|
|
|
loss = outputs.loss
|
|
epoch_loss += loss.item()
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
model.eval()
|
|
eval_loss = 0
|
|
for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
|
|
input_ids = batch.pop('input_ids').to(device)
|
|
pixel_values = batch.pop('pixel_values').to(device)
|
|
attention_masked = batch.pop('attention_mask').to(device)
|
|
labels = batch.pop('labels').to(device)
|
|
|
|
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
outputs = model(input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_masked,
|
|
labels=labels)
|
|
|
|
loss = outputs.loss
|
|
eval_loss += loss.item()
|
|
|
|
tracking_information.append((epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
|
|
print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
|
|
scheduler.step()
|
|
if eval_loss < min_eval_loss:
|
|
model.save_pretrained("Model/blip-saved-model", from_pt=True)
|
|
print("Saved model to Model/blip-saved-model")
|
|
min_eval_loss = eval_loss
|
|
early_stopping_hook = 0
|
|
else:
|
|
early_stopping_hook += 1
|
|
if early_stopping_hook > patience:
|
|
break
|
|
|
|
pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
|
|
print("The finetuning process has done!")
|
|
|