libokj commited on
Commit
07abb03
·
verified ·
1 Parent(s): 81e4dc5

Delete deepscreen/utils/hydra.py.bak

Browse files
Files changed (1) hide show
  1. deepscreen/utils/hydra.py.bak +0 -182
deepscreen/utils/hydra.py.bak DELETED
@@ -1,182 +0,0 @@
1
- from datetime import datetime
2
- from pathlib import Path
3
- import re
4
- from typing import Any, Tuple
5
-
6
- import pandas as pd
7
- from hydra import TaskFunction
8
- from hydra.core.hydra_config import HydraConfig
9
- from hydra.core.override_parser.overrides_parser import OverridesParser
10
- from hydra.core.utils import _save_config
11
- from hydra.experimental.callbacks import Callback
12
- from hydra.types import RunMode
13
- from hydra._internal.config_loader_impl import ConfigLoaderImpl
14
- from omegaconf import DictConfig, OmegaConf
15
- from omegaconf.errors import MissingMandatoryValue
16
-
17
- from deepscreen.utils import get_logger
18
-
19
- log = get_logger(__name__)
20
-
21
-
22
- class CSVExperimentSummary(Callback):
23
- """On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv."""
24
-
25
- def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'):
26
- self.filename = filename
27
- self.prefix = prefix if isinstance(prefix, str) else tuple(prefix)
28
- self.input_experiment_summary = None
29
- self.time = {}
30
-
31
- def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
32
- if config.hydra.get('overrides') and config.hydra.overrides.get('task'):
33
- for i, override in enumerate(config.hydra.overrides.task):
34
- if override.startswith("ckpt_path"):
35
- ckpt_path = override.split('=', 1)[1]
36
- if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
37
- config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
38
- log.info(ckpt_path)
39
- break
40
- if config.hydra.sweeper.get('params'):
41
- if config.hydra.sweeper.params.get('ckpt_path'):
42
- ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"")
43
- if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
44
- config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
45
- log.info(ckpt_path)
46
- def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None:
47
- self.time['start'] = datetime.now()
48
-
49
- def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None:
50
- # Skip callback if job is DDP subprocess
51
- if "ddp" in job_return.hydra_cfg.hydra.job.name:
52
- return
53
-
54
- try:
55
- self.time['end'] = datetime.now()
56
- if config.hydra.mode == RunMode.RUN:
57
- summary_file_path = Path(config.hydra.run.dir) / self.filename
58
- elif config.hydra.mode == RunMode.MULTIRUN:
59
- summary_file_path = Path(config.hydra.sweep.dir) / self.filename
60
- else:
61
- raise RuntimeError('Invalid Hydra `RunMode`.')
62
-
63
- if summary_file_path.is_file():
64
- summary_df = pd.read_csv(summary_file_path)
65
- else:
66
- summary_df = pd.DataFrame()
67
-
68
- # Add job and override info
69
- info_dict = {}
70
- if job_return.overrides:
71
- info_dict = dict(override.split('=', 1) for override in job_return.overrides)
72
- info_dict['job_status'] = job_return.status.name
73
- info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id
74
- info_dict['wall_time'] = str(self.time['end'] - self.time['start'])
75
-
76
- # Add checkpoint info
77
- if info_dict.get('ckpt_path'):
78
- info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"")
79
-
80
- ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"")
81
- if Path(ckpt_path).is_file():
82
- if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']:
83
- info_dict['previous_ckpt_path'] = info_dict['ckpt_path']
84
- info_dict['ckpt_path'] = ckpt_path
85
- if info_dict.get('ckpt_path'):
86
- info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1))
87
-
88
- # Add metrics info
89
- metrics_df = pd.DataFrame()
90
- if config.get('logger'):
91
- output_dir = Path(config.hydra.runtime.output_dir).resolve()
92
- csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
93
- if csv_metrics_path.is_file():
94
- log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
95
- metrics_df = pd.read_csv(csv_metrics_path)
96
- # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch
97
- test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
98
- if test_columns:
99
- mask = metrics_df[test_columns].notna().any(axis=1)
100
- metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch']
101
- # Group and filter by best epoch
102
- metrics_df = metrics_df.groupby('epoch').first()
103
- metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']]
104
- else:
105
- log.info(f"No metrics.csv found in {output_dir}")
106
-
107
- if metrics_df.empty:
108
- metrics_df = pd.DataFrame(data=info_dict, index=[0])
109
- else:
110
- metrics_df = metrics_df.assign(**info_dict)
111
- metrics_df.index = [0]
112
-
113
- # Add extra info from the input batch experiment summary
114
- if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns:
115
- log.info(self.input_experiment_summary['ckpt_path'])
116
- log.info(metrics_df['ckpt_path'])
117
- orig_meta = self.input_experiment_summary[
118
- self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
119
- ].head(1)
120
- if not orig_meta.empty:
121
- orig_meta.index = [0]
122
- metrics_df = metrics_df.astype('O').combine_first(orig_meta.astype('O'))
123
-
124
- summary_df = pd.concat([summary_df, metrics_df])
125
-
126
- # Drop empty columns
127
- summary_df.dropna(inplace=True, axis=1, how='all')
128
- summary_df.to_csv(summary_file_path, index=False, mode='w')
129
- log.info(f"Experiment summary saved to {summary_file_path}")
130
- except Exception as e:
131
- log.exception("Unable to save the experiment summary due to an error.", exc_info=e)
132
-
133
- def parse_ckpt_path_from_experiment_summary(self, ckpt_path):
134
- log.info(ckpt_path)
135
- try:
136
- self.input_experiment_summary = pd.read_csv(
137
- ckpt_path, usecols=lambda col: not col.startswith(self.prefix)
138
- )
139
- self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply(
140
- lambda x: x.strip("'\"")
141
- )
142
- ckpt_list = list(set(self.input_experiment_summary['ckpt_path']))
143
- parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list])
144
- return parsed_ckpt_path
145
-
146
- except Exception as e:
147
- log.exception(
148
- f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).',
149
- exc_info=e
150
- )
151
-
152
-
153
- def checkpoint_rerun_config(config: DictConfig):
154
- hydra_cfg = HydraConfig.get()
155
-
156
- if hydra_cfg.get('output_subdir'):
157
- ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml'
158
- hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir
159
-
160
- if ckpt_cfg_path.is_file():
161
- log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; "
162
- f"merging config overrides with checkpoint config...")
163
- ckpt_cfg = OmegaConf.load(ckpt_cfg_path)
164
-
165
- # Recompose checkpoint config with overrides
166
-
167
- if hydra_cfg.overrides.get('task'):
168
- parser = OverridesParser.create()
169
- parsed_overrides = parser.parse_overrides(overrides=hydra_cfg.overrides.task)
170
- filtered_overrides = []
171
- for override in parsed_overrides:
172
- if not override.is_force_add():
173
- OmegaConf.update(ckpt_cfg, override.key_or_group, override.value())
174
- filtered_overrides.append(override)
175
- log.info(filtered_overrides)
176
- ConfigLoaderImpl._apply_overrides_to_config(filtered_overrides, config)
177
-
178
- _save_config(config, "config.yaml", hydra_output)
179
-
180
- return config
181
-
182
-