from typing import Dict, Any from torch import nn from data.datasets.ab_dataset import ABDataset from abc import ABC, abstractmethod from utils.common.log import logger import json import os from utils.common.others import backup_key_codes from .model import BaseModel from data import Scenario from schema import Schema from utils.common.data_record import write_json class BaseAlg(ABC): def __init__(self, models: Dict[str, BaseModel], res_save_dir): self.models = models self.res_save_dir = res_save_dir self.get_required_models_schema().validate(models) os.makedirs(res_save_dir) logger.info(f'[alg] init alg: {self.__class__.__name__}, res saved in {res_save_dir}') @abstractmethod def get_required_models_schema(self) -> Schema: raise NotImplementedError @abstractmethod def get_required_hyp_schema(self) -> Schema: raise NotImplementedError @abstractmethod def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: """ return metrics """ self.get_required_hyp_schema().validate(hyps) try: write_json(os.path.join(self.res_save_dir, 'hyps.json'), hyps, ensure_obj_serializable=True) except: with open(os.path.join(self.res_save_dir, 'hyps.txt'), 'w') as f: f.write(str(hyps)) write_json(os.path.join(self.res_save_dir, 'scenario.json'), scenario.to_json()) logger.info(f'[alg] alg {self.__class__.__name__} start running') backup_key_codes(os.path.join(self.res_save_dir, 'backup_codes'))