glenn-jocher commited on
Commit
b50fdf1
·
1 Parent(s): 9a9c4f1

model.names multi-GPU bug fix #94

Browse files
Files changed (2) hide show
  1. detect.py +1 -1
  2. train.py +1 -1
detect.py CHANGED
@@ -46,7 +46,7 @@ def detect(save_img=False):
46
  dataset = LoadImages(source, img_size=imgsz)
47
 
48
  # Get names and colors
49
- names = model.names if hasattr(model, 'names') else model.modules.names
50
  colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
51
 
52
  # Run inference
 
46
  dataset = LoadImages(source, img_size=imgsz)
47
 
48
  # Get names and colors
49
+ names = model.module.names if hasattr(model, 'module') else model.names
50
  colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
51
 
52
  # Run inference
train.py CHANGED
@@ -79,6 +79,7 @@ def train(hyp):
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
 
82
 
83
  # Image sizes
84
  gs = int(max(model.stride)) # grid size (max stride)
@@ -193,7 +194,6 @@ def train(hyp):
193
  model.hyp = hyp # attach hyperparameters to model
194
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
195
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
196
- model.names = data_dict['names']
197
 
198
  # Class frequency
199
  labels = np.concatenate(dataset.labels, 0)
 
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
82
+ model.names = data_dict['names']
83
 
84
  # Image sizes
85
  gs = int(max(model.stride)) # grid size (max stride)
 
194
  model.hyp = hyp # attach hyperparameters to model
195
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
196
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
 
197
 
198
  # Class frequency
199
  labels = np.concatenate(dataset.labels, 0)