henry000 commited on
Commit
593a50f
·
1 Parent(s): de99a93

🔒️ [Fix] torch version diff in meshgrid

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +4 -1
yolo/utils/bounding_box_utils.py CHANGED
@@ -129,7 +129,10 @@ def generate_anchors(image_size: List[int], strides: List[int]):
129
  shift = stride // 2
130
  h = torch.arange(0, H, stride) + shift
131
  w = torch.arange(0, W, stride) + shift
132
- anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
 
 
 
133
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
 
129
  shift = stride // 2
130
  h = torch.arange(0, H, stride) + shift
131
  w = torch.arange(0, W, stride) + shift
132
+ if torch.__version__ >= "1.10.0":
133
+ anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
134
+ else:
135
+ anchor_h, anchor_w = torch.meshgrid(h, w)
136
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
137
  anchors.append(anchor)
138
  all_anchors = torch.cat(anchors, dim=0)