cheng-hust's picture
Upload 91 files
e8861c0 verified
"""by lyuwenyu
"""
from pprint import pprint
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.cuda.amp.grad_scaler import GradScaler
from typing import Callable, List, Dict
__all__ = ['BaseConfig', ]
class BaseConfig(object):
# TODO property
def __init__(self) -> None:
super().__init__()
self.task :str = None
self._model :nn.Module = None
self._postprocessor :nn.Module = None
self._criterion :nn.Module = None
self._optimizer :Optimizer = None
self._lr_scheduler :LRScheduler = None
self._train_dataloader :DataLoader = None
self._val_dataloader :DataLoader = None
self._ema :nn.Module = None
self._scaler :GradScaler = None
self.train_dataset :Dataset = None
self.val_dataset :Dataset = None
self.num_workers :int = 0
self.collate_fn :Callable = None
self.batch_size :int = None
self._train_batch_size :int = None
self._val_batch_size :int = None
self._train_shuffle: bool = None
self._val_shuffle: bool = None
self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None
# runtime
self.resume :str = None
self.tuning :str = None
self.epoches :int = None
self.last_epoch :int = -1
self.end_epoch :int = None
self.use_amp :bool = False
self.use_ema :bool = False
self.sync_bn :bool = False
self.clip_max_norm : float = None
self.find_unused_parameters :bool = None
# self.ema_decay: float = 0.9999
# self.grad_clip_: Callable = None
self.log_dir :str = './logs/'
self.log_step :int = 10
self._output_dir :str = None
self._print_freq :int = None
self.checkpoint_step :int = 1
# self.device :str = torch.device('cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
@property
def model(self, ) -> nn.Module:
return self._model
@model.setter
def model(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._model = m
@property
def postprocessor(self, ) -> nn.Module:
return self._postprocessor
@postprocessor.setter
def postprocessor(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._postprocessor = m
@property
def criterion(self, ) -> nn.Module:
return self._criterion
@criterion.setter
def criterion(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._criterion = m
@property
def optimizer(self, ) -> Optimizer:
return self._optimizer
@optimizer.setter
def optimizer(self, m):
assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class'
self._optimizer = m
@property
def lr_scheduler(self, ) -> LRScheduler:
return self._lr_scheduler
@lr_scheduler.setter
def lr_scheduler(self, m):
assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class'
self._lr_scheduler = m
@property
def train_dataloader(self):
if self._train_dataloader is None and self.train_dataset is not None:
loader = DataLoader(self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
shuffle=self.train_shuffle, )
loader.shuffle = self.train_shuffle
self._train_dataloader = loader
return self._train_dataloader
@train_dataloader.setter
def train_dataloader(self, loader):
self._train_dataloader = loader
@property
def val_dataloader(self):
if self._val_dataloader is None and self.val_dataset is not None:
loader = DataLoader(self.val_dataset,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
drop_last=False,
collate_fn=self.collate_fn,
shuffle=self.val_shuffle)
loader.shuffle = self.val_shuffle
self._val_dataloader = loader
return self._val_dataloader
@val_dataloader.setter
def val_dataloader(self, loader):
self._val_dataloader = loader
# TODO method
# @property
# def ema(self, ) -> nn.Module:
# if self._ema is None and self.use_ema and self.model is not None:
# self._ema = ModelEMA(self.model, self.ema_decay)
# return self._ema
@property
def ema(self, ) -> nn.Module:
return self._ema
@ema.setter
def ema(self, obj):
self._ema = obj
@property
def scaler(self) -> GradScaler:
if self._scaler is None and self.use_amp and torch.cuda.is_available():
self._scaler = GradScaler()
return self._scaler
@scaler.setter
def scaler(self, obj: GradScaler):
self._scaler = obj
@property
def val_shuffle(self):
if self._val_shuffle is None:
print('warning: set default val_shuffle=False')
return False
return self._val_shuffle
@val_shuffle.setter
def val_shuffle(self, shuffle):
assert isinstance(shuffle, bool), 'shuffle must be bool'
self._val_shuffle = shuffle
@property
def train_shuffle(self):
if self._train_shuffle is None:
print('warning: set default train_shuffle=True')
return True
return self._train_shuffle
@train_shuffle.setter
def train_shuffle(self, shuffle):
assert isinstance(shuffle, bool), 'shuffle must be bool'
self._train_shuffle = shuffle
@property
def train_batch_size(self):
if self._train_batch_size is None and isinstance(self.batch_size, int):
print(f'warning: set train_batch_size=batch_size={self.batch_size}')
return self.batch_size
return self._train_batch_size
@train_batch_size.setter
def train_batch_size(self, batch_size):
assert isinstance(batch_size, int), 'batch_size must be int'
self._train_batch_size = batch_size
@property
def val_batch_size(self):
if self._val_batch_size is None:
print(f'warning: set val_batch_size=batch_size={self.batch_size}')
return self.batch_size
return self._val_batch_size
@val_batch_size.setter
def val_batch_size(self, batch_size):
assert isinstance(batch_size, int), 'batch_size must be int'
self._val_batch_size = batch_size
@property
def output_dir(self):
if self._output_dir is None:
return self.log_dir
return self._output_dir
@output_dir.setter
def output_dir(self, root):
self._output_dir = root
@property
def print_freq(self):
if self._print_freq is None:
# self._print_freq = self.log_step
return self.log_step
return self._print_freq
@print_freq.setter
def print_freq(self, n):
assert isinstance(n, int), 'print_freq must be int'
self._print_freq = n
# def __repr__(self) -> str:
# pass