kashif HF staff commited on
Commit
a028d0b
·
1 Parent(s): 45e60de

Upload run_experiment.py

Browse files
Files changed (1) hide show
  1. run_experiment.py +154 -0
run_experiment.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import datetime
3
+ import pprint
4
+ from typing import Optional
5
+
6
+
7
+ from src import (
8
+ load_dataset,
9
+ fit_predict_with_model,
10
+ score_predictions,
11
+ AVAILABLE_DATASETS,
12
+ AVAILABLE_MODELS,
13
+ SEASONALITY_MAP,
14
+ )
15
+
16
+
17
+ def apply_ablation(ablation: str, model_kwargs: dict) -> dict:
18
+ if ablation == "NoEnsemble":
19
+ model_kwargs["enable_ensemble"] = False
20
+ elif ablation == "NoDeepModels":
21
+ model_kwargs["hyperparameters"] = {
22
+ "Naive": {},
23
+ "SeasonalNaive": {},
24
+ "ARIMA": {},
25
+ "ETS": {},
26
+ "AutoETS": {},
27
+ "AutoARIMA": {},
28
+ "Theta": {},
29
+ "AutoGluonTabular": {},
30
+ }
31
+ elif ablation == "NoStatModels":
32
+ model_kwargs["hyperparameters"] = {
33
+ "AutoGluonTabular": {},
34
+ "DeepAR": {},
35
+ "SimpleFeedForward": {},
36
+ "TemporalFusionTransformer": {},
37
+ }
38
+ elif ablation == "NoTreeModels":
39
+ model_kwargs["hyperparameters"] = {
40
+ "Naive": {},
41
+ "SeasonalNaive": {},
42
+ "ARIMA": {},
43
+ "ETS": {},
44
+ "AutoETS": {},
45
+ "AutoARIMA": {},
46
+ "Theta": {},
47
+ "DeepAR": {},
48
+ "SimpleFeedForward": {},
49
+ "TemporalFusionTransformer": {},
50
+ }
51
+ return model_kwargs
52
+
53
+
54
+ @click.command(
55
+ context_settings=dict(
56
+ ignore_unknown_options=True,
57
+ allow_extra_args=True,
58
+ )
59
+ )
60
+ @click.option(
61
+ "--dataset_name",
62
+ "-d",
63
+ required=True,
64
+ default="m3_other",
65
+ help="The dataset to train the model on",
66
+ type=click.Choice(AVAILABLE_DATASETS),
67
+ )
68
+ @click.option(
69
+ "--model_name",
70
+ "-m",
71
+ default="autogluon",
72
+ help="Model to train",
73
+ type=click.Choice(AVAILABLE_MODELS),
74
+ )
75
+ @click.option(
76
+ "--eval_metric",
77
+ "-e",
78
+ default="MASE",
79
+ type=click.Choice(["MASE", "mean_wQuantileLoss"]),
80
+ )
81
+ @click.option(
82
+ "--seed",
83
+ "-s",
84
+ default=1,
85
+ type=int,
86
+ )
87
+ @click.option(
88
+ "--time_limit",
89
+ "-t",
90
+ default=4 * 3600,
91
+ type=int,
92
+ )
93
+ @click.option(
94
+ "--ablation",
95
+ "-a",
96
+ default=None,
97
+ type=click.Choice(["NoEnsemble", "NoDeepModels", "NoStatModels", "NoTreeModels"]),
98
+ )
99
+ @click.pass_context
100
+ def main(
101
+ ctx,
102
+ dataset_name: str,
103
+ model_name: str,
104
+ eval_metric: str,
105
+ seed: int,
106
+ time_limit: int,
107
+ ablation: Optional[str],
108
+ ):
109
+ print(f"Evaluating {model_name} on {dataset_name}")
110
+ dataset = load_dataset(dataset_name)
111
+ task_kwargs = {
112
+ "prediction_length": dataset.metadata.prediction_length,
113
+ "freq": dataset.metadata.freq,
114
+ "eval_metric": eval_metric,
115
+ "seasonality": SEASONALITY_MAP[dataset.metadata.freq],
116
+ }
117
+ print("Task definition:")
118
+ pprint.pprint(task_kwargs)
119
+
120
+ # Additional command line arguments like `--name value` are parsed as {"name": "value"}
121
+ model_kwargs = {ctx.args[i][2:]: ctx.args[i + 1] for i in range(0, len(ctx.args), 2)}
122
+ model_kwargs["seed"] = seed
123
+ model_kwargs["time_limit"] = time_limit
124
+
125
+ if ablation is not None:
126
+ assert model_name == "autogluon", f"{model_name} does not support ablations"
127
+ model_kwargs = apply_ablation(ablation, model_kwargs)
128
+
129
+ if len(model_kwargs) > 0:
130
+ print("Model kwargs:")
131
+ pprint.pprint(model_kwargs)
132
+
133
+ print(f"Starting training {datetime.datetime.now()}")
134
+
135
+ predictions, info = fit_predict_with_model(
136
+ model_name, dataset.train, **task_kwargs, **model_kwargs
137
+ )
138
+
139
+ metrics = score_predictions(
140
+ dataset=dataset.test,
141
+ predictions=predictions,
142
+ prediction_length=task_kwargs["prediction_length"],
143
+ seasonality=task_kwargs["seasonality"],
144
+ )
145
+ print("================================================")
146
+ print(f"model: {model_name}")
147
+ print(f"dataset: {dataset_name}")
148
+ print(f"total_run_time: {info['run_time']:.2f}")
149
+ print(f"mase: {metrics['MASE']:.4f}")
150
+ print(f"mean_wQuantileLoss: {metrics['mean_wQuantileLoss']:.4f}")
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()