Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import unittest | |
import numpy as np | |
import torch | |
from mmdet.structures import DetDataSample | |
from mmdet.structures.bbox import HorizontalBoxes | |
from mmengine.structures import InstanceData | |
from mmyolo.datasets import BatchShapePolicy, yolov5_collate | |
def _rand_bboxes(rng, num_boxes, w, h): | |
cx, cy, bw, bh = rng.rand(num_boxes, 4).T | |
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w) | |
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h) | |
br_x = ((cx * w) + (w * bw / 2)).clip(0, w) | |
br_y = ((cy * h) + (h * bh / 2)).clip(0, h) | |
bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T | |
return bboxes | |
class TestYOLOv5Collate(unittest.TestCase): | |
def test_yolov5_collate(self): | |
rng = np.random.RandomState(0) | |
inputs = torch.randn((3, 10, 10)) | |
data_samples = DetDataSample() | |
gt_instances = InstanceData() | |
bboxes = _rand_bboxes(rng, 4, 6, 8) | |
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32) | |
labels = rng.randint(1, 2, size=len(bboxes)) | |
gt_instances.labels = torch.LongTensor(labels) | |
data_samples.gt_instances = gt_instances | |
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)]) | |
self.assertIsInstance(out, dict) | |
self.assertTrue(out['inputs'].shape == (1, 3, 10, 10)) | |
self.assertTrue(out['data_samples'], dict) | |
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (4, 6)) | |
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)] * | |
2) | |
self.assertIsInstance(out, dict) | |
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10)) | |
self.assertTrue(out['data_samples'], dict) | |
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (8, 6)) | |
def test_yolov5_collate_with_multi_scale(self): | |
rng = np.random.RandomState(0) | |
inputs = torch.randn((3, 10, 10)) | |
data_samples = DetDataSample() | |
gt_instances = InstanceData() | |
bboxes = _rand_bboxes(rng, 4, 6, 8) | |
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32) | |
labels = rng.randint(1, 2, size=len(bboxes)) | |
gt_instances.labels = torch.LongTensor(labels) | |
data_samples.gt_instances = gt_instances | |
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)], | |
use_ms_training=True) | |
self.assertIsInstance(out, dict) | |
self.assertTrue(out['inputs'][0].shape == (3, 10, 10)) | |
self.assertTrue(out['data_samples'], dict) | |
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (4, 6)) | |
self.assertIsInstance(out['inputs'], list) | |
self.assertIsInstance(out['data_samples']['bboxes_labels'], | |
torch.Tensor) | |
out = yolov5_collate( | |
[dict(inputs=inputs, data_samples=data_samples)] * 2, | |
use_ms_training=True) | |
self.assertIsInstance(out, dict) | |
self.assertTrue(out['inputs'][0].shape == (3, 10, 10)) | |
self.assertTrue(out['data_samples'], dict) | |
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (8, 6)) | |
self.assertIsInstance(out['inputs'], list) | |
self.assertIsInstance(out['data_samples']['bboxes_labels'], | |
torch.Tensor) | |
class TestBatchShapePolicy(unittest.TestCase): | |
def test_batch_shape_policy(self): | |
src_data_infos = [{ | |
'height': 20, | |
'width': 100, | |
}, { | |
'height': 11, | |
'width': 100, | |
}, { | |
'height': 21, | |
'width': 100, | |
}, { | |
'height': 30, | |
'width': 100, | |
}, { | |
'height': 10, | |
'width': 100, | |
}] | |
expected_data_infos = [{ | |
'height': 10, | |
'width': 100, | |
'batch_shape': np.array([96, 672]) | |
}, { | |
'height': 11, | |
'width': 100, | |
'batch_shape': np.array([96, 672]) | |
}, { | |
'height': 20, | |
'width': 100, | |
'batch_shape': np.array([160, 672]) | |
}, { | |
'height': 21, | |
'width': 100, | |
'batch_shape': np.array([160, 672]) | |
}, { | |
'height': 30, | |
'width': 100, | |
'batch_shape': np.array([224, 672]) | |
}] | |
batch_shapes_policy = BatchShapePolicy(batch_size=2) | |
out_data_infos = batch_shapes_policy(src_data_infos) | |
for i in range(5): | |
self.assertEqual( | |
(expected_data_infos[i]['height'], | |
expected_data_infos[i]['width']), | |
(out_data_infos[i]['height'], out_data_infos[i]['width'])) | |
self.assertTrue( | |
np.allclose(expected_data_infos[i]['batch_shape'], | |
out_data_infos[i]['batch_shape'])) | |