BLIP_fine_tune / finetuning.py
datnguyentien204's picture
Update finetuning.py
db88be6 verified
import os
import requests
from transformers import BlipProcessor, BlipForQuestionAnswering
from datasets import load_dataset
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
torch.cuda.empty_cache()
torch.manual_seed(42)
class VQADataset(torch.utils.data.Dataset):
"""VQA (v2) dataset."""
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# get image + text
question = self.dataset[idx]['question']
answer = self.dataset[idx]['answer']
image_id = self.dataset[idx]['pid']
image_path = f"Data/train_fill_in_blank/{image_id}/image.png"
image = Image.open(image_path).convert("RGB")
text = question
encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
labels = self.processor.tokenizer.encode(
answer, max_length= 8, pad_to_max_length=True, return_tensors='pt'
)
encoding["labels"] = labels
# remove batch dimension
for k,v in encoding.items(): encoding[k] = v.squeeze()
return encoding
training_dataset = load_dataset("json", data_files="Data/train.jsonl", split="train[:90%]")
valid_dataset = load_dataset("json", data_files="Data/train.jsonl", split="train[90%:]")
print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))
train_dataset = VQADataset(dataset=training_dataset,
processor=processor)
valid_dataset = VQADataset(dataset=valid_dataset,
processor=processor)
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)
num_epochs = 100
patience = 10
min_eval_loss = float("inf")
early_stopping_hook = 0
tracking_information = []
scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
epoch_loss = 0
model.train()
for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training batch: ...'), train_dataloader):
input_ids = batch.pop('input_ids').to(device)
pixel_values = batch.pop('pixel_values').to(device)
attention_masked = batch.pop('attention_mask').to(device)
labels = batch.pop('labels').to(device)
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
# attention_mask=attention_masked,
labels=labels)
loss = outputs.loss
epoch_loss += loss.item()
# loss.backward()
# optimizer.step()
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
model.eval()
eval_loss = 0
for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validating batch: ...'), valid_dataloader):
input_ids = batch.pop('input_ids').to(device)
pixel_values = batch.pop('pixel_values').to(device)
attention_masked = batch.pop('attention_mask').to(device)
labels = batch.pop('labels').to(device)
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_masked,
labels=labels)
loss = outputs.loss
eval_loss += loss.item()
tracking_information.append((epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
scheduler.step()
if eval_loss < min_eval_loss:
model.save_pretrained("Model/blip-saved-model", from_pt=True)
print("Saved model to Model/blip-saved-model")
min_eval_loss = eval_loss
early_stopping_hook = 0
else:
early_stopping_hook += 1
if early_stopping_hook > patience:
break
pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
print("The finetuning process has done!")