Commit
·
b50fdf1
1
Parent(s):
9a9c4f1
model.names multi-GPU bug fix #94
Browse files
detect.py
CHANGED
@@ -46,7 +46,7 @@ def detect(save_img=False):
|
|
46 |
dataset = LoadImages(source, img_size=imgsz)
|
47 |
|
48 |
# Get names and colors
|
49 |
-
names = model.names if hasattr(model, '
|
50 |
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
|
51 |
|
52 |
# Run inference
|
|
|
46 |
dataset = LoadImages(source, img_size=imgsz)
|
47 |
|
48 |
# Get names and colors
|
49 |
+
names = model.module.names if hasattr(model, 'module') else model.names
|
50 |
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
|
51 |
|
52 |
# Run inference
|
train.py
CHANGED
@@ -79,6 +79,7 @@ def train(hyp):
|
|
79 |
# Create model
|
80 |
model = Model(opt.cfg).to(device)
|
81 |
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
|
|
82 |
|
83 |
# Image sizes
|
84 |
gs = int(max(model.stride)) # grid size (max stride)
|
@@ -193,7 +194,6 @@ def train(hyp):
|
|
193 |
model.hyp = hyp # attach hyperparameters to model
|
194 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
195 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
196 |
-
model.names = data_dict['names']
|
197 |
|
198 |
# Class frequency
|
199 |
labels = np.concatenate(dataset.labels, 0)
|
|
|
79 |
# Create model
|
80 |
model = Model(opt.cfg).to(device)
|
81 |
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
82 |
+
model.names = data_dict['names']
|
83 |
|
84 |
# Image sizes
|
85 |
gs = int(max(model.stride)) # grid size (max stride)
|
|
|
194 |
model.hyp = hyp # attach hyperparameters to model
|
195 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
196 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
|
|
197 |
|
198 |
# Class frequency
|
199 |
labels = np.concatenate(dataset.labels, 0)
|