from datetime import timedelta from pathlib import Path import re from time import time from typing import Any, Tuple import pandas as pd from hydra import TaskFunction from hydra.core.hydra_config import HydraConfig from hydra.core.utils import _save_config from hydra.experimental.callbacks import Callback from hydra.types import RunMode from omegaconf import DictConfig, OmegaConf from deepscreen.utils import get_logger log = get_logger(__name__) class CSVExperimentSummary(Callback): """On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv.""" def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'): self.filename = filename self.prefix = prefix if isinstance(prefix, str) else tuple(prefix) self.input_experiment_summary = None self.time = {} def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: if config.hydra.get('overrides') and config.hydra.overrides.get('task'): for i, override in enumerate(config.hydra.overrides.task): if override.startswith("ckpt_path"): ckpt_path = override.split('=', 1)[1] if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path) break if config.hydra.sweeper.get('params'): if config.hydra.sweeper.params.get('ckpt_path'): ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"") if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path) def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None: self.time['start'] = time() def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None: # Skip callback if job is DDP subprocess if "ddp" in return try: self.time['end'] = time() if config.hydra.mode == RunMode.RUN: summary_file_path = Path( / self.filename elif config.hydra.mode == RunMode.MULTIRUN: summary_file_path = Path(config.hydra.sweep.dir) / self.filename else: raise RuntimeError('Invalid Hydra `RunMode`.') if summary_file_path.is_file(): summary_df = pd.read_csv(summary_file_path) else: summary_df = pd.DataFrame() # Add job and override info info_dict = {} if job_return.overrides: info_dict = dict(override.split('=', 1) for override in job_return.overrides) info_dict['job_status'] = info_dict['job_id'] = info_dict['wall_time'] = str(timedelta(self.time['end'] - self.time['start'])) # Add checkpoint info if info_dict.get('ckpt_path'): info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"") ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"") if Path(ckpt_path).is_file(): if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']: info_dict['previous_ckpt_path'] = info_dict['ckpt_path'] info_dict['ckpt_path'] = ckpt_path info_dict['best_epoch'] = int('epoch_(\d+)', info_dict['ckpt_path']).group(1)) # Add metrics info metrics_df = pd.DataFrame() if config.get('logger'): output_dir = Path(config.hydra.runtime.output_dir).resolve() csv_metrics_path = output_dir / / "metrics.csv" if csv_metrics_path.is_file():"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}") metrics_df = pd.read_csv(csv_metrics_path) # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch test_columns = [col for col in metrics_df.columns if col.startswith('test/')] if test_columns: mask = metrics_df[test_columns].notna().any(axis=1) metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] # Group and filter by best epoch metrics_df = metrics_df.groupby('epoch').first() metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] else:"No metrics.csv found in {output_dir}") if metrics_df.empty: metrics_df = pd.DataFrame(data=info_dict, index=[0]) else: metrics_df = metrics_df.assign(**info_dict) metrics_df.index = [0] # Add extra info from the input batch experiment summary if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns: orig_meta = self.input_experiment_summary[ self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0] ].head(1) if not orig_meta.empty: orig_meta.index = [0] metrics_df = metrics_df.combine_first(orig_meta) summary_df = pd.concat([summary_df, metrics_df]) # Drop empty columns summary_df.dropna(inplace=True, axis=1, how='all') summary_df.to_csv(summary_file_path, index=False, mode='w')"Experiment summary saved to {summary_file_path}") except Exception as e: log.exception("Unable to save the experiment summary due to an error.", exc_info=e) def parse_ckpt_path_from_experiment_summary(self, ckpt_path): try: self.input_experiment_summary = pd.read_csv( ckpt_path, usecols=lambda col: not col.startswith(self.prefix) ) self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply( lambda x: x.strip("'\"") ) ckpt_list = list(set(self.input_experiment_summary['ckpt_path'])) parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list]) return parsed_ckpt_path except Exception as e: log.exception( f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).', exc_info=e ) def checkpoint_rerun_config(config: DictConfig): hydra_cfg = HydraConfig.get() if hydra_cfg.output_subdir is not None: ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir if ckpt_cfg_path.is_file():"Found config file for the checkpoint at {str(ckpt_cfg_path)}; " f"merging config overrides with checkpoint config...") ckpt_cfg = OmegaConf.load(ckpt_cfg_path) # Merge checkpoint config with test config by overriding specified nodes. # ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'trainer', 'task']) # = OmegaConf.masked_copy(, [ # key for key in if key not in ['data_file', 'split', 'train_val_test_split'] # ]) # # config = OmegaConf.merge(ckpt_cfg, config) # config = OmegaConf.masked_copy(config, # [key for key in config if key not in # ['task']]) # = OmegaConf.masked_copy(, # [key for key in if key not in # ['drug_featurizer', 'protein_featurizer', 'collator']]) # config.model = OmegaConf.masked_copy(config.model, # [key for key in config.model if key not in # ['predictor']]) # # config = OmegaConf.merge(ckpt_cfg, config) ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task', 'seed']) = OmegaConf.masked_copy(, [ key for key in if key not in ['data_file', 'split', 'train_val_test_split'] ]) ckpt_override_keys = ['task', 'data.drug_featurizer', 'data.protein_featurizer', 'data.collator', 'model.predictor', 'model.out', 'model.loss', 'model.activation', 'model.metrics'] for key in ckpt_override_keys: OmegaConf.update(config, key,, key), force_add=True) config = OmegaConf.merge(ckpt_cfg, config) # OmegaConf.set_readonly(hydra_cfg, False) # hydra_cfg.job.override_dirname += f"ckpt={str(Path(*Path(config.ckpt_path).parts[-4:]))}" _save_config(config, "config.yaml", hydra_output) return config