Simplify callbacks (#4289)
Browse files- utils/callbacks.py +20 -21
utils/callbacks.py
CHANGED
@@ -58,12 +58,11 @@ class Callbacks:
|
|
58 |
else:
|
59 |
return self._callbacks
|
60 |
|
61 |
-
|
62 |
-
def run_callbacks(register, *args, **kwargs):
|
63 |
"""
|
64 |
Loop through the registered actions and fire all callbacks
|
65 |
"""
|
66 |
-
for logger in
|
67 |
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
68 |
logger['callback'](*args, **kwargs)
|
69 |
|
@@ -71,106 +70,106 @@ class Callbacks:
|
|
71 |
"""
|
72 |
Fires all registered callbacks at the start of each pretraining routine
|
73 |
"""
|
74 |
-
self.run_callbacks(
|
75 |
|
76 |
def on_pretrain_routine_end(self, *args, **kwargs):
|
77 |
"""
|
78 |
Fires all registered callbacks at the end of each pretraining routine
|
79 |
"""
|
80 |
-
self.run_callbacks(
|
81 |
|
82 |
def on_train_start(self, *args, **kwargs):
|
83 |
"""
|
84 |
Fires all registered callbacks at the start of each training
|
85 |
"""
|
86 |
-
self.run_callbacks(
|
87 |
|
88 |
def on_train_epoch_start(self, *args, **kwargs):
|
89 |
"""
|
90 |
Fires all registered callbacks at the start of each training epoch
|
91 |
"""
|
92 |
-
self.run_callbacks(
|
93 |
|
94 |
def on_train_batch_start(self, *args, **kwargs):
|
95 |
"""
|
96 |
Fires all registered callbacks at the start of each training batch
|
97 |
"""
|
98 |
-
self.run_callbacks(
|
99 |
|
100 |
def optimizer_step(self, *args, **kwargs):
|
101 |
"""
|
102 |
Fires all registered callbacks on each optimizer step
|
103 |
"""
|
104 |
-
self.run_callbacks(
|
105 |
|
106 |
def on_before_zero_grad(self, *args, **kwargs):
|
107 |
"""
|
108 |
Fires all registered callbacks before zero grad
|
109 |
"""
|
110 |
-
self.run_callbacks(
|
111 |
|
112 |
def on_train_batch_end(self, *args, **kwargs):
|
113 |
"""
|
114 |
Fires all registered callbacks at the end of each training batch
|
115 |
"""
|
116 |
-
self.run_callbacks(
|
117 |
|
118 |
def on_train_epoch_end(self, *args, **kwargs):
|
119 |
"""
|
120 |
Fires all registered callbacks at the end of each training epoch
|
121 |
"""
|
122 |
-
self.run_callbacks(
|
123 |
|
124 |
def on_val_start(self, *args, **kwargs):
|
125 |
"""
|
126 |
Fires all registered callbacks at the start of the validation
|
127 |
"""
|
128 |
-
self.run_callbacks(
|
129 |
|
130 |
def on_val_batch_start(self, *args, **kwargs):
|
131 |
"""
|
132 |
Fires all registered callbacks at the start of each validation batch
|
133 |
"""
|
134 |
-
self.run_callbacks(
|
135 |
|
136 |
def on_val_image_end(self, *args, **kwargs):
|
137 |
"""
|
138 |
Fires all registered callbacks at the end of each val image
|
139 |
"""
|
140 |
-
self.run_callbacks(
|
141 |
|
142 |
def on_val_batch_end(self, *args, **kwargs):
|
143 |
"""
|
144 |
Fires all registered callbacks at the end of each validation batch
|
145 |
"""
|
146 |
-
self.run_callbacks(
|
147 |
|
148 |
def on_val_end(self, *args, **kwargs):
|
149 |
"""
|
150 |
Fires all registered callbacks at the end of the validation
|
151 |
"""
|
152 |
-
self.run_callbacks(
|
153 |
|
154 |
def on_fit_epoch_end(self, *args, **kwargs):
|
155 |
"""
|
156 |
Fires all registered callbacks at the end of each fit (train+val) epoch
|
157 |
"""
|
158 |
-
self.run_callbacks(
|
159 |
|
160 |
def on_model_save(self, *args, **kwargs):
|
161 |
"""
|
162 |
Fires all registered callbacks after each model save
|
163 |
"""
|
164 |
-
self.run_callbacks(
|
165 |
|
166 |
def on_train_end(self, *args, **kwargs):
|
167 |
"""
|
168 |
Fires all registered callbacks at the end of training
|
169 |
"""
|
170 |
-
self.run_callbacks(
|
171 |
|
172 |
def teardown(self, *args, **kwargs):
|
173 |
"""
|
174 |
Fires all registered callbacks before teardown
|
175 |
"""
|
176 |
-
self.run_callbacks(
|
|
|
58 |
else:
|
59 |
return self._callbacks
|
60 |
|
61 |
+
def run_callbacks(self, hook, *args, **kwargs):
|
|
|
62 |
"""
|
63 |
Loop through the registered actions and fire all callbacks
|
64 |
"""
|
65 |
+
for logger in self._callbacks[hook]:
|
66 |
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
67 |
logger['callback'](*args, **kwargs)
|
68 |
|
|
|
70 |
"""
|
71 |
Fires all registered callbacks at the start of each pretraining routine
|
72 |
"""
|
73 |
+
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
|
74 |
|
75 |
def on_pretrain_routine_end(self, *args, **kwargs):
|
76 |
"""
|
77 |
Fires all registered callbacks at the end of each pretraining routine
|
78 |
"""
|
79 |
+
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
|
80 |
|
81 |
def on_train_start(self, *args, **kwargs):
|
82 |
"""
|
83 |
Fires all registered callbacks at the start of each training
|
84 |
"""
|
85 |
+
self.run_callbacks('on_train_start', *args, **kwargs)
|
86 |
|
87 |
def on_train_epoch_start(self, *args, **kwargs):
|
88 |
"""
|
89 |
Fires all registered callbacks at the start of each training epoch
|
90 |
"""
|
91 |
+
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
|
92 |
|
93 |
def on_train_batch_start(self, *args, **kwargs):
|
94 |
"""
|
95 |
Fires all registered callbacks at the start of each training batch
|
96 |
"""
|
97 |
+
self.run_callbacks('on_train_batch_start', *args, **kwargs)
|
98 |
|
99 |
def optimizer_step(self, *args, **kwargs):
|
100 |
"""
|
101 |
Fires all registered callbacks on each optimizer step
|
102 |
"""
|
103 |
+
self.run_callbacks('optimizer_step', *args, **kwargs)
|
104 |
|
105 |
def on_before_zero_grad(self, *args, **kwargs):
|
106 |
"""
|
107 |
Fires all registered callbacks before zero grad
|
108 |
"""
|
109 |
+
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
|
110 |
|
111 |
def on_train_batch_end(self, *args, **kwargs):
|
112 |
"""
|
113 |
Fires all registered callbacks at the end of each training batch
|
114 |
"""
|
115 |
+
self.run_callbacks('on_train_batch_end', *args, **kwargs)
|
116 |
|
117 |
def on_train_epoch_end(self, *args, **kwargs):
|
118 |
"""
|
119 |
Fires all registered callbacks at the end of each training epoch
|
120 |
"""
|
121 |
+
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
|
122 |
|
123 |
def on_val_start(self, *args, **kwargs):
|
124 |
"""
|
125 |
Fires all registered callbacks at the start of the validation
|
126 |
"""
|
127 |
+
self.run_callbacks('on_val_start', *args, **kwargs)
|
128 |
|
129 |
def on_val_batch_start(self, *args, **kwargs):
|
130 |
"""
|
131 |
Fires all registered callbacks at the start of each validation batch
|
132 |
"""
|
133 |
+
self.run_callbacks('on_val_batch_start', *args, **kwargs)
|
134 |
|
135 |
def on_val_image_end(self, *args, **kwargs):
|
136 |
"""
|
137 |
Fires all registered callbacks at the end of each val image
|
138 |
"""
|
139 |
+
self.run_callbacks('on_val_image_end', *args, **kwargs)
|
140 |
|
141 |
def on_val_batch_end(self, *args, **kwargs):
|
142 |
"""
|
143 |
Fires all registered callbacks at the end of each validation batch
|
144 |
"""
|
145 |
+
self.run_callbacks('on_val_batch_end', *args, **kwargs)
|
146 |
|
147 |
def on_val_end(self, *args, **kwargs):
|
148 |
"""
|
149 |
Fires all registered callbacks at the end of the validation
|
150 |
"""
|
151 |
+
self.run_callbacks('on_val_end', *args, **kwargs)
|
152 |
|
153 |
def on_fit_epoch_end(self, *args, **kwargs):
|
154 |
"""
|
155 |
Fires all registered callbacks at the end of each fit (train+val) epoch
|
156 |
"""
|
157 |
+
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
|
158 |
|
159 |
def on_model_save(self, *args, **kwargs):
|
160 |
"""
|
161 |
Fires all registered callbacks after each model save
|
162 |
"""
|
163 |
+
self.run_callbacks('on_model_save', *args, **kwargs)
|
164 |
|
165 |
def on_train_end(self, *args, **kwargs):
|
166 |
"""
|
167 |
Fires all registered callbacks at the end of training
|
168 |
"""
|
169 |
+
self.run_callbacks('on_train_end', *args, **kwargs)
|
170 |
|
171 |
def teardown(self, *args, **kwargs):
|
172 |
"""
|
173 |
Fires all registered callbacks before teardown
|
174 |
"""
|
175 |
+
self.run_callbacks('teardown', *args, **kwargs)
|