Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
r""" | |
Model Base | |
============== | |
Abstract base class used to build new modules inside Polos. | |
This class is just an extention of PyTorch Lightning main module: | |
https://pytorch-lightning.readthedocs.io/en/0.8.4/lightning-module.html | |
""" | |
from argparse import Namespace | |
from os import path | |
import os | |
from typing import Dict, Generator, List, Tuple, Union | |
import click | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch.utils.data import DataLoader, RandomSampler, Subset, Dataset | |
from PIL import Image | |
import pytorch_lightning as ptl | |
from polos.models.encoders import Encoder, str2encoder | |
from polos.schedulers import str2scheduler | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
class ModelBase(ptl.LightningModule): | |
""" | |
Extends PyTorch Lightning with a common structure and interface | |
that will be shared across all architectures. | |
:param hparams: Namespace with hyper-parameters | |
""" | |
class ModelConfig: | |
""" | |
The ModelConfig class is used to define model hyper-parameters that | |
are used to initialize our Lightning Modules. These parameters are | |
then overwritted with the values defined in the YAML file and coverted | |
to a Namespace to initialize the model. | |
:param model: Model class name (to be replaced with the model specified in the YAML) | |
-------------------- Training Parameters ------------------------- | |
:param batch_size: Batch size used during training. | |
:param nr_frozen_epochs: Number of epochs we keep the encoder model frozen. | |
:param keep_embeddings_frozen: Keeping the embeddings frozen is a usefull way to save some GPU memory usage. | |
This is critical to fine-tune large models in GPUs with less than 32GB memory. | |
-------------------- Optimizer Parameters ------------------------- | |
:param optimizer: Optimizer class to be used. | |
:param learning_rate: Overall learning rate. | |
-------------------- Scheduler Parameters ------------------------- | |
:param scheduler: Scheduler class to be used. | |
:param warmup_steps: Warmup steps (only used for schedulers with warmup period). | |
-------------------- Architecture Parameters ------------------------- | |
:param encoder_model: Encoder class to be used. | |
:param pretrained_model: Encoder checkpoint (e.g. xlmr.base vs xlmr.large) | |
:param pool: Pooling technique to extract the sentence embeddings. | |
Options: {max, avg, default, cls} where default uses the `default` sentence embedding | |
returned by the encoder (e.g. BERT pooler_output) and `cls` is the first token of the | |
sequence and depends on the selected layer. | |
:param load_weights: Loads weights from a checkpoint file that match the architecture. | |
-------------------- Data Parameters ------------------------- | |
:param train_path: Path to the training data. | |
:param val_path: Path to the validation data. | |
:param test_path: Path to the test data. | |
:param loader_workers: Number of workers used to load and tokenize data during training. | |
:param monitor: Metric to be displayed in tqdm bar. Same as trainer monitor flag! | |
""" | |
model: str = None | |
# TODO: rankerγ’γγ«γ γ¨γγγεΏ θ¦οΌθ¦θͺΏζ» | |
# γγγγγγ | |
encoder_learning_rate: float = 1e-06 | |
layerwise_decay: float = 1.0 | |
layer: str = "mix" | |
scalar_mix_dropout: float = 0.0 | |
loss: str = "mse" | |
hidden_sizes: str = "1024" | |
activations: str = "Tanh" | |
dropout: float = 0.1 | |
final_activation: str = "Sigmoid" | |
# γγγγΎγ§γ | |
# Training details | |
batch_size: int = 8 | |
nr_frozen_epochs: int = 0 | |
keep_embeddings_frozen: bool = False | |
# Optimizer | |
optimizer: str = "Adam" | |
learning_rate: float = 1e-05 | |
# Scheduler | |
scheduler: str = "constant" | |
warmup_steps: int = None | |
# Architecture Definition | |
encoder_model: str = "XLMR" | |
pretrained_model: str = "xlmr.base" | |
pool: str = "avg" | |
load_weights: str = False | |
# Data | |
train_path: str = None | |
val_path: str = None | |
test_path: str = None | |
train_img_dir_path: str = None | |
val_img_dir_path: str = None | |
test_img_dir_path: str = None | |
loader_workers: int = 8 | |
monitor: str = "kendall" | |
def __init__(self, initial_data: dict) -> None: | |
for key in initial_data: | |
if hasattr(self, key): | |
setattr(self, key, initial_data[key]) | |
def namespace(self) -> Namespace: | |
return Namespace( | |
**{ | |
name: getattr(self, name) | |
for name in dir(self) | |
if not callable(getattr(self, name)) and not name.startswith("__") | |
} | |
) | |
def __init__(self, hparams: Namespace) -> None: | |
super(ModelBase, self).__init__() | |
if isinstance(hparams, dict): | |
self.hparams = Namespace(**hparams) | |
else: | |
self.hparams = hparams | |
self.encoder = self._build_encoder() | |
# Model initialization | |
self._build_model() | |
# Loss criterion initialization. | |
self._build_loss() | |
# The encoder always starts in a frozen state. | |
if self.hparams.nr_frozen_epochs > 0: | |
self._frozen = True | |
self.freeze_encoder() | |
else: | |
self._frozen = False | |
if ( | |
hasattr(self.hparams, "keep_embeddings_frozen") | |
and self.hparams.keep_embeddings_frozen | |
): | |
self.encoder.freeze_embeddings() | |
self.nr_frozen_epochs = self.hparams.nr_frozen_epochs | |
def _build_loss(self): | |
""" Initializes the loss function/s. """ | |
pass | |
def _build_model(self) -> ptl.LightningModule: | |
""" | |
Initializes the estimator architecture. | |
""" | |
# Compatibility with previous Polos versions | |
if ( | |
hasattr(self.hparams, "load_weights") | |
and self.hparams.load_weights | |
and path.exists(self.hparams.load_weights) | |
): | |
click.secho(f"Loading weights from {self.hparams.load_weights}", fg="red") | |
self.load_weights(self.hparams.load_weights) | |
def _build_encoder(self) -> Encoder: | |
""" | |
Initializes the encoder. | |
""" | |
try: | |
return str2encoder[self.hparams.encoder_model].from_pretrained(self.hparams) | |
except KeyError: | |
raise Exception(f"{self.hparams.encoder_model} invalid encoder model!") | |
def _build_optimizer(self, parameters: Generator) -> torch.optim.Optimizer: | |
""" | |
Initializes the Optimizer. | |
:param parameters: Module.parameters. | |
""" | |
if hasattr(torch.optim, self.hparams.optimizer): | |
return getattr(torch.optim, self.hparams.optimizer)( | |
params=parameters, lr=self.hparams.learning_rate | |
) | |
else: | |
raise Exception(f"{self.hparams.optimizer} invalid optimizer!") | |
def _build_scheduler( | |
self, optimizer: torch.optim.Optimizer | |
) -> torch.optim.lr_scheduler.LambdaLR: | |
""" | |
Initializes the Scheduler. | |
:param optimizer: PyTorch optimizer | |
""" | |
self.epoch_total_steps = len(self.train_dataset) // ( | |
self.hparams.batch_size * max(1, self.trainer.num_gpus) | |
) | |
self.total_steps = self.epoch_total_steps * float(self.trainer.max_epochs) | |
try: | |
return { | |
"scheduler": str2scheduler[self.hparams.scheduler].from_hparams( | |
optimizer, self.hparams, num_training_steps=self.total_steps | |
), | |
"interval": "step", # called after each training step | |
} | |
except KeyError: | |
raise Exception(f"{self.hparams.scheduler} invalid scheduler!") | |
def read_csv(self, path: str) -> List[dict]: | |
"""Reads a comma separated value file. | |
:param path: path to a csv file. | |
:return: List of records as dictionaries | |
""" | |
df = pd.read_csv(path) | |
return df.to_dict("records") | |
def freeze_encoder(self) -> None: | |
""" Freezes the encoder layer. """ | |
self.encoder.freeze() | |
def unfreeze_encoder(self) -> None: | |
""" un-freezes the encoder layer. """ | |
if self._frozen: | |
if self.trainer.is_global_zero: | |
click.secho("\nEncoder model fine-tuning", fg="red") | |
self.encoder.unfreeze() | |
self._frozen = False | |
if ( | |
hasattr(self.hparams, "keep_embeddings_frozen") | |
and self.hparams.keep_embeddings_frozen | |
): | |
self.encoder.freeze_embeddings() | |
def on_epoch_end(self): | |
""" Hook used to unfreeze encoder during training. """ | |
if self.current_epoch + 1 >= self.nr_frozen_epochs and self._frozen: | |
self.unfreeze_encoder() | |
self._frozen = False | |
def predict( | |
self, samples: Dict[str, str] | |
) -> (Dict[str, Union[str, float]], List[float]): | |
"""Function that runs a model prediction, | |
:param samples: dictionary with expected model sequences. | |
You can also pass a list of dictionaries to predict an entire batch. | |
:return: Dictionary with input samples + scores and list with just the scores. | |
""" | |
pass | |
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: | |
""" | |
PyTorch Forward. | |
:return: Dictionary with model outputs to be passed to the loss function. | |
""" | |
pass | |
def compute_loss( | |
self, model_out: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor] | |
) -> torch.Tensor: | |
""" | |
Computes Loss value according to a loss function. | |
:param model_out: model specific output. | |
:param targets: Target score values [batch_size] | |
""" | |
pass | |
def prepare_sample( | |
self, sample: List[Dict[str, Union[str, float]]], inference: bool = False | |
) -> Union[ | |
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], Dict[str, torch.Tensor] | |
]: | |
""" | |
Function that prepares a sample to input the model. | |
:param sample: List of dictionaries. | |
:param inference: If set to true prepares only the model inputs. | |
:returns: Tuple with 2 dictionaries (model inputs and targets). If `inference=True` | |
returns only the model inputs. | |
""" | |
pass | |
def configure_optimizers( | |
self, | |
) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler.LambdaLR]]: | |
""" | |
Function for setting up the optimizers and the schedulers to be used during training. | |
:returns: List with as many optimizers as we need and a list with the respective schedulers. | |
""" | |
optimizer = self._build_optimizer(self.parameters()) | |
scheduler = self._build_scheduler(optimizer) | |
return [optimizer], [scheduler] | |
def compute_metrics( | |
self, outputs: List[Dict[str, torch.Tensor]] | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Function that computes metrics of interest based on the list of outputs | |
you defined in validation_step. | |
""" | |
pass | |
def training_step( | |
self, | |
batch: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], | |
batch_nb: int, | |
*args, | |
**kwargs, | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Runs one training step. | |
This usually consists in the forward function followed by the loss function. | |
:param batch: The output of your prepare_sample function. | |
:param batch_nb: Integer displaying which batch this is. | |
:returns: dictionary containing the loss and the metrics to be added to the lightning logger. | |
""" | |
batch_input, batch_target = batch | |
batch_prediction = self.forward(**batch_input) | |
loss_value = self.compute_loss(batch_prediction, batch_target) | |
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning | |
if self.trainer.use_dp or self.trainer.use_ddp2: | |
loss_value = loss_value.unsqueeze(0) | |
if ( | |
self.nr_frozen_epochs < 1.0 | |
and self.nr_frozen_epochs > 0.0 | |
and batch_nb > self.epoch_total_steps * self.nr_frozen_epochs | |
): | |
self.unfreeze_encoder() | |
self._frozen = False | |
self.log("train_loss", loss_value, on_step=True, on_epoch=True) | |
return loss_value | |
def validation_step( | |
self, | |
batch: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], | |
batch_nb: int, | |
dataloader_idx: int, | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Similar to the training step but with the model in eval mode. | |
:param batch: The output of your prepare_sample function. | |
:param batch_nb: Integer displaying which batch this is. | |
:param dataloader_idx: Integer displaying which dataloader this is. | |
:returns: dictionary passed to the validation_end function. | |
""" | |
batch_input, batch_target = batch | |
batch_prediction = self.forward(**batch_input) | |
loss_value = self.compute_loss(batch_prediction, batch_target) | |
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning | |
if self.trainer.use_dp or self.trainer.use_ddp2: | |
loss_value = loss_value.unsqueeze(0) | |
return { | |
"val_loss": loss_value, | |
"val_prediction": batch_prediction, | |
"val_target": batch_target, | |
} | |
def validation_epoch_end( | |
self, outputs: List[Dict[str, torch.Tensor]] | |
) -> Dict[str, Dict[str, torch.Tensor]]: | |
""" | |
Function that takes as input a list of dictionaries returned by the validation_step | |
and measures the model performance accross the entire validation set. | |
:param outputs: | |
:returns: Dictionary with metrics to be added to the lightning logger. | |
""" | |
train_outs, val_outs = outputs | |
train_loss = torch.stack([x["val_loss"] for x in train_outs]).mean() | |
val_loss = torch.stack([x["val_loss"] for x in val_outs]).mean() | |
# Store Metrics for Reporting... | |
val_metrics = self.compute_metrics(val_outs) | |
val_metrics["avg_loss"] = val_loss | |
self.log_dict(val_metrics, prog_bar=True) | |
train_metrics = self.compute_metrics(train_outs) | |
train_metrics["avg_loss"] = train_loss | |
self.log_dict({"train_" + k: v for k, v in train_metrics.items()}) | |
def test_step( | |
self, | |
batch: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], | |
batch_nb: int, | |
*args, | |
**kwargs, | |
) -> Dict[str, torch.Tensor]: | |
""" Redirects to the validation_step function """ | |
return self.validation_step(batch, batch_nb, 0) | |
def test_epoch_end( | |
self, outputs: List[Dict[str, torch.Tensor]] | |
) -> Dict[str, Dict[str, torch.Tensor]]: | |
""" Computes metrics. """ | |
return self.compute_metrics(outputs) | |
def setup(self, stage) -> None: | |
"""Data preparation function called before training by Lightning. | |
Equivalent to the prepare_data in previous Lightning Versions""" | |
self.train_dataset = self.read_csv(self.hparams.train_path) | |
self.val_dataset = self.read_csv(self.hparams.val_path) | |
self.train_dataset = CVPRDataset(self.train_dataset, self.hparams.train_img_dir_path) | |
self.val_dataset = CVPRDataset(self.val_dataset, self.hparams.val_img_dir_path) | |
print("[SPLIT]: Train: {}, Val: {}".format(len(self.train_dataset), len(self.val_dataset))) | |
# Always validate the model with 2k examples from training to control overfit. | |
train_subset = np.random.choice(a=len(self.train_dataset), size=2000) | |
self.train_subset = Subset(self.train_dataset, train_subset) | |
if self.hparams.test_path: | |
self.test_dataset = self.read_csv(self.hparams.test_path) | |
self.test_dataset = CVPRDataset(self.test_dataset, self.hparams.test_img_dir_path) | |
def train_dataloader(self) -> DataLoader: | |
""" Function that loads the train set. """ | |
return DataLoader( | |
dataset=self.train_dataset, | |
sampler=RandomSampler(self.train_dataset), | |
batch_size=self.hparams.batch_size, | |
collate_fn=self.prepare_sample, | |
num_workers=self.hparams.loader_workers, | |
) | |
def val_dataloader(self) -> DataLoader: | |
""" Function that loads the validation set. """ | |
return [ | |
DataLoader( | |
dataset=self.train_subset, | |
batch_size=self.hparams.batch_size, | |
collate_fn=self.prepare_sample, | |
num_workers=self.hparams.loader_workers, | |
), | |
DataLoader( | |
dataset=self.val_dataset, | |
batch_size=self.hparams.batch_size, | |
collate_fn=self.prepare_sample, | |
num_workers=self.hparams.loader_workers, | |
), | |
] | |
def test_dataloader(self) -> DataLoader: | |
""" Function that loads the validation set. """ | |
return DataLoader( | |
dataset=self.test_dataset, | |
batch_size=self.hparams.batch_size, | |
collate_fn=self.prepare_sample, | |
num_workers=self.hparams.loader_workers, | |
) | |
class CVPRDataset(Dataset): | |
def __init__(self, dataset, img_dir_path): | |
self.dataset = dataset | |
self.img_dir_path = img_dir_path | |
# Filter out entries with broken image files | |
for data in self.dataset: | |
assert self.is_image_ok(path.join(img_dir_path, f"{data['imgid']}")) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
from copy import deepcopy | |
# Get image from df | |
imgid = self.dataset[idx]["imgid"] | |
img_name = path.join(self.img_dir_path, f"{imgid}") | |
# Get label from df | |
labels = deepcopy(self.dataset[idx]) | |
# print(labels) | |
# Open image file | |
# from detectron2.data.detection_utils import read_image | |
# img = read_image(img_name, format="RGB") | |
img = Image.open(img_name).convert("RGB") | |
labels["img"] = img | |
return labels | |
def is_image_ok(img_path): | |
# Try to open the image file | |
try: | |
img = Image.open(img_path) | |
img.verify() | |
return True | |
except (IOError, SyntaxError) as e: | |
return False | |