Remove `image_weights` DDP code (#4579)
Browse files
train.py
CHANGED
@@ -265,21 +265,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
265 |
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
266 |
model.train()
|
267 |
|
268 |
-
# Update image weights (optional)
|
269 |
if opt.image_weights:
|
270 |
-
#
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
# Broadcast if DDP
|
276 |
-
if RANK != -1:
|
277 |
-
indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
|
278 |
-
dist.broadcast(indices, 0)
|
279 |
-
if RANK != 0:
|
280 |
-
dataset.indices = indices.cpu().numpy()
|
281 |
-
|
282 |
-
# Update mosaic border
|
283 |
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
|
284 |
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
285 |
|
|
|
265 |
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
266 |
model.train()
|
267 |
|
268 |
+
# Update image weights (optional, single-GPU only)
|
269 |
if opt.image_weights:
|
270 |
+
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
|
271 |
+
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
|
272 |
+
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
|
273 |
+
|
274 |
+
# Update mosaic border (optional)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
|
276 |
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
277 |
|