yxNONG
commited on
Update train.py
Browse files
train.py
CHANGED
@@ -147,15 +147,6 @@ def train(hyp):
|
|
147 |
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
148 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
149 |
|
150 |
-
# Initialize distributed training
|
151 |
-
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
|
152 |
-
dist.init_process_group(backend='nccl', # distributed backend
|
153 |
-
init_method='tcp://127.0.0.1:9999', # init method
|
154 |
-
world_size=1, # number of nodes
|
155 |
-
rank=0) # node rank
|
156 |
-
model = torch.nn.parallel.DistributedDataParallel(model)
|
157 |
-
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
158 |
-
|
159 |
# Trainloader
|
160 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
161 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
@@ -173,6 +164,15 @@ def train(hyp):
|
|
173 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
174 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
175 |
model.names = data_dict['names']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
# Class frequency
|
178 |
labels = np.concatenate(dataset.labels, 0)
|
@@ -289,7 +289,7 @@ def train(hyp):
|
|
289 |
batch_size=batch_size,
|
290 |
imgsz=imgsz_test,
|
291 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
292 |
-
model=ema.ema,
|
293 |
single_cls=opt.single_cls,
|
294 |
dataloader=testloader)
|
295 |
|
@@ -315,14 +315,6 @@ def train(hyp):
|
|
315 |
# Save model
|
316 |
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
317 |
if save:
|
318 |
-
if hasattr(model, 'module'):
|
319 |
-
# Duplicate Model parameters for Multi-GPU save
|
320 |
-
ema.ema.module.nc = model.nc # attach number of classes to model
|
321 |
-
ema.ema.module.hyp = model.hyp # attach hyperparameters to model
|
322 |
-
ema.ema.module.gr = model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
323 |
-
ema.ema.module.class_weights = model.class_weights # attach class weights
|
324 |
-
ema.ema.module.names = data_dict['names']
|
325 |
-
|
326 |
with open(results_file, 'r') as f: # create checkpoint
|
327 |
ckpt = {'epoch': epoch,
|
328 |
'best_fitness': best_fitness,
|
|
|
147 |
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
148 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
# Trainloader
|
151 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
152 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
|
|
164 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
165 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
166 |
model.names = data_dict['names']
|
167 |
+
|
168 |
+
# Initialize distributed training
|
169 |
+
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
|
170 |
+
dist.init_process_group(backend='nccl', # distributed backend
|
171 |
+
init_method='tcp://127.0.0.1:9999', # init method
|
172 |
+
world_size=1, # number of nodes
|
173 |
+
rank=0) # node rank
|
174 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
175 |
+
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
176 |
|
177 |
# Class frequency
|
178 |
labels = np.concatenate(dataset.labels, 0)
|
|
|
289 |
batch_size=batch_size,
|
290 |
imgsz=imgsz_test,
|
291 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
292 |
+
model=ema.ema.module if hasattr(model, 'module') else ema.ema,
|
293 |
single_cls=opt.single_cls,
|
294 |
dataloader=testloader)
|
295 |
|
|
|
315 |
# Save model
|
316 |
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
317 |
if save:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
with open(results_file, 'r') as f: # create checkpoint
|
319 |
ckpt = {'epoch': epoch,
|
320 |
'best_fitness': best_fitness,
|