markytools's picture
added strexp
d61b9c7
raw
history blame
8.41 kB
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import Any, Tuple, List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from pytorch_lightning.utilities.types import STEP_OUTPUT
from timm.optim.optim_factory import param_groups_weight_decay
from strhub.models.base import CrossEntropySystem
from strhub.models.utils import init_weights
from .model_abinet_iter import ABINetIterModel as Model
log = logging.getLogger(__name__)
class ABINet(CrossEntropySystem):
def __init__(self, charset_train: str, charset_test: str, max_label_length: int,
batch_size: int, lr: float, warmup_pct: float, weight_decay: float,
iter_size: int, d_model: int, nhead: int, d_inner: int, dropout: float, activation: str,
v_loss_weight: float, v_attention: str, v_attention_mode: str, v_backbone: str, v_num_layers: int,
l_loss_weight: float, l_num_layers: int, l_detach: bool, l_use_self_attn: bool,
l_lr: float, a_loss_weight: float, lm_only: bool = False, **kwargs) -> None:
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay)
self.scheduler = None
self.save_hyperparameters()
self.max_label_length = max_label_length
self.num_classes = len(self.tokenizer) - 2 # We don't predict <bos> nor <pad>
self.model = Model(max_label_length, self.eos_id, self.num_classes, iter_size, d_model, nhead, d_inner,
dropout, activation, v_loss_weight, v_attention, v_attention_mode, v_backbone, v_num_layers,
l_loss_weight, l_num_layers, l_detach, l_use_self_attn, a_loss_weight)
self.model.apply(init_weights)
# FIXME: doesn't support resumption from checkpoint yet
self._reset_alignment = True
self._reset_optimizers = True
self.l_lr = l_lr
self.lm_only = lm_only
# Train LM only. Freeze other submodels.
if lm_only:
self.l_lr = lr # for tuning
self.model.vision.requires_grad_(False)
self.model.alignment.requires_grad_(False)
@property
def _pretraining(self):
# In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs.
total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches
return self.global_step < (8 / (8 + 10)) * total_steps
@torch.jit.ignore
def no_weight_decay(self):
return {'model.language.proj.weight'}
def _add_weight_decay(self, model: nn.Module, skip_list=()):
if self.weight_decay:
return param_groups_weight_decay(model, self.weight_decay, skip_list)
else:
return [{'params': model.parameters()}]
def configure_optimizers(self):
agb = self.trainer.accumulate_grad_batches
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.
lr = lr_scale * self.lr
l_lr = lr_scale * self.l_lr
params = []
params.extend(self._add_weight_decay(self.model.vision))
params.extend(self._add_weight_decay(self.model.alignment))
# We use a different learning rate for the LM.
for p in self._add_weight_decay(self.model.language, ('proj.weight',)):
p['lr'] = l_lr
params.append(p)
max_lr = [p.get('lr', lr) for p in params]
optim = AdamW(params, lr)
self.scheduler = OneCycleLR(optim, max_lr, self.trainer.estimated_stepping_batches,
pct_start=self.warmup_pct, cycle_momentum=False)
return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}}
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
logits = self.model.forward(images)[0]['logits']
return logits[:, :max_length + 1] # truncate
def calc_loss(self, targets, *res_lists) -> Tensor:
total_loss = 0
for res_list in res_lists:
loss = 0
if isinstance(res_list, dict):
res_list = [res_list]
for res in res_list:
logits = res['logits'].flatten(end_dim=1)
loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id)
loss /= len(res_list)
self.log('loss_' + res_list[0]['name'], loss)
total_loss += res_list[0]['loss_weight'] * loss
return total_loss
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
if not self._pretraining and self._reset_optimizers:
log.info('Pretraining ends. Updating base LRs.')
self._reset_optimizers = False
# Make base_lr the same for all groups
base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM
self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs)
def _prepare_inputs_and_targets(self, labels):
# Use dummy label to ensure sequence length is constant.
dummy = ['0' * self.max_label_length]
targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:]
targets = targets[:, 1:] # remove <bos>. Unused here.
# Inputs are padded with eos_id
inputs = torch.where(targets == self.pad_id, self.eos_id, targets)
inputs = F.one_hot(inputs, self.num_classes).float()
lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos
return inputs, lengths, targets
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
images, labels = batch
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
if self.lm_only:
l_res = self.model.language(inputs, lengths)
loss = self.calc_loss(targets, l_res)
# Pretrain submodels independently first
elif self._pretraining:
# Vision
v_res = self.model.vision(images)
# Language
l_res = self.model.language(inputs, lengths)
# We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used).
# We'll reset its parameters prior to joint training.
a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach())
loss = self.calc_loss(targets, v_res, l_res, a_res)
else:
# Reset alignment model's parameters once prior to full model training.
if self._reset_alignment:
log.info('Pretraining ends. Resetting alignment model.')
self._reset_alignment = False
self.model.alignment.apply(init_weights)
all_a_res, all_l_res, v_res = self.model.forward(images)
loss = self.calc_loss(targets, v_res, all_l_res, all_a_res)
self.log('loss', loss)
return loss
def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]:
if self.lm_only:
inputs, lengths, targets = self._prepare_inputs_and_targets(labels)
l_res = self.model.language(inputs, lengths)
loss = self.calc_loss(targets, l_res)
loss_numel = (targets != self.pad_id).sum()
return l_res['logits'], loss, loss_numel
else:
return super().forward_logits_loss(images, labels)