#!/usr/bin/env python3 # Copyright (c) Megvii Inc. All rights reserved. import ast import pprint from abc import ABCMeta, abstractmethod from typing import Dict, List, Tuple from tabulate import tabulate import torch from torch.nn import Module from yolox.utils import LRScheduler class BaseExp(metaclass=ABCMeta): """Basic class for any experiment.""" def __init__(self): self.seed = None self.output_dir = "./YOLOX_outputs" self.print_interval = 100 self.eval_interval = 10 self.dataset = None @abstractmethod def get_model(self) -> Module: pass @abstractmethod def get_dataset(self, cache: bool = False, cache_type: str = "ram"): pass @abstractmethod def get_data_loader( self, batch_size: int, is_distributed: bool ) -> Dict[str, torch.utils.data.DataLoader]: pass @abstractmethod def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer: pass @abstractmethod def get_lr_scheduler( self, lr: float, iters_per_epoch: int, **kwargs ) -> LRScheduler: pass @abstractmethod def get_evaluator(self): pass @abstractmethod def eval(self, model, evaluator, weights): pass def __repr__(self): table_header = ["keys", "values"] exp_table = [ (str(k), pprint.pformat(v)) for k, v in vars(self).items() if not k.startswith("_") ] return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid") def merge(self, cfg_list): assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}" for k, v in zip(cfg_list[0::2], cfg_list[1::2]): # only update value with same key if hasattr(self, k): src_value = getattr(self, k) src_type = type(src_value) # pre-process input if source type is list or tuple if isinstance(src_value, (List, Tuple)): v = v.strip("[]()") v = [t.strip() for t in v.split(",")] # find type of tuple if len(src_value) > 0: src_item_type = type(src_value[0]) v = [src_item_type(t) for t in v] if src_value is not None and src_type != type(v): try: v = src_type(v) except Exception: v = ast.literal_eval(v) setattr(self, k, v)