Update torch_utils.py
Browse files- utils/torch_utils.py +13 -7
utils/torch_utils.py
CHANGED
@@ -54,6 +54,11 @@ def time_synchronized():
|
|
54 |
return time.time()
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
def initialize_weights(model):
|
58 |
for m in model.modules():
|
59 |
t = type(m)
|
@@ -111,8 +116,8 @@ def model_info(model, verbose=False):
|
|
111 |
|
112 |
try: # FLOPS
|
113 |
from thop import profile
|
114 |
-
|
115 |
-
fs = ', %.1f GFLOPS' % (
|
116 |
except:
|
117 |
fs = ''
|
118 |
|
@@ -185,7 +190,7 @@ class ModelEMA:
|
|
185 |
self.updates += 1
|
186 |
d = self.decay(self.updates)
|
187 |
with torch.no_grad():
|
188 |
-
if
|
189 |
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
190 |
else:
|
191 |
msd, esd = model.state_dict(), self.ema.state_dict()
|
@@ -196,7 +201,8 @@ class ModelEMA:
|
|
196 |
v += (1. - d) * msd[k].detach()
|
197 |
|
198 |
def update_attr(self, model):
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
54 |
return time.time()
|
55 |
|
56 |
|
57 |
+
def is_parallel(model):
|
58 |
+
# is model is parallel with DP or DDP
|
59 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
60 |
+
|
61 |
+
|
62 |
def initialize_weights(model):
|
63 |
for m in model.modules():
|
64 |
t = type(m)
|
|
|
116 |
|
117 |
try: # FLOPS
|
118 |
from thop import profile
|
119 |
+
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
|
120 |
+
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
|
121 |
except:
|
122 |
fs = ''
|
123 |
|
|
|
190 |
self.updates += 1
|
191 |
d = self.decay(self.updates)
|
192 |
with torch.no_grad():
|
193 |
+
if is_parallel(model):
|
194 |
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
195 |
else:
|
196 |
msd, esd = model.state_dict(), self.ema.state_dict()
|
|
|
201 |
v += (1. - d) * msd[k].detach()
|
202 |
|
203 |
def update_attr(self, model):
|
204 |
+
# Update class attributes
|
205 |
+
ema = self.ema.module if is_parallel(model) else self.ema
|
206 |
+
for k, v in model.__dict__.items():
|
207 |
+
if not k.startswith('_') and k != 'module':
|
208 |
+
setattr(ema, k, v)
|