glenn-jocher commited on
Commit
4103ce9
·
unverified ·
1 Parent(s): 771ac6c

Simplify callbacks (#4289)

Browse files
Files changed (1) hide show
  1. utils/callbacks.py +20 -21
utils/callbacks.py CHANGED
@@ -58,12 +58,11 @@ class Callbacks:
58
  else:
59
  return self._callbacks
60
 
61
- @staticmethod
62
- def run_callbacks(register, *args, **kwargs):
63
  """
64
  Loop through the registered actions and fire all callbacks
65
  """
66
- for logger in register:
67
  # print(f"Running callbacks.{logger['callback'].__name__}()")
68
  logger['callback'](*args, **kwargs)
69
 
@@ -71,106 +70,106 @@ class Callbacks:
71
  """
72
  Fires all registered callbacks at the start of each pretraining routine
73
  """
74
- self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
75
 
76
  def on_pretrain_routine_end(self, *args, **kwargs):
77
  """
78
  Fires all registered callbacks at the end of each pretraining routine
79
  """
80
- self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
81
 
82
  def on_train_start(self, *args, **kwargs):
83
  """
84
  Fires all registered callbacks at the start of each training
85
  """
86
- self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
87
 
88
  def on_train_epoch_start(self, *args, **kwargs):
89
  """
90
  Fires all registered callbacks at the start of each training epoch
91
  """
92
- self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
93
 
94
  def on_train_batch_start(self, *args, **kwargs):
95
  """
96
  Fires all registered callbacks at the start of each training batch
97
  """
98
- self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
99
 
100
  def optimizer_step(self, *args, **kwargs):
101
  """
102
  Fires all registered callbacks on each optimizer step
103
  """
104
- self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
105
 
106
  def on_before_zero_grad(self, *args, **kwargs):
107
  """
108
  Fires all registered callbacks before zero grad
109
  """
110
- self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
111
 
112
  def on_train_batch_end(self, *args, **kwargs):
113
  """
114
  Fires all registered callbacks at the end of each training batch
115
  """
116
- self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
117
 
118
  def on_train_epoch_end(self, *args, **kwargs):
119
  """
120
  Fires all registered callbacks at the end of each training epoch
121
  """
122
- self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
123
 
124
  def on_val_start(self, *args, **kwargs):
125
  """
126
  Fires all registered callbacks at the start of the validation
127
  """
128
- self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
129
 
130
  def on_val_batch_start(self, *args, **kwargs):
131
  """
132
  Fires all registered callbacks at the start of each validation batch
133
  """
134
- self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
135
 
136
  def on_val_image_end(self, *args, **kwargs):
137
  """
138
  Fires all registered callbacks at the end of each val image
139
  """
140
- self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
141
 
142
  def on_val_batch_end(self, *args, **kwargs):
143
  """
144
  Fires all registered callbacks at the end of each validation batch
145
  """
146
- self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
147
 
148
  def on_val_end(self, *args, **kwargs):
149
  """
150
  Fires all registered callbacks at the end of the validation
151
  """
152
- self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
153
 
154
  def on_fit_epoch_end(self, *args, **kwargs):
155
  """
156
  Fires all registered callbacks at the end of each fit (train+val) epoch
157
  """
158
- self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
159
 
160
  def on_model_save(self, *args, **kwargs):
161
  """
162
  Fires all registered callbacks after each model save
163
  """
164
- self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
165
 
166
  def on_train_end(self, *args, **kwargs):
167
  """
168
  Fires all registered callbacks at the end of training
169
  """
170
- self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
171
 
172
  def teardown(self, *args, **kwargs):
173
  """
174
  Fires all registered callbacks before teardown
175
  """
176
- self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
 
58
  else:
59
  return self._callbacks
60
 
61
+ def run_callbacks(self, hook, *args, **kwargs):
 
62
  """
63
  Loop through the registered actions and fire all callbacks
64
  """
65
+ for logger in self._callbacks[hook]:
66
  # print(f"Running callbacks.{logger['callback'].__name__}()")
67
  logger['callback'](*args, **kwargs)
68
 
 
70
  """
71
  Fires all registered callbacks at the start of each pretraining routine
72
  """
73
+ self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
74
 
75
  def on_pretrain_routine_end(self, *args, **kwargs):
76
  """
77
  Fires all registered callbacks at the end of each pretraining routine
78
  """
79
+ self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
80
 
81
  def on_train_start(self, *args, **kwargs):
82
  """
83
  Fires all registered callbacks at the start of each training
84
  """
85
+ self.run_callbacks('on_train_start', *args, **kwargs)
86
 
87
  def on_train_epoch_start(self, *args, **kwargs):
88
  """
89
  Fires all registered callbacks at the start of each training epoch
90
  """
91
+ self.run_callbacks('on_train_epoch_start', *args, **kwargs)
92
 
93
  def on_train_batch_start(self, *args, **kwargs):
94
  """
95
  Fires all registered callbacks at the start of each training batch
96
  """
97
+ self.run_callbacks('on_train_batch_start', *args, **kwargs)
98
 
99
  def optimizer_step(self, *args, **kwargs):
100
  """
101
  Fires all registered callbacks on each optimizer step
102
  """
103
+ self.run_callbacks('optimizer_step', *args, **kwargs)
104
 
105
  def on_before_zero_grad(self, *args, **kwargs):
106
  """
107
  Fires all registered callbacks before zero grad
108
  """
109
+ self.run_callbacks('on_before_zero_grad', *args, **kwargs)
110
 
111
  def on_train_batch_end(self, *args, **kwargs):
112
  """
113
  Fires all registered callbacks at the end of each training batch
114
  """
115
+ self.run_callbacks('on_train_batch_end', *args, **kwargs)
116
 
117
  def on_train_epoch_end(self, *args, **kwargs):
118
  """
119
  Fires all registered callbacks at the end of each training epoch
120
  """
121
+ self.run_callbacks('on_train_epoch_end', *args, **kwargs)
122
 
123
  def on_val_start(self, *args, **kwargs):
124
  """
125
  Fires all registered callbacks at the start of the validation
126
  """
127
+ self.run_callbacks('on_val_start', *args, **kwargs)
128
 
129
  def on_val_batch_start(self, *args, **kwargs):
130
  """
131
  Fires all registered callbacks at the start of each validation batch
132
  """
133
+ self.run_callbacks('on_val_batch_start', *args, **kwargs)
134
 
135
  def on_val_image_end(self, *args, **kwargs):
136
  """
137
  Fires all registered callbacks at the end of each val image
138
  """
139
+ self.run_callbacks('on_val_image_end', *args, **kwargs)
140
 
141
  def on_val_batch_end(self, *args, **kwargs):
142
  """
143
  Fires all registered callbacks at the end of each validation batch
144
  """
145
+ self.run_callbacks('on_val_batch_end', *args, **kwargs)
146
 
147
  def on_val_end(self, *args, **kwargs):
148
  """
149
  Fires all registered callbacks at the end of the validation
150
  """
151
+ self.run_callbacks('on_val_end', *args, **kwargs)
152
 
153
  def on_fit_epoch_end(self, *args, **kwargs):
154
  """
155
  Fires all registered callbacks at the end of each fit (train+val) epoch
156
  """
157
+ self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
158
 
159
  def on_model_save(self, *args, **kwargs):
160
  """
161
  Fires all registered callbacks after each model save
162
  """
163
+ self.run_callbacks('on_model_save', *args, **kwargs)
164
 
165
  def on_train_end(self, *args, **kwargs):
166
  """
167
  Fires all registered callbacks at the end of training
168
  """
169
+ self.run_callbacks('on_train_end', *args, **kwargs)
170
 
171
  def teardown(self, *args, **kwargs):
172
  """
173
  Fires all registered callbacks before teardown
174
  """
175
+ self.run_callbacks('teardown', *args, **kwargs)