Increase FLOPS robustness (#1608)
Browse files- utils/torch_utils.py +2 -2
utils/torch_utils.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
# PyTorch utils
|
2 |
|
3 |
import logging
|
4 |
-
import math
|
5 |
import os
|
6 |
import time
|
7 |
from contextlib import contextmanager
|
8 |
from copy import deepcopy
|
9 |
|
|
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
import torch.nn as nn
|
@@ -152,7 +152,7 @@ def model_info(model, verbose=False, img_size=640):
|
|
152 |
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
-
stride = int(model.stride.max())
|
156 |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
157 |
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
158 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
|
|
1 |
# PyTorch utils
|
2 |
|
3 |
import logging
|
|
|
4 |
import os
|
5 |
import time
|
6 |
from contextlib import contextmanager
|
7 |
from copy import deepcopy
|
8 |
|
9 |
+
import math
|
10 |
import torch
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
import torch.nn as nn
|
|
|
152 |
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
+
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
156 |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
157 |
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
158 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|