Spaces:
Runtime error
Runtime error
File size: 7,401 Bytes
1bc9b9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json
import torch
from torch import nn
#FIX
import config as CFG
from models import CLIPModel
from utils import AvgMeter, get_lr
from utils import get_datasets, build_loaders
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
"""
Performs one epoch of training.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
dataloader to get batches from
optimizer: torch.optim.Optimizer
optimizer used for training
lr_scheduler: torch.optim.lr_scheduler.LRScheduler
scheduler used for training
step: str ("batch" or "epoch")
if "batch", lr_scheduler will step (update) for each batch of loader.
else lr_scheduler only steps and updates after finishing each epoch.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's training
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
# backpropagate and step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
#update training info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
# print('train loss: ', loss_meter.avg)
return loss_meter
def valid_epoch(model, valid_loader):
"""
Performs one epoch of validation.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to validate
valid_loader: torch.utils.data.DataLoader
dataloader to get batches from.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's validation
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
#update validation info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
# print('validation loss: ', loss_meter.avg)
return loss_meter
def test(model, test_dataset):
"""
Calculates accuracy on test set.
This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem.
Parameters:
-----------
model: PoemTextModel
model to test
test_dataset: list of dict
the list containing dict of data to perform test on (must have "text" and "poem" keys)
Returns:
--------
accuracy: np.float
The accuracy of model on the test set given
"""
test_loader = build_loaders(test_dataset, mode="test")
accuracy = 0
tqdm_object = tqdm(test_loader, total=len(test_loader))
model.eval()
with torch.no_grad():
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
# get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text)
pred = model.predict(batch).cpu().numpy()
count = batch["text"]["input_ids"].size(0)
# since each text is associated with the poem with the same index as it, np.arange(count) is the real labels.
acc = np.sum(pred == np.arange(count))
accuracy += acc
tqdm_object.set_postfix(accuracy=acc / count)
accuracy /= len(test_dataset)
return accuracy
def train(model, train_loader, valid_loader, epochs=CFG.epochs):
"""
Performs train and validation for (epochs) epochs.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
train dataloader to get batches from
valid_loader: torch.utils.data.DataLoader
validation dataloader to get batches from
epochs: int, optional
the number of epochs to train
Returns:
--------
model: PoemTextModel or CLIPModel
trained model
loss_history: dict
a dict containing train and validation average loss for each epoch.
"""
# Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config
optimizer = torch.optim.AdamW(
model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)
# if step="batch", lr_scheduler will step (update) for each batch of loader.
# else lr_scheduler only steps and updates after finishing each epoch. (this case)
step = "epoch"
loss_history = {"train":[], "valid":[]}
# to keep track of best validation loss
best_loss = float('inf')
for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}")
# train for one epoch
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
loss_history["train"].append(train_loss.avg)
# validate trained model
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
loss_history["valid"].append(valid_loss.avg)
# if this epoch's avg validation loss is lower than best loss, save and keep this model.
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
model.save_current()
print("Saved Best Model!")
if step == "epoch":
lr_scheduler.step(valid_loss.avg)
return model, loss_history |