glenn-jocher commited on
Commit
8918e63
·
unverified ·
1 Parent(s): ba48f86

Increase FLOPS robustness (#1608)

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +2 -2
utils/torch_utils.py CHANGED
@@ -1,12 +1,12 @@
1
  # PyTorch utils
2
 
3
  import logging
4
- import math
5
  import os
6
  import time
7
  from contextlib import contextmanager
8
  from copy import deepcopy
9
 
 
10
  import torch
11
  import torch.backends.cudnn as cudnn
12
  import torch.nn as nn
@@ -152,7 +152,7 @@ def model_info(model, verbose=False, img_size=640):
152
 
153
  try: # FLOPS
154
  from thop import profile
155
- stride = int(model.stride.max())
156
  img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
157
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
158
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
 
1
  # PyTorch utils
2
 
3
  import logging
 
4
  import os
5
  import time
6
  from contextlib import contextmanager
7
  from copy import deepcopy
8
 
9
+ import math
10
  import torch
11
  import torch.backends.cudnn as cudnn
12
  import torch.nn as nn
 
152
 
153
  try: # FLOPS
154
  from thop import profile
155
+ stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
156
  img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
157
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
158
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float