🔒️ [Fix] torch version diff in meshgrid
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -129,7 +129,10 @@ def generate_anchors(image_size: List[int], strides: List[int]):
|
|
129 |
shift = stride // 2
|
130 |
h = torch.arange(0, H, stride) + shift
|
131 |
w = torch.arange(0, W, stride) + shift
|
132 |
-
|
|
|
|
|
|
|
133 |
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
134 |
anchors.append(anchor)
|
135 |
all_anchors = torch.cat(anchors, dim=0)
|
|
|
129 |
shift = stride // 2
|
130 |
h = torch.arange(0, H, stride) + shift
|
131 |
w = torch.arange(0, W, stride) + shift
|
132 |
+
if torch.__version__ >= "1.10.0":
|
133 |
+
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
|
134 |
+
else:
|
135 |
+
anchor_h, anchor_w = torch.meshgrid(h, w)
|
136 |
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
137 |
anchors.append(anchor)
|
138 |
all_anchors = torch.cat(anchors, dim=0)
|