Feng Wang
commited on
Commit
·
24c0eb8
1
Parent(s):
27cdfe4
fix(model): compatible meshgrid and CUDA OOM error
Browse files- hubconf.py +1 -1
- yolox/models/yolo_head.py +8 -4
- yolox/utils/__init__.py +1 -0
- yolox/utils/compat.py +15 -0
hubconf.py
CHANGED
@@ -8,7 +8,7 @@ Usage example:
|
|
8 |
"""
|
9 |
dependencies = ["torch"]
|
10 |
|
11 |
-
from yolox.models import ( # noqa: F401, E402
|
12 |
yolox_tiny,
|
13 |
yolox_nano,
|
14 |
yolox_s,
|
|
|
8 |
"""
|
9 |
dependencies = ["torch"]
|
10 |
|
11 |
+
from yolox.models import ( # isort:skip # noqa: F401, E402
|
12 |
yolox_tiny,
|
13 |
yolox_nano,
|
14 |
yolox_s,
|
yolox/models/yolo_head.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
-
from yolox.utils import bboxes_iou
|
13 |
|
14 |
from .losses import IOUloss
|
15 |
from .network_blocks import BaseConv, DWConv
|
@@ -220,7 +220,7 @@ class YOLOXHead(nn.Module):
|
|
220 |
n_ch = 5 + self.num_classes
|
221 |
hsize, wsize = output.shape[-2:]
|
222 |
if grid.shape[2:4] != output.shape[2:4]:
|
223 |
-
yv, xv =
|
224 |
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
225 |
self.grids[k] = grid
|
226 |
|
@@ -237,7 +237,7 @@ class YOLOXHead(nn.Module):
|
|
237 |
grids = []
|
238 |
strides = []
|
239 |
for (hsize, wsize), stride in zip(self.hw, self.strides):
|
240 |
-
yv, xv =
|
241 |
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
|
242 |
grids.append(grid)
|
243 |
shape = grid.shape[:2]
|
@@ -321,7 +321,11 @@ class YOLOXHead(nn.Module):
|
|
321 |
labels,
|
322 |
imgs,
|
323 |
)
|
324 |
-
except RuntimeError:
|
|
|
|
|
|
|
|
|
325 |
logger.error(
|
326 |
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
|
327 |
CPU mode is applied in this batch. If you want to avoid this issue, \
|
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
+
from yolox.utils import bboxes_iou, meshgrid
|
13 |
|
14 |
from .losses import IOUloss
|
15 |
from .network_blocks import BaseConv, DWConv
|
|
|
220 |
n_ch = 5 + self.num_classes
|
221 |
hsize, wsize = output.shape[-2:]
|
222 |
if grid.shape[2:4] != output.shape[2:4]:
|
223 |
+
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
224 |
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
225 |
self.grids[k] = grid
|
226 |
|
|
|
237 |
grids = []
|
238 |
strides = []
|
239 |
for (hsize, wsize), stride in zip(self.hw, self.strides):
|
240 |
+
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
241 |
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
|
242 |
grids.append(grid)
|
243 |
shape = grid.shape[:2]
|
|
|
321 |
labels,
|
322 |
imgs,
|
323 |
)
|
324 |
+
except RuntimeError as e:
|
325 |
+
# TODO: the string might change, consider a better way
|
326 |
+
if "CUDA out of memory. " not in str(e):
|
327 |
+
raise # RuntimeError might not caused by CUDA OOM
|
328 |
+
|
329 |
logger.error(
|
330 |
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
|
331 |
CPU mode is applied in this batch. If you want to avoid this issue, \
|
yolox/utils/__init__.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
from .allreduce_norm import *
|
6 |
from .boxes import *
|
7 |
from .checkpoint import load_ckpt, save_checkpoint
|
|
|
8 |
from .demo_utils import *
|
9 |
from .dist import *
|
10 |
from .ema import *
|
|
|
5 |
from .allreduce_norm import *
|
6 |
from .boxes import *
|
7 |
from .checkpoint import load_ckpt, save_checkpoint
|
8 |
+
from .compat import meshgrid
|
9 |
from .demo_utils import *
|
10 |
from .dist import *
|
11 |
from .ema import *
|
yolox/utils/compat.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
|
7 |
+
|
8 |
+
__all__ = ["meshgrid"]
|
9 |
+
|
10 |
+
|
11 |
+
def meshgrid(*tensors):
|
12 |
+
if _TORCH_VER >= [1, 10]:
|
13 |
+
return torch.meshgrid(*tensors, indexing="ij")
|
14 |
+
else:
|
15 |
+
return torch.meshgrid(*tensors)
|