Commit
·
09402a2
1
Parent(s):
4fb8cb3
torch.from_tensor() bug fix
Browse files
train.py
CHANGED
@@ -225,7 +225,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
225 |
if rank != -1:
|
226 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
227 |
if rank == 0:
|
228 |
-
indices[:] = torch.
|
229 |
dist.broadcast(indices, 0)
|
230 |
if rank != 0:
|
231 |
dataset.indices = indices.cpu().numpy()
|
|
|
225 |
if rank != -1:
|
226 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
227 |
if rank == 0:
|
228 |
+
indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
|
229 |
dist.broadcast(indices, 0)
|
230 |
if rank != 0:
|
231 |
dataset.indices = indices.cpu().numpy()
|