tkarras commited on
Commit
ecfea65
·
1 Parent(s): 1d25833

Fix metrics to work with grayscale datasets (#9)

Browse files
Files changed (1) hide show
  1. metrics/metric_utils.py +6 -1
metrics/metric_utils.py CHANGED
@@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l
213
  # Main loop.
214
  item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215
  for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
 
 
216
  features = detector(images.to(opts.device), **detector_kwargs)
217
  stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
218
  progress.update(stats.num_items)
@@ -262,7 +264,10 @@ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel
262
  c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
263
  c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
264
  images.append(run_generator(z, c))
265
- features = detector(torch.cat(images), **detector_kwargs)
 
 
 
266
  stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
267
  progress.update(stats.num_items)
268
  return stats
 
213
  # Main loop.
214
  item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215
  for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216
+ if images.shape[1] == 1:
217
+ images = images.repeat([1, 3, 1, 1])
218
  features = detector(images.to(opts.device), **detector_kwargs)
219
  stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
220
  progress.update(stats.num_items)
 
264
  c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
265
  c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
266
  images.append(run_generator(z, c))
267
+ images = torch.cat(images)
268
+ if images.shape[1] == 1:
269
+ images = images.repeat([1, 3, 1, 1])
270
+ features = detector(images, **detector_kwargs)
271
  stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
272
  progress.update(stats.num_items)
273
  return stats