File size: 7,401 Bytes
1bc9b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json


import torch
from torch import nn

#FIX
import config as CFG
from models import CLIPModel
from utils import AvgMeter, get_lr
from utils import get_datasets, build_loaders

def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    """
    Performs one epoch of training.

        Parameters:
        -----------
        model: PoemTextModel or CLIPModel
            model to train
        train_loader: torch.utils.data.DataLoader
            dataloader to get batches from
        optimizer: torch.optim.Optimizer
            optimizer used for training
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler
            scheduler used for training
        step: str ("batch" or "epoch")
            if "batch", lr_scheduler will step (update) for each batch of loader.
            else lr_scheduler only steps and updates after finishing each epoch.
        
        Returns:
        --------
        loss_meter: AvgMeter
            the class containing average loss of this epoch's training
    """
    loss_meter = AvgMeter() # to track average of loss
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch_cpu in tqdm_object:
        # put batch data on device
        batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
        if "image" in batch_cpu:
          batch["image"] = batch_cpu["image"].to(CFG.device)

        #get model's embeddings and calculate loss
        poem_or_img_embeddings, text_embeddings = model(batch)
        loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)

        # backpropagate and step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        #update training info
        count = batch["text"]["input_ids"].size(0)
        loss_meter.update(loss.item(), count)
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    # print('train loss: ', loss_meter.avg)
    return loss_meter


def valid_epoch(model, valid_loader):
    """
    Performs one epoch of validation.

        Parameters:
        -----------
        model: PoemTextModel or CLIPModel
            model to validate
        valid_loader: torch.utils.data.DataLoader
            dataloader to get batches from.
        
        Returns:
        --------
        loss_meter: AvgMeter
            the class containing average loss of this epoch's validation
    """
    loss_meter = AvgMeter() # to track average of loss
    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch_cpu in tqdm_object:
        # put batch data on device
        batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
        if "image" in batch_cpu:
            batch["image"] = batch_cpu["image"].to(CFG.device)
        
        #get model's embeddings and calculate loss
        poem_or_img_embeddings, text_embeddings = model(batch)
        loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)

        #update validation info
        count = batch["text"]["input_ids"].size(0)
        loss_meter.update(loss.item(), count)
        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    # print('validation loss: ', loss_meter.avg)
    return loss_meter

def test(model, test_dataset):
    """
    Calculates accuracy on test set.
    This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem.

        Parameters:
        -----------
        model: PoemTextModel
            model to test
        test_dataset: list of dict
            the list containing dict of data to perform test on (must have "text" and "poem" keys)
        
        Returns:
        --------
        accuracy: np.float
            The accuracy of model on the test set given
    """
    test_loader = build_loaders(test_dataset, mode="test")
    accuracy = 0
    tqdm_object = tqdm(test_loader, total=len(test_loader))
    model.eval()
    with torch.no_grad():
        for batch_cpu in tqdm_object:
            # put batch data on device
            batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
            if "image" in batch_cpu:
                batch["image"] = batch_cpu["image"].to(CFG.device)
            
            # get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text)
            pred = model.predict(batch).cpu().numpy()
            
            count = batch["text"]["input_ids"].size(0)
            # since each text is associated with the poem with the same index as it, np.arange(count) is the real labels.
            acc = np.sum(pred == np.arange(count))
            accuracy += acc

            tqdm_object.set_postfix(accuracy=acc / count)
    accuracy /= len(test_dataset)
    return accuracy

def train(model, train_loader, valid_loader, epochs=CFG.epochs):
    """
    Performs train and validation for (epochs) epochs.

        Parameters:
        -----------
        model: PoemTextModel or CLIPModel
            model to train
        train_loader: torch.utils.data.DataLoader
            train dataloader to get batches from
        valid_loader: torch.utils.data.DataLoader
            validation dataloader to get batches from
        epochs: int, optional
            the number of epochs to train
        
        Returns:
        --------
        model: PoemTextModel or CLIPModel
            trained model
        loss_history: dict
            a dict containing train and validation average loss for each epoch.
    """
    # Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )

    # if step="batch", lr_scheduler will step (update) for each batch of loader.
    # else lr_scheduler only steps and updates after finishing each epoch. (this case)
    step = "epoch"
    loss_history = {"train":[], "valid":[]}

    # to keep track of best validation loss
    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")
        # train for one epoch
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
        loss_history["train"].append(train_loss.avg)

        # validate trained model
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(model, valid_loader)
            loss_history["valid"].append(valid_loss.avg)
        
        # if this epoch's avg validation loss is lower than best loss, save and keep this model.
        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            model.save_current()
            print("Saved Best Model!")
        
        if step == "epoch":
            lr_scheduler.step(valid_loss.avg)
    return model, loss_history