mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json
import torch
from torch import nn
#FIX
import config as CFG
from models import CLIPModel
from utils import AvgMeter, get_lr
from utils import get_datasets, build_loaders
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
"""
Performs one epoch of training.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
dataloader to get batches from
optimizer: torch.optim.Optimizer
optimizer used for training
lr_scheduler: torch.optim.lr_scheduler.LRScheduler
scheduler used for training
step: str ("batch" or "epoch")
if "batch", lr_scheduler will step (update) for each batch of loader.
else lr_scheduler only steps and updates after finishing each epoch.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's training
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
# backpropagate and step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
#update training info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
# print('train loss: ', loss_meter.avg)
return loss_meter
def valid_epoch(model, valid_loader):
"""
Performs one epoch of validation.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to validate
valid_loader: torch.utils.data.DataLoader
dataloader to get batches from.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's validation
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
#update validation info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
# print('validation loss: ', loss_meter.avg)
return loss_meter
def test(model, test_dataset):
"""
Calculates accuracy on test set.
This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem.
Parameters:
-----------
model: PoemTextModel
model to test
test_dataset: list of dict
the list containing dict of data to perform test on (must have "text" and "poem" keys)
Returns:
--------
accuracy: np.float
The accuracy of model on the test set given
"""
test_loader = build_loaders(test_dataset, mode="test")
accuracy = 0
tqdm_object = tqdm(test_loader, total=len(test_loader))
model.eval()
with torch.no_grad():
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
# get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text)
pred = model.predict(batch).cpu().numpy()
count = batch["text"]["input_ids"].size(0)
# since each text is associated with the poem with the same index as it, np.arange(count) is the real labels.
acc = np.sum(pred == np.arange(count))
accuracy += acc
tqdm_object.set_postfix(accuracy=acc / count)
accuracy /= len(test_dataset)
return accuracy
def train(model, train_loader, valid_loader, epochs=CFG.epochs):
"""
Performs train and validation for (epochs) epochs.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
train dataloader to get batches from
valid_loader: torch.utils.data.DataLoader
validation dataloader to get batches from
epochs: int, optional
the number of epochs to train
Returns:
--------
model: PoemTextModel or CLIPModel
trained model
loss_history: dict
a dict containing train and validation average loss for each epoch.
"""
# Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config
optimizer = torch.optim.AdamW(
model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)
# if step="batch", lr_scheduler will step (update) for each batch of loader.
# else lr_scheduler only steps and updates after finishing each epoch. (this case)
step = "epoch"
loss_history = {"train":[], "valid":[]}
# to keep track of best validation loss
best_loss = float('inf')
for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}")
# train for one epoch
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
loss_history["train"].append(train_loss.avg)
# validate trained model
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
loss_history["valid"].append(valid_loss.avg)
# if this epoch's avg validation loss is lower than best loss, save and keep this model.
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
model.save_current()
print("Saved Best Model!")
if step == "epoch":
lr_scheduler.step(valid_loss.avg)
return model, loss_history