#!/usr/bin/env python3 # Scene Text Recognition Model Hub # Copyright 2022 Darwin Bautista # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math import os import shutil from pathlib import Path import hydra import numpy as np from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, open_dict from ray import air, train, tune from ray.tune import CLIReporter from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback from ray.tune.schedulers import MedianStoppingRule from ray.tune.search.ax import AxSearch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import TensorBoardLogger from strhub.data.module import SceneTextDataModule from strhub.models.base import BaseSystem log = logging.getLogger(__name__) class MetricTracker(tune.Stopper): """Tracks the trend of the metric. Stops downward/stagnant trials. Assumes metric is being maximized.""" def __init__(self, metric, max_t, patience: int = 3, window: int = 3) -> None: super().__init__() self.metric = metric self.trial_history = {} self.max_t = max_t self.training_iteration = 0 self.eps = 0.01 # sensitivity self.patience = patience # number of consecutive downward/stagnant samples to trigger early stoppage. self.kernel = self.gaussian_pdf(np.arange(window) - window // 2, sigma=0.6) # Extra samples to keep in order to have better MAs + gradients for the middle p samples. self.buffer = 2 * (len(self.kernel) // 2) + 2 @staticmethod def gaussian_pdf(x, sigma=1.0): return np.exp(-((x / sigma) ** 2) / 2) / (sigma * np.sqrt(2 * np.pi)) @staticmethod def moving_average(x, k): return np.convolve(x, k, 'valid') / k.sum() def __call__(self, trial_id, result): self.training_iteration = result['training_iteration'] if np.isnan(result['loss']) or self.training_iteration >= self.max_t: try: del self.trial_history[trial_id] except KeyError: pass return True history = self.trial_history.get(trial_id, []) # FIFO queue of metric values. history = history[-(self.patience + self.buffer - 1) :] + [result[self.metric]] # Only start checking once we have enough data. At least one non-zero sample is required. if len(history) == self.patience + self.buffer and sum(history) > 0: smooth_grad = np.gradient(self.moving_average(history, self.kernel))[1:-1] # discard edge values. # Check if trend is downward or stagnant if (smooth_grad < self.eps).all(): log.info(f'Stopping trial = {trial_id}, hist = {history}, grad = {smooth_grad}') try: del self.trial_history[trial_id] except KeyError: pass return True self.trial_history[trial_id] = history return False def stop_all(self): return False class TuneReportCheckpointPruneCallback(TuneReportCheckpointCallback): def _handle(self, trainer: Trainer, pl_module: LightningModule): super()._handle(trainer, pl_module) # Prune older checkpoints trial_dir = train.get_context().get_trial_dir() for old in sorted(Path(trial_dir).glob('checkpoint_epoch=*-step=*'), key=os.path.getmtime)[:-1]: log.info(f'Deleting old checkpoint: {old}') shutil.rmtree(old) def trainable(hparams, config): with open_dict(config): config.model.lr = hparams['lr'] # config.model.weight_decay = hparams['wd'] model: BaseSystem = hydra.utils.instantiate(config.model) datamodule: SceneTextDataModule = hydra.utils.instantiate(config.data) tune_callback = TuneReportCheckpointPruneCallback({ 'loss': 'val_loss', 'NED': 'val_NED', 'accuracy': 'val_accuracy', }) if checkpoint := train.get_checkpoint(): with checkpoint.as_directory() as checkpoint_dir: ckpt_path = os.path.join(checkpoint_dir, 'checkpoint') else: ckpt_path = None trainer: Trainer = hydra.utils.instantiate( config.trainer, enable_progress_bar=False, enable_checkpointing=False, logger=TensorBoardLogger(save_dir=train.get_context().get_trial_dir(), name='', version='.'), callbacks=[tune_callback], ) trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) @hydra.main(config_path='configs', config_name='tune', version_base='1.2') def main(config: DictConfig): # Special handling for PARseq if config.model.get('perm_mirrored', False): assert config.model.perm_num % 2 == 0, 'perm_num should be even if perm_mirrored = True' # Modify config with open_dict(config): # Use mixed-precision training if config.trainer.get('gpus', 0): config.trainer.precision = 16 # Resolve absolute path to data.root_dir config.data.root_dir = hydra.utils.to_absolute_path(config.data.root_dir) hparams = { 'lr': tune.loguniform(config.tune.lr.min, config.tune.lr.max), # 'wd': tune.loguniform(config.tune.wd.min, config.tune.wd.max), } steps_per_epoch = len(hydra.utils.instantiate(config.data).train_dataloader()) val_steps = steps_per_epoch * config.trainer.max_epochs / config.trainer.val_check_interval max_t = round(0.75 * val_steps) warmup_t = round(config.model.warmup_pct * val_steps) scheduler = MedianStoppingRule(time_attr='training_iteration', grace_period=warmup_t) # Always start by evenly diving the range in log scale. lr = hparams['lr'] start = np.log10(lr.lower) stop = np.log10(lr.upper) num = math.ceil(stop - start) + 1 initial_points = [{'lr': np.clip(x, lr.lower, lr.upper).item()} for x in reversed(np.logspace(start, stop, num))] search_alg = AxSearch(points_to_evaluate=initial_points) reporter = CLIReporter(parameter_columns=['lr'], metric_columns=['loss', 'accuracy', 'training_iteration']) out_dir = Path(HydraConfig.get().runtime.output_dir if config.tune.resume_dir is None else config.tune.resume_dir) resources_per_trial = { 'cpu': 1, 'gpu': config.tune.gpus_per_trial, } wrapped_trainable = tune.with_parameters(tune.with_resources(trainable, resources_per_trial), config=config) if config.tune.resume_dir is None: tuner = tune.Tuner( wrapped_trainable, param_space=hparams, tune_config=tune.TuneConfig( mode='max', metric='NED', search_alg=search_alg, scheduler=scheduler, num_samples=config.tune.num_samples, ), run_config=air.RunConfig( name=out_dir.name, stop=MetricTracker('NED', max_t), progress_reporter=reporter, local_dir=str(out_dir.parent.absolute()), ), ) else: tuner = tune.Tuner.restore(config.tune.resume_dir, wrapped_trainable) results = tuner.fit() best_result = results.get_best_result() print('Best hyperparameters found were:', best_result.config) print('with result:\n', best_result) if __name__ == '__main__': main()