Spaces:
Runtime error
Runtime error
File size: 7,748 Bytes
1bc9b9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import config as CFG
import json
from models import PoemTextModel
import torch
import random
from datasets import PoemTextDataset, get_transforms, CLIPDataset
from tqdm import tqdm
import numpy as np
class AvgMeter:
"""
Used to keep track of batch losses during training / validation.
...
Attributes:
-----------
name : str
count : int
number of data whose train/val loss has been metered
sum: int or float
sum of all losses metered
avg: int or float
average of metered losses
Methods:
--------
reset():
Sets count, sum and avg to 0.
update(val, count=1):
Updates loss sum, count and avg.
__repr__():
string representation of this class.
"""
def __init__(self, name="Metric"):
"""Sets the name of the avg meter. sets avg, sum & count to 0."""
self.name = name
self.reset()
def reset(self):
"""Sets avg, sum & count to 0."""
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
"""Updates loss sum, count and avg using val and count (count of the val input)"""
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
"""String representation of this class"""
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
"""Returns learning rate of the input optimizer"""
for param_group in optimizer.param_groups:
return param_group["lr"]
def get_datasets():
"""
Returns train, validation & test split from a dataset json file specified using CFG.dataset_path.
This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
random.Random(CFG.random_seed).shuffle(dataset)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset,
[int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))])
return train_dataset, val_dataset, test_dataset
def build_loaders(dataset_dict, mode):
"""
Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using PoemTextDataset and configs.
"""
dataset = PoemTextDataset(
dataset_dict
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_clip_datasets(dataset_dict):
"""
(Used for clip model training) Returns train, validation & test split from input.
This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Parameters:
-----------
dataset_dict: list of dict
the input dataset
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
random.Random(CFG.random_seed).shuffle(dataset_dict)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset_dict,
[int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))])
return train_dataset, val_dataset, test_dataset
def build_image_loaders(dataset_dict, mode):
"""
(Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using CLIPDataset and configs.
"""
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataset_dict, transforms, is_image_poem_pair=False
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_poem_embeddings(test_dataset, model=None):
"""
Returns embeddings of the poems existing in test_dataset.
Parameters:
-----------
test_dataset: list of dict
dataset to get poems from. each of its dictionaries must have a "beyt" key.
model: PoemTextModel, optional
The PoemTextModel model to get poem embeddings from.
If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py.
Returns:
--------
model (PoemTextModel): The model used for creating poem embeddings
"""
test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems)
if model == None:
model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device)
model.eval()
poem_embeddings = []
with torch.no_grad():
for batch in tqdm(test_loader):
# get poem embeddings by passing tokenizer output of the poems
# to the model's poem encoder and projection
beyts = {
key: values.to(CFG.device)
for key, values in batch["beyt"].items()
}
if model.__class__.__name__ == "PoemTextModel":
poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.poem_projection(poem_features)
poem_embeddings.append(poem_emb)
elif model.__class__.__name__ == "CLIPModel":
poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.text_projection(poem_features)
poem_embeddings.append(poem_emb)
else:
raise #not a right model to use!
return model, torch.cat(poem_embeddings)
|