Spaces:
Runtime error
Runtime error
update demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +20 -0
- README.md +6 -5
- app.py +79 -0
- configs/_base_/datasets/parking_instance.py +48 -0
- configs/_base_/datasets/parking_instance_coco.py +49 -0
- configs/_base_/datasets/people_real_coco.py +49 -0
- configs/_base_/datasets/walt_people.py +49 -0
- configs/_base_/datasets/walt_vehicle.py +49 -0
- configs/_base_/default_runtime.py +16 -0
- configs/_base_/models/mask_rcnn_swin_fpn.py +127 -0
- configs/_base_/models/occ_mask_rcnn_swin_fpn.py +127 -0
- configs/_base_/schedules/schedule_1x.py +11 -0
- configs/walt/walt_people.py +80 -0
- configs/walt/walt_vehicle.py +80 -0
- cwalt/CWALT.py +161 -0
- cwalt/Clip_WALT_Generate.py +284 -0
- cwalt/Download_Detections.py +28 -0
- cwalt/clustering_utils.py +132 -0
- cwalt/kmedoid.py +55 -0
- cwalt/utils.py +168 -0
- cwalt_generate.py +14 -0
- docker/Dockerfile +52 -0
- github_vis/cwalt.gif +0 -0
- github_vis/vis_cars.gif +0 -0
- github_vis/vis_people.gif +0 -0
- infer.py +118 -0
- mmcv_custom/__init__.py +5 -0
- mmcv_custom/checkpoint.py +500 -0
- mmcv_custom/runner/__init__.py +8 -0
- mmcv_custom/runner/checkpoint.py +85 -0
- mmcv_custom/runner/epoch_based_runner.py +104 -0
- mmdet/__init__.py +28 -0
- mmdet/apis/__init__.py +10 -0
- mmdet/apis/inference.py +217 -0
- mmdet/apis/test.py +189 -0
- mmdet/apis/train.py +185 -0
- mmdet/core/__init__.py +7 -0
- mmdet/core/anchor/__init__.py +11 -0
- mmdet/core/anchor/anchor_generator.py +727 -0
- mmdet/core/anchor/builder.py +7 -0
- mmdet/core/anchor/point_generator.py +37 -0
- mmdet/core/anchor/utils.py +71 -0
- mmdet/core/bbox/__init__.py +27 -0
- mmdet/core/bbox/assigners/__init__.py +16 -0
- mmdet/core/bbox/assigners/approx_max_iou_assigner.py +145 -0
- mmdet/core/bbox/assigners/assign_result.py +204 -0
- mmdet/core/bbox/assigners/atss_assigner.py +178 -0
- mmdet/core/bbox/assigners/base_assigner.py +9 -0
- mmdet/core/bbox/assigners/center_region_assigner.py +335 -0
- mmdet/core/bbox/assigners/grid_assigner.py +155 -0
LICENSE
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022-2022 dinesh reddy and others
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
4 |
+
a copy of this software and associated documentation files (the
|
5 |
+
"Software"), to deal in the Software without restriction, including
|
6 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
7 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
8 |
+
permit persons to whom the Software is furnished to do so, subject to
|
9 |
+
the following conditions:
|
10 |
+
|
11 |
+
The above copyright notice and this permission notice shall be
|
12 |
+
included in all copies or substantial portions of the Software.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
15 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
16 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
17 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
18 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
19 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
20 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title: WALT
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: WALT DEMO
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.0.20
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from infer import detections
|
5 |
+
'''
|
6 |
+
import os
|
7 |
+
os.system("mkdir data")
|
8 |
+
os.system("mkdir data/models")
|
9 |
+
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
|
10 |
+
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
|
11 |
+
'''
|
12 |
+
def walt_demo(input_img, confidence_threshold):
|
13 |
+
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
14 |
+
if torch.cuda.is_available() == False:
|
15 |
+
device='cpu'
|
16 |
+
else:
|
17 |
+
device='cuda:0'
|
18 |
+
#detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
19 |
+
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
|
20 |
+
|
21 |
+
count = 0
|
22 |
+
#img = detect_people.run_on_image(input_img)
|
23 |
+
output_img = detect.run_on_image(input_img)
|
24 |
+
#try:
|
25 |
+
#except:
|
26 |
+
# print("detecting on image failed")
|
27 |
+
|
28 |
+
return output_img
|
29 |
+
|
30 |
+
description = """
|
31 |
+
WALT Demo on WALT dataset. After watching and automatically learning for several days, this approach shows significant performance improvement in detecting and segmenting occluded people and vehicles, over human-supervised amodal approaches</b>.
|
32 |
+
<center>
|
33 |
+
<a href="https://www.cs.cmu.edu/~walt/">
|
34 |
+
<img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-WALT-green">
|
35 |
+
</a>
|
36 |
+
<a href="https://www.cs.cmu.edu/~walt/pdf/walt.pdf"><img style="display:inline" src="https://img.shields.io/badge/Paper-Pdf-red"></a>
|
37 |
+
<a href="https://github.com/dineshreddy91/WALT"><img style="display:inline" src="https://img.shields.io/github/stars/dineshreddy91/WALT?style=social"></a>
|
38 |
+
</center>
|
39 |
+
"""
|
40 |
+
title = "WALT:Watch And Learn 2D Amodal Representation using Time-lapse Imagery"
|
41 |
+
article="""
|
42 |
+
<center>
|
43 |
+
<img src='https://visitor-badge.glitch.me/badge?page_id=anhquancao.MonoScene&left_color=darkmagenta&right_color=purple' alt='visitor badge'>
|
44 |
+
</center>
|
45 |
+
"""
|
46 |
+
|
47 |
+
examples = [
|
48 |
+
['demo/images/img_1.jpg',0.8],
|
49 |
+
['demo/images/img_2.jpg',0.8],
|
50 |
+
['demo/images/img_4.png',0.85],
|
51 |
+
]
|
52 |
+
|
53 |
+
'''
|
54 |
+
import cv2
|
55 |
+
filename='demo/images/img_1.jpg'
|
56 |
+
img=cv2.imread(filename)
|
57 |
+
img=walt_demo(img)
|
58 |
+
cv2.imwrite(filename.replace('/images/','/results/'),img)
|
59 |
+
cv2.imwrite('check.png',img)
|
60 |
+
'''
|
61 |
+
confidence_threshold = gr.Slider(minimum=0.3,
|
62 |
+
maximum=1.0,
|
63 |
+
step=0.01,
|
64 |
+
value=1.0,
|
65 |
+
label="Amodal Detection Confidence Threshold")
|
66 |
+
inputs = [gr.Image(), confidence_threshold]
|
67 |
+
demo = gr.Interface(walt_demo,
|
68 |
+
outputs="image",
|
69 |
+
inputs=inputs,
|
70 |
+
article=article,
|
71 |
+
title=title,
|
72 |
+
enable_queue=True,
|
73 |
+
examples=examples,
|
74 |
+
description=description)
|
75 |
+
|
76 |
+
#demo.launch(server_name="0.0.0.0", server_port=7000)
|
77 |
+
demo.launch()
|
78 |
+
|
79 |
+
|
configs/_base_/datasets/parking_instance.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'ParkingDataset'
|
2 |
+
data_root = 'data/parking/'
|
3 |
+
img_norm_cfg = dict(
|
4 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
5 |
+
train_pipeline = [
|
6 |
+
dict(type='LoadImageFromFile'),
|
7 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
8 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
9 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
10 |
+
dict(type='Normalize', **img_norm_cfg),
|
11 |
+
dict(type='Pad', size_divisor=32),
|
12 |
+
dict(type='DefaultFormatBundle'),
|
13 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d','gt_bboxes_3d_proj']),
|
14 |
+
]
|
15 |
+
test_pipeline = [
|
16 |
+
dict(type='LoadImageFromFile'),
|
17 |
+
dict(
|
18 |
+
type='MultiScaleFlipAug',
|
19 |
+
img_scale=(1333, 800),
|
20 |
+
flip=False,
|
21 |
+
transforms=[
|
22 |
+
dict(type='Resize', keep_ratio=True),
|
23 |
+
dict(type='RandomFlip'),
|
24 |
+
dict(type='Normalize', **img_norm_cfg),
|
25 |
+
dict(type='Pad', size_divisor=32),
|
26 |
+
dict(type='ImageToTensor', keys=['img']),
|
27 |
+
dict(type='Collect', keys=['img']),
|
28 |
+
])
|
29 |
+
]
|
30 |
+
data = dict(
|
31 |
+
samples_per_gpu=1,
|
32 |
+
workers_per_gpu=1,
|
33 |
+
train=dict(
|
34 |
+
type=dataset_type,
|
35 |
+
ann_file=data_root + 'GT_data/',
|
36 |
+
img_prefix=data_root + 'images/',
|
37 |
+
pipeline=train_pipeline),
|
38 |
+
val=dict(
|
39 |
+
type=dataset_type,
|
40 |
+
ann_file=data_root + 'GT_data/',
|
41 |
+
img_prefix=data_root + 'images/',
|
42 |
+
pipeline=test_pipeline),
|
43 |
+
test=dict(
|
44 |
+
type=dataset_type,
|
45 |
+
ann_file=data_root + 'GT_data/',
|
46 |
+
img_prefix=data_root + 'images/',
|
47 |
+
pipeline=test_pipeline))
|
48 |
+
evaluation = dict(metric=['bbox'])#, 'segm'])
|
configs/_base_/datasets/parking_instance_coco.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'ParkingCocoDataset'
|
2 |
+
data_root = 'data/parking/'
|
3 |
+
data_root_test = 'data/parking_highres/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=6,
|
33 |
+
workers_per_gpu=6,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + 'GT_data/',
|
37 |
+
img_prefix=data_root + 'images/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + 'GT_data/',
|
42 |
+
img_prefix=data_root_test + 'images',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + 'GT_data/',
|
47 |
+
img_prefix=data_root_test + 'images',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/people_real_coco.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=8,
|
33 |
+
workers_per_gpu=8,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/walt_people.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=8,
|
33 |
+
workers_per_gpu=8,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/datasets/walt_vehicle.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_type = 'WaltDataset'
|
2 |
+
data_root = 'data/cwalt_train/'
|
3 |
+
data_root_test = 'data/cwalt_test/'
|
4 |
+
img_norm_cfg = dict(
|
5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
6 |
+
train_pipeline = [
|
7 |
+
dict(type='LoadImageFromFile'),
|
8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
11 |
+
dict(type='Normalize', **img_norm_cfg),
|
12 |
+
dict(type='Pad', size_divisor=32),
|
13 |
+
dict(type='DefaultFormatBundle'),
|
14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
15 |
+
]
|
16 |
+
test_pipeline = [
|
17 |
+
dict(type='LoadImageFromFile'),
|
18 |
+
dict(
|
19 |
+
type='MultiScaleFlipAug',
|
20 |
+
img_scale=(1333, 800),
|
21 |
+
flip=False,
|
22 |
+
transforms=[
|
23 |
+
dict(type='Resize', keep_ratio=True),
|
24 |
+
dict(type='RandomFlip'),
|
25 |
+
dict(type='Normalize', **img_norm_cfg),
|
26 |
+
dict(type='Pad', size_divisor=32),
|
27 |
+
dict(type='ImageToTensor', keys=['img']),
|
28 |
+
dict(type='Collect', keys=['img']),
|
29 |
+
])
|
30 |
+
]
|
31 |
+
data = dict(
|
32 |
+
samples_per_gpu=5,
|
33 |
+
workers_per_gpu=5,
|
34 |
+
train=dict(
|
35 |
+
type=dataset_type,
|
36 |
+
ann_file=data_root + '/',
|
37 |
+
img_prefix=data_root + '/',
|
38 |
+
pipeline=train_pipeline),
|
39 |
+
val=dict(
|
40 |
+
type=dataset_type,
|
41 |
+
ann_file=data_root_test + '/',
|
42 |
+
img_prefix=data_root_test + '/',
|
43 |
+
pipeline=test_pipeline),
|
44 |
+
test=dict(
|
45 |
+
type=dataset_type,
|
46 |
+
ann_file=data_root_test + '/',
|
47 |
+
img_prefix=data_root_test + '/',
|
48 |
+
pipeline=test_pipeline))
|
49 |
+
evaluation = dict(metric=['bbox', 'segm'])
|
configs/_base_/default_runtime.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_config = dict(interval=1)
|
2 |
+
# yapf:disable
|
3 |
+
log_config = dict(
|
4 |
+
interval=50,
|
5 |
+
hooks=[
|
6 |
+
dict(type='TextLoggerHook'),
|
7 |
+
# dict(type='TensorboardLoggerHook')
|
8 |
+
])
|
9 |
+
# yapf:enable
|
10 |
+
custom_hooks = [dict(type='NumClassCheckHook')]
|
11 |
+
|
12 |
+
dist_params = dict(backend='nccl')
|
13 |
+
log_level = 'INFO'
|
14 |
+
load_from = None
|
15 |
+
resume_from = None
|
16 |
+
workflow = [('train', 1)]
|
configs/_base_/models/mask_rcnn_swin_fpn.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model settings
|
2 |
+
model = dict(
|
3 |
+
type='MaskRCNN',
|
4 |
+
pretrained=None,
|
5 |
+
backbone=dict(
|
6 |
+
type='SwinTransformer',
|
7 |
+
embed_dim=96,
|
8 |
+
depths=[2, 2, 6, 2],
|
9 |
+
num_heads=[3, 6, 12, 24],
|
10 |
+
window_size=7,
|
11 |
+
mlp_ratio=4.,
|
12 |
+
qkv_bias=True,
|
13 |
+
qk_scale=None,
|
14 |
+
drop_rate=0.,
|
15 |
+
attn_drop_rate=0.,
|
16 |
+
drop_path_rate=0.2,
|
17 |
+
ape=False,
|
18 |
+
patch_norm=True,
|
19 |
+
out_indices=(0, 1, 2, 3),
|
20 |
+
use_checkpoint=False),
|
21 |
+
neck=dict(
|
22 |
+
type='FPN',
|
23 |
+
in_channels=[96, 192, 384, 768],
|
24 |
+
out_channels=256,
|
25 |
+
num_outs=5),
|
26 |
+
rpn_head=dict(
|
27 |
+
type='RPNHead',
|
28 |
+
in_channels=256,
|
29 |
+
feat_channels=256,
|
30 |
+
anchor_generator=dict(
|
31 |
+
type='AnchorGenerator',
|
32 |
+
scales=[8],
|
33 |
+
ratios=[0.5, 1.0, 2.0],
|
34 |
+
strides=[4, 8, 16, 32, 64]),
|
35 |
+
bbox_coder=dict(
|
36 |
+
type='DeltaXYWHBBoxCoder',
|
37 |
+
target_means=[.0, .0, .0, .0],
|
38 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
39 |
+
loss_cls=dict(
|
40 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
41 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
42 |
+
roi_head=dict(
|
43 |
+
type='StandardRoIHead',
|
44 |
+
bbox_roi_extractor=dict(
|
45 |
+
type='SingleRoIExtractor',
|
46 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
47 |
+
out_channels=256,
|
48 |
+
featmap_strides=[4, 8, 16, 32]),
|
49 |
+
bbox_head=dict(
|
50 |
+
type='Shared2FCBBoxHead',
|
51 |
+
in_channels=256,
|
52 |
+
fc_out_channels=1024,
|
53 |
+
roi_feat_size=7,
|
54 |
+
num_classes=80,
|
55 |
+
bbox_coder=dict(
|
56 |
+
type='DeltaXYWHBBoxCoder',
|
57 |
+
target_means=[0., 0., 0., 0.],
|
58 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
59 |
+
reg_class_agnostic=False,
|
60 |
+
loss_cls=dict(
|
61 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
62 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
63 |
+
mask_roi_extractor=dict(
|
64 |
+
type='SingleRoIExtractor',
|
65 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
66 |
+
out_channels=256,
|
67 |
+
featmap_strides=[4, 8, 16, 32]),
|
68 |
+
mask_head=dict(
|
69 |
+
type='FCNMaskHead',
|
70 |
+
num_convs=4,
|
71 |
+
in_channels=256,
|
72 |
+
conv_out_channels=256,
|
73 |
+
num_classes=80,
|
74 |
+
loss_mask=dict(
|
75 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
76 |
+
# model training and testing settings
|
77 |
+
train_cfg=dict(
|
78 |
+
rpn=dict(
|
79 |
+
assigner=dict(
|
80 |
+
type='MaxIoUAssigner',
|
81 |
+
pos_iou_thr=0.7,
|
82 |
+
neg_iou_thr=0.3,
|
83 |
+
min_pos_iou=0.3,
|
84 |
+
match_low_quality=True,
|
85 |
+
ignore_iof_thr=-1),
|
86 |
+
sampler=dict(
|
87 |
+
type='RandomSampler',
|
88 |
+
num=256,
|
89 |
+
pos_fraction=0.5,
|
90 |
+
neg_pos_ub=-1,
|
91 |
+
add_gt_as_proposals=False),
|
92 |
+
allowed_border=-1,
|
93 |
+
pos_weight=-1,
|
94 |
+
debug=False),
|
95 |
+
rpn_proposal=dict(
|
96 |
+
nms_pre=2000,
|
97 |
+
max_per_img=1000,
|
98 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
99 |
+
min_bbox_size=0),
|
100 |
+
rcnn=dict(
|
101 |
+
assigner=dict(
|
102 |
+
type='MaxIoUAssigner',
|
103 |
+
pos_iou_thr=0.5,
|
104 |
+
neg_iou_thr=0.5,
|
105 |
+
min_pos_iou=0.5,
|
106 |
+
match_low_quality=True,
|
107 |
+
ignore_iof_thr=-1),
|
108 |
+
sampler=dict(
|
109 |
+
type='RandomSampler',
|
110 |
+
num=512,
|
111 |
+
pos_fraction=0.25,
|
112 |
+
neg_pos_ub=-1,
|
113 |
+
add_gt_as_proposals=True),
|
114 |
+
mask_size=28,
|
115 |
+
pos_weight=-1,
|
116 |
+
debug=False)),
|
117 |
+
test_cfg=dict(
|
118 |
+
rpn=dict(
|
119 |
+
nms_pre=1000,
|
120 |
+
max_per_img=1000,
|
121 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
122 |
+
min_bbox_size=0),
|
123 |
+
rcnn=dict(
|
124 |
+
score_thr=0.05,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
126 |
+
max_per_img=100,
|
127 |
+
mask_thr_binary=0.5)))
|
configs/_base_/models/occ_mask_rcnn_swin_fpn.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model settings
|
2 |
+
model = dict(
|
3 |
+
type='MaskRCNN',
|
4 |
+
pretrained=None,
|
5 |
+
backbone=dict(
|
6 |
+
type='SwinTransformer',
|
7 |
+
embed_dim=96,
|
8 |
+
depths=[2, 2, 6, 2],
|
9 |
+
num_heads=[3, 6, 12, 24],
|
10 |
+
window_size=7,
|
11 |
+
mlp_ratio=4.,
|
12 |
+
qkv_bias=True,
|
13 |
+
qk_scale=None,
|
14 |
+
drop_rate=0.,
|
15 |
+
attn_drop_rate=0.,
|
16 |
+
drop_path_rate=0.2,
|
17 |
+
ape=False,
|
18 |
+
patch_norm=True,
|
19 |
+
out_indices=(0, 1, 2, 3),
|
20 |
+
use_checkpoint=False),
|
21 |
+
neck=dict(
|
22 |
+
type='FPN',
|
23 |
+
in_channels=[96, 192, 384, 768],
|
24 |
+
out_channels=256,
|
25 |
+
num_outs=5),
|
26 |
+
rpn_head=dict(
|
27 |
+
type='RPNHead',
|
28 |
+
in_channels=256,
|
29 |
+
feat_channels=256,
|
30 |
+
anchor_generator=dict(
|
31 |
+
type='AnchorGenerator',
|
32 |
+
scales=[8],
|
33 |
+
ratios=[0.5, 1.0, 2.0],
|
34 |
+
strides=[4, 8, 16, 32, 64]),
|
35 |
+
bbox_coder=dict(
|
36 |
+
type='DeltaXYWHBBoxCoder',
|
37 |
+
target_means=[.0, .0, .0, .0],
|
38 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
39 |
+
loss_cls=dict(
|
40 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
41 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
42 |
+
roi_head=dict(
|
43 |
+
type='StandardRoIHead',
|
44 |
+
bbox_roi_extractor=dict(
|
45 |
+
type='SingleRoIExtractor',
|
46 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
47 |
+
out_channels=256,
|
48 |
+
featmap_strides=[4, 8, 16, 32]),
|
49 |
+
bbox_head=dict(
|
50 |
+
type='Shared2FCBBoxHead',
|
51 |
+
in_channels=256,
|
52 |
+
fc_out_channels=1024,
|
53 |
+
roi_feat_size=7,
|
54 |
+
num_classes=80,
|
55 |
+
bbox_coder=dict(
|
56 |
+
type='DeltaXYWHBBoxCoder',
|
57 |
+
target_means=[0., 0., 0., 0.],
|
58 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
59 |
+
reg_class_agnostic=False,
|
60 |
+
loss_cls=dict(
|
61 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
62 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
63 |
+
mask_roi_extractor=dict(
|
64 |
+
type='SingleRoIExtractor',
|
65 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
66 |
+
out_channels=256,
|
67 |
+
featmap_strides=[4, 8, 16, 32]),
|
68 |
+
mask_head=dict(
|
69 |
+
type='FCNOccMaskHead',
|
70 |
+
num_convs=4,
|
71 |
+
in_channels=256,
|
72 |
+
conv_out_channels=256,
|
73 |
+
num_classes=80,
|
74 |
+
loss_mask=dict(
|
75 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
76 |
+
# model training and testing settings
|
77 |
+
train_cfg=dict(
|
78 |
+
rpn=dict(
|
79 |
+
assigner=dict(
|
80 |
+
type='MaxIoUAssigner',
|
81 |
+
pos_iou_thr=0.7,
|
82 |
+
neg_iou_thr=0.3,
|
83 |
+
min_pos_iou=0.3,
|
84 |
+
match_low_quality=True,
|
85 |
+
ignore_iof_thr=-1),
|
86 |
+
sampler=dict(
|
87 |
+
type='RandomSampler',
|
88 |
+
num=256,
|
89 |
+
pos_fraction=0.5,
|
90 |
+
neg_pos_ub=-1,
|
91 |
+
add_gt_as_proposals=False),
|
92 |
+
allowed_border=-1,
|
93 |
+
pos_weight=-1,
|
94 |
+
debug=False),
|
95 |
+
rpn_proposal=dict(
|
96 |
+
nms_pre=2000,
|
97 |
+
max_per_img=1000,
|
98 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
99 |
+
min_bbox_size=0),
|
100 |
+
rcnn=dict(
|
101 |
+
assigner=dict(
|
102 |
+
type='MaxIoUAssigner',
|
103 |
+
pos_iou_thr=0.5,
|
104 |
+
neg_iou_thr=0.5,
|
105 |
+
min_pos_iou=0.5,
|
106 |
+
match_low_quality=True,
|
107 |
+
ignore_iof_thr=-1),
|
108 |
+
sampler=dict(
|
109 |
+
type='RandomSampler',
|
110 |
+
num=512,
|
111 |
+
pos_fraction=0.25,
|
112 |
+
neg_pos_ub=-1,
|
113 |
+
add_gt_as_proposals=True),
|
114 |
+
mask_size=28,
|
115 |
+
pos_weight=-1,
|
116 |
+
debug=False)),
|
117 |
+
test_cfg=dict(
|
118 |
+
rpn=dict(
|
119 |
+
nms_pre=1000,
|
120 |
+
max_per_img=1000,
|
121 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
122 |
+
min_bbox_size=0),
|
123 |
+
rcnn=dict(
|
124 |
+
score_thr=0.05,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
126 |
+
max_per_img=100,
|
127 |
+
mask_thr_binary=0.5)))
|
configs/_base_/schedules/schedule_1x.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# optimizer
|
2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
3 |
+
optimizer_config = dict(grad_clip=None)
|
4 |
+
# learning policy
|
5 |
+
lr_config = dict(
|
6 |
+
policy='step',
|
7 |
+
warmup='linear',
|
8 |
+
warmup_iters=500,
|
9 |
+
warmup_ratio=0.001,
|
10 |
+
step=[8, 11])
|
11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=12)
|
configs/walt/walt_people.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/occ_mask_rcnn_swin_fpn.py',
|
3 |
+
'../_base_/datasets/walt_people.py',
|
4 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
model = dict(
|
8 |
+
backbone=dict(
|
9 |
+
embed_dim=96,
|
10 |
+
depths=[2, 2, 6, 2],
|
11 |
+
num_heads=[3, 6, 12, 24],
|
12 |
+
window_size=7,
|
13 |
+
ape=False,
|
14 |
+
drop_path_rate=0.1,
|
15 |
+
patch_norm=True,
|
16 |
+
use_checkpoint=False
|
17 |
+
),
|
18 |
+
neck=dict(in_channels=[96, 192, 384, 768]))
|
19 |
+
|
20 |
+
img_norm_cfg = dict(
|
21 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
22 |
+
|
23 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
24 |
+
train_pipeline = [
|
25 |
+
dict(type='LoadImageFromFile'),
|
26 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
27 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
28 |
+
dict(type='AutoAugment',
|
29 |
+
policies=[
|
30 |
+
[
|
31 |
+
dict(type='Resize',
|
32 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
33 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
34 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
35 |
+
multiscale_mode='value',
|
36 |
+
keep_ratio=True)
|
37 |
+
],
|
38 |
+
[
|
39 |
+
dict(type='Resize',
|
40 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
41 |
+
multiscale_mode='value',
|
42 |
+
keep_ratio=True),
|
43 |
+
dict(type='RandomCrop',
|
44 |
+
crop_type='absolute_range',
|
45 |
+
crop_size=(384, 600),
|
46 |
+
allow_negative_crop=True),
|
47 |
+
dict(type='Resize',
|
48 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
49 |
+
(576, 1333), (608, 1333), (640, 1333),
|
50 |
+
(672, 1333), (704, 1333), (736, 1333),
|
51 |
+
(768, 1333), (800, 1333)],
|
52 |
+
multiscale_mode='value',
|
53 |
+
override=True,
|
54 |
+
keep_ratio=True)
|
55 |
+
]
|
56 |
+
]),
|
57 |
+
dict(type='Normalize', **img_norm_cfg),
|
58 |
+
dict(type='Pad', size_divisor=32),
|
59 |
+
dict(type='DefaultFormatBundle'),
|
60 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
61 |
+
]
|
62 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
63 |
+
|
64 |
+
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
65 |
+
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
|
66 |
+
'relative_position_bias_table': dict(decay_mult=0.),
|
67 |
+
'norm': dict(decay_mult=0.)}))
|
68 |
+
lr_config = dict(step=[8, 11])
|
69 |
+
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
|
70 |
+
|
71 |
+
# do not use mmdet version fp16
|
72 |
+
fp16 = None
|
73 |
+
optimizer_config = dict(
|
74 |
+
type="DistOptimizerHook",
|
75 |
+
update_interval=1,
|
76 |
+
grad_clip=None,
|
77 |
+
coalesce=True,
|
78 |
+
bucket_size_mb=-1,
|
79 |
+
use_fp16=True,
|
80 |
+
)
|
configs/walt/walt_vehicle.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/models/occ_mask_rcnn_swin_fpn.py',
|
3 |
+
'../_base_/datasets/walt_vehicle.py',
|
4 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
5 |
+
]
|
6 |
+
|
7 |
+
model = dict(
|
8 |
+
backbone=dict(
|
9 |
+
embed_dim=96,
|
10 |
+
depths=[2, 2, 6, 2],
|
11 |
+
num_heads=[3, 6, 12, 24],
|
12 |
+
window_size=7,
|
13 |
+
ape=False,
|
14 |
+
drop_path_rate=0.1,
|
15 |
+
patch_norm=True,
|
16 |
+
use_checkpoint=False
|
17 |
+
),
|
18 |
+
neck=dict(in_channels=[96, 192, 384, 768]))
|
19 |
+
|
20 |
+
img_norm_cfg = dict(
|
21 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
22 |
+
|
23 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
24 |
+
train_pipeline = [
|
25 |
+
dict(type='LoadImageFromFile'),
|
26 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
27 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
28 |
+
dict(type='AutoAugment',
|
29 |
+
policies=[
|
30 |
+
[
|
31 |
+
dict(type='Resize',
|
32 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
33 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
34 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
35 |
+
multiscale_mode='value',
|
36 |
+
keep_ratio=True)
|
37 |
+
],
|
38 |
+
[
|
39 |
+
dict(type='Resize',
|
40 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
41 |
+
multiscale_mode='value',
|
42 |
+
keep_ratio=True),
|
43 |
+
dict(type='RandomCrop',
|
44 |
+
crop_type='absolute_range',
|
45 |
+
crop_size=(384, 600),
|
46 |
+
allow_negative_crop=True),
|
47 |
+
dict(type='Resize',
|
48 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
49 |
+
(576, 1333), (608, 1333), (640, 1333),
|
50 |
+
(672, 1333), (704, 1333), (736, 1333),
|
51 |
+
(768, 1333), (800, 1333)],
|
52 |
+
multiscale_mode='value',
|
53 |
+
override=True,
|
54 |
+
keep_ratio=True)
|
55 |
+
]
|
56 |
+
]),
|
57 |
+
dict(type='Normalize', **img_norm_cfg),
|
58 |
+
dict(type='Pad', size_divisor=32),
|
59 |
+
dict(type='DefaultFormatBundle'),
|
60 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
61 |
+
]
|
62 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
63 |
+
|
64 |
+
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
65 |
+
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
|
66 |
+
'relative_position_bias_table': dict(decay_mult=0.),
|
67 |
+
'norm': dict(decay_mult=0.)}))
|
68 |
+
lr_config = dict(step=[8, 11])
|
69 |
+
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
|
70 |
+
|
71 |
+
# do not use mmdet version fp16
|
72 |
+
fp16 = None
|
73 |
+
optimizer_config = dict(
|
74 |
+
type="DistOptimizerHook",
|
75 |
+
update_interval=1,
|
76 |
+
grad_clip=None,
|
77 |
+
coalesce=True,
|
78 |
+
bucket_size_mb=-1,
|
79 |
+
use_fp16=True,
|
80 |
+
)
|
cwalt/CWALT.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Oct 19 19:14:47 2021
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
import glob
|
9 |
+
from .utils import bb_intersection_over_union_unoccluded
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import datetime
|
13 |
+
import cv2
|
14 |
+
import os
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_image(time, folder):
|
19 |
+
for week_loop in range(5):
|
20 |
+
try:
|
21 |
+
image = np.array(Image.open(folder+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
22 |
+
break
|
23 |
+
except:
|
24 |
+
continue
|
25 |
+
if image is None:
|
26 |
+
print('file not found')
|
27 |
+
return image
|
28 |
+
|
29 |
+
def get_mask(segm, image):
|
30 |
+
poly = np.array(segm).reshape((int(len(segm)/2), 2))
|
31 |
+
mask = image.copy()*0
|
32 |
+
cv2.fillConvexPoly(mask, poly, (255, 255, 255))
|
33 |
+
return mask
|
34 |
+
|
35 |
+
def get_unoccluded(indices, tracks_all):
|
36 |
+
unoccluded_indexes = []
|
37 |
+
unoccluded_index_all =[]
|
38 |
+
while 1:
|
39 |
+
unoccluded_clusters = []
|
40 |
+
len_unocc = len(unoccluded_indexes)
|
41 |
+
for ind in indices:
|
42 |
+
if ind in unoccluded_indexes:
|
43 |
+
continue
|
44 |
+
occ = False
|
45 |
+
for ind_compare in indices:
|
46 |
+
if ind_compare in unoccluded_indexes:
|
47 |
+
continue
|
48 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare]) > 0.01 and ind_compare != ind:
|
49 |
+
occ = True
|
50 |
+
if occ==False:
|
51 |
+
unoccluded_indexes.extend([ind])
|
52 |
+
unoccluded_clusters.extend([ind])
|
53 |
+
if len(unoccluded_indexes) == len_unocc and len_unocc != 0:
|
54 |
+
for ind in indices:
|
55 |
+
if ind not in unoccluded_indexes:
|
56 |
+
unoccluded_indexes.extend([ind])
|
57 |
+
unoccluded_clusters.extend([ind])
|
58 |
+
|
59 |
+
unoccluded_index_all.append(unoccluded_clusters)
|
60 |
+
if len(unoccluded_indexes) > len(indices)-5:
|
61 |
+
break
|
62 |
+
return unoccluded_index_all
|
63 |
+
|
64 |
+
def primes(n): # simple sieve of multiples
|
65 |
+
odds = range(3, n+1, 2)
|
66 |
+
sieve = set(sum([list(range(q*q, n+1, q+q)) for q in odds], []))
|
67 |
+
return [2] + [p for p in odds if p not in sieve]
|
68 |
+
|
69 |
+
def save_image(image_read, save_path, data, path):
|
70 |
+
tracks = data['tracks_all_unoccluded']
|
71 |
+
segmentations = data['segmentation_all_unoccluded']
|
72 |
+
timestamps = data['timestamps_final_unoccluded']
|
73 |
+
|
74 |
+
image = image_read.copy()
|
75 |
+
indices = np.random.randint(len(tracks),size=30)
|
76 |
+
prime_numbers = primes(1000)
|
77 |
+
unoccluded_index_all = get_unoccluded(indices, tracks)
|
78 |
+
|
79 |
+
mask_stacked = image*0
|
80 |
+
mask_stacked_all =[]
|
81 |
+
count = 0
|
82 |
+
time = datetime.datetime.now()
|
83 |
+
|
84 |
+
for l in indices:
|
85 |
+
try:
|
86 |
+
image_crop = get_image(timestamps[l], path)
|
87 |
+
except:
|
88 |
+
continue
|
89 |
+
try:
|
90 |
+
bb_left, bb_top, bb_width, bb_height, confidence = tracks[l]
|
91 |
+
except:
|
92 |
+
bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[l]
|
93 |
+
mask = get_mask(segmentations[l], image)
|
94 |
+
|
95 |
+
image[mask > 0] = image_crop[mask > 0]
|
96 |
+
mask[mask > 0] = 1
|
97 |
+
for count, mask_inc in enumerate(mask_stacked_all):
|
98 |
+
mask_stacked_all[count][cv2.bitwise_and(mask, mask_inc) > 0] = 2
|
99 |
+
mask_stacked_all.append(mask)
|
100 |
+
mask_stacked += mask
|
101 |
+
count = count+1
|
102 |
+
|
103 |
+
cv2.imwrite(save_path + '/images/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', image[:, :, ::-1])
|
104 |
+
cv2.imwrite(save_path + '/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', mask_stacked[:, :, ::-1]*30)
|
105 |
+
np.savez_compressed(save_path+'/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0], mask=mask_stacked_all)
|
106 |
+
|
107 |
+
def CWALT_Generation(camera_name):
|
108 |
+
save_path_train = 'data/cwalt_train'
|
109 |
+
save_path_test = 'data/cwalt_test'
|
110 |
+
|
111 |
+
json_file_path = 'data/{}/{}.json'.format(camera_name,camera_name) # iii1/iii1_7_test.json' # './data.json'
|
112 |
+
path = 'data/' + camera_name
|
113 |
+
|
114 |
+
data = np.load(json_file_path + '.npz', allow_pickle=True)
|
115 |
+
|
116 |
+
## slip data
|
117 |
+
|
118 |
+
data_train=dict()
|
119 |
+
data_test=dict()
|
120 |
+
|
121 |
+
split_index = int(len(data['timestamps_final_unoccluded'])*0.8)
|
122 |
+
|
123 |
+
data_train['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][0:split_index]
|
124 |
+
data_train['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][0:split_index]
|
125 |
+
data_train['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][0:split_index]
|
126 |
+
|
127 |
+
data_test['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][split_index:]
|
128 |
+
data_test['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][split_index:]
|
129 |
+
data_test['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][split_index:]
|
130 |
+
|
131 |
+
image_read = np.array(Image.open(path + '/T18-median_image.jpg'))
|
132 |
+
image_read = cv2.resize(image_read, (int(image_read.shape[1]/2), int(image_read.shape[0]/2)))
|
133 |
+
|
134 |
+
try:
|
135 |
+
os.mkdir(save_path_train)
|
136 |
+
except:
|
137 |
+
print(save_path_train)
|
138 |
+
|
139 |
+
try:
|
140 |
+
os.mkdir(save_path_train + '/images')
|
141 |
+
os.mkdir(save_path_train + '/Segmentation')
|
142 |
+
except:
|
143 |
+
print(save_path_train+ '/images')
|
144 |
+
|
145 |
+
try:
|
146 |
+
os.mkdir(save_path_test)
|
147 |
+
except:
|
148 |
+
print(save_path_test)
|
149 |
+
|
150 |
+
try:
|
151 |
+
os.mkdir(save_path_test + '/images')
|
152 |
+
os.mkdir(save_path_test + '/Segmentation')
|
153 |
+
except:
|
154 |
+
print(save_path_test+ '/images')
|
155 |
+
|
156 |
+
for loop in tqdm(range(3000), desc="Generating training CWALT Images "):
|
157 |
+
save_image(image_read, save_path_train, data_train, path)
|
158 |
+
|
159 |
+
for loop in tqdm(range(300), desc="Generating testing CWALT Images "):
|
160 |
+
save_image(image_read, save_path_test, data_test, path)
|
161 |
+
|
cwalt/Clip_WALT_Generate.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:15:11 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
from collections import OrderedDict
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from .utils import *
|
12 |
+
import scipy.interpolate
|
13 |
+
|
14 |
+
from scipy import interpolate
|
15 |
+
from .clustering_utils import *
|
16 |
+
import glob
|
17 |
+
import cv2
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
|
21 |
+
import json
|
22 |
+
import cv2
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
|
28 |
+
def ignore_indexes(tracks_all, labels_all):
|
29 |
+
# get repeating bounding boxes
|
30 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x == y]
|
31 |
+
ignore_ind = []
|
32 |
+
for index, track in enumerate(tracks_all):
|
33 |
+
print('in ignore', index, len(tracks_all))
|
34 |
+
if index in ignore_ind:
|
35 |
+
continue
|
36 |
+
|
37 |
+
if labels_all[index] < 1 or labels_all[index] > 3:
|
38 |
+
ignore_ind.extend([index])
|
39 |
+
|
40 |
+
ind = get_indexes(track, tracks_all)
|
41 |
+
if len(ind) > 30:
|
42 |
+
ignore_ind.extend(ind)
|
43 |
+
|
44 |
+
return ignore_ind
|
45 |
+
|
46 |
+
def repeated_indexes_old(tracks_all,ignore_ind, unoccluded_indexes=None):
|
47 |
+
# get repeating bounding boxes
|
48 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
|
49 |
+
repeat_ind = []
|
50 |
+
repeat_inds =[]
|
51 |
+
if unoccluded_indexes == None:
|
52 |
+
for index, track in enumerate(tracks_all):
|
53 |
+
if index in repeat_ind or index in ignore_ind:
|
54 |
+
continue
|
55 |
+
ind = get_indexes(track, tracks_all)
|
56 |
+
if len(ind) > 20:
|
57 |
+
repeat_ind.extend(ind)
|
58 |
+
repeat_inds.append([ind,track])
|
59 |
+
else:
|
60 |
+
for index in unoccluded_indexes:
|
61 |
+
if index in repeat_ind or index in ignore_ind:
|
62 |
+
continue
|
63 |
+
ind = get_indexes(tracks_all[index], tracks_all)
|
64 |
+
if len(ind) > 3:
|
65 |
+
repeat_ind.extend(ind)
|
66 |
+
repeat_inds.append([ind,tracks_all[index]])
|
67 |
+
return repeat_inds
|
68 |
+
|
69 |
+
def get_unoccluded_instances(timestamps_final, tracks_all, ignore_ind=[], threshold = 0.01):
|
70 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
|
71 |
+
unoccluded_indexes = []
|
72 |
+
time_checked = []
|
73 |
+
stationary_obj = []
|
74 |
+
count =0
|
75 |
+
|
76 |
+
for time in tqdm(np.unique(timestamps_final), desc="Detecting Unocclued objects in Image "):
|
77 |
+
count += 1
|
78 |
+
if [time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond] in time_checked:
|
79 |
+
analyze_bb = []
|
80 |
+
for ind in unoccluded_indexes_time:
|
81 |
+
for ind_compare in same_time_instances:
|
82 |
+
iou = bb_intersection_over_union(tracks_all[ind], tracks_all[ind_compare])
|
83 |
+
if iou < 0.5 and iou > 0:
|
84 |
+
analyze_bb.extend([ind_compare])
|
85 |
+
if iou > 0.99:
|
86 |
+
stationary_obj.extend([str(ind_compare)+'+'+str(ind)])
|
87 |
+
|
88 |
+
for ind in analyze_bb:
|
89 |
+
occ = False
|
90 |
+
for ind_compare in same_time_instances:
|
91 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind:
|
92 |
+
occ = True
|
93 |
+
break
|
94 |
+
if occ == False:
|
95 |
+
unoccluded_indexes.extend([ind])
|
96 |
+
continue
|
97 |
+
|
98 |
+
same_time_instances = get_indexes(time,timestamps_final)
|
99 |
+
unoccluded_indexes_time = []
|
100 |
+
|
101 |
+
for ind in same_time_instances:
|
102 |
+
if tracks_all[ind][4] < 0.9 or ind in ignore_ind:# or ind != 1859:
|
103 |
+
continue
|
104 |
+
occ = False
|
105 |
+
for ind_compare in same_time_instances:
|
106 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind and tracks_all[ind_compare][4] < 0.5:
|
107 |
+
occ = True
|
108 |
+
break
|
109 |
+
if occ==False:
|
110 |
+
unoccluded_indexes.extend([ind])
|
111 |
+
unoccluded_indexes_time.extend([ind])
|
112 |
+
time_checked.append([time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond])
|
113 |
+
return unoccluded_indexes,stationary_obj
|
114 |
+
|
115 |
+
def visualize_unoccluded_detection(timestamps_final,tracks_all,segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name, ignore_ind=[]):
|
116 |
+
tracks_final = []
|
117 |
+
tracks_final.append([])
|
118 |
+
try:
|
119 |
+
os.mkdir(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/')
|
120 |
+
except:
|
121 |
+
print('Unoccluded debugging exists')
|
122 |
+
|
123 |
+
for time in tqdm(np.unique(timestamps_final), desc="Visualizing Unocclued objects in Image "):
|
124 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
|
125 |
+
ind = get_indexes(time, timestamps_final)
|
126 |
+
image_unocc = False
|
127 |
+
for index in ind:
|
128 |
+
if index not in unoccluded_indexes:
|
129 |
+
continue
|
130 |
+
else:
|
131 |
+
image_unocc = True
|
132 |
+
break
|
133 |
+
if image_unocc == False:
|
134 |
+
continue
|
135 |
+
|
136 |
+
for week_loop in range(5):
|
137 |
+
try:
|
138 |
+
image = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
139 |
+
break
|
140 |
+
except:
|
141 |
+
continue
|
142 |
+
|
143 |
+
try:
|
144 |
+
mask = image*0
|
145 |
+
except:
|
146 |
+
print('image not found for ' + str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg' )
|
147 |
+
continue
|
148 |
+
image_original = image.copy()
|
149 |
+
|
150 |
+
for index in ind:
|
151 |
+
track = tracks_all[index]
|
152 |
+
|
153 |
+
if index in ignore_ind:
|
154 |
+
continue
|
155 |
+
if index not in unoccluded_indexes:
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
bb_left, bb_top, bb_width, bb_height, confidence, id = track
|
159 |
+
except:
|
160 |
+
bb_left, bb_top, bb_width, bb_height, confidence = track
|
161 |
+
|
162 |
+
if confidence > 0.6:
|
163 |
+
mask = poly_seg(image, segmentation_all[index])
|
164 |
+
cv2.imwrite(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/' + str(index)+'.png', mask[:, :, ::-1])
|
165 |
+
|
166 |
+
def repeated_indexes(tracks_all,ignore_ind, repeat_count = 10, unoccluded_indexes=None):
|
167 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
|
168 |
+
repeat_ind = []
|
169 |
+
repeat_inds =[]
|
170 |
+
if unoccluded_indexes == None:
|
171 |
+
for index, track in enumerate(tracks_all):
|
172 |
+
if index in repeat_ind or index in ignore_ind:
|
173 |
+
continue
|
174 |
+
|
175 |
+
ind = get_indexes(track, tracks_all)
|
176 |
+
if len(ind) > repeat_count:
|
177 |
+
repeat_ind.extend(ind)
|
178 |
+
repeat_inds.append([ind,track])
|
179 |
+
else:
|
180 |
+
for index in unoccluded_indexes:
|
181 |
+
if index in repeat_ind or index in ignore_ind:
|
182 |
+
continue
|
183 |
+
ind = get_indexes(tracks_all[index], tracks_all)
|
184 |
+
if len(ind) > repeat_count:
|
185 |
+
repeat_ind.extend(ind)
|
186 |
+
repeat_inds.append([ind,tracks_all[index]])
|
187 |
+
|
188 |
+
|
189 |
+
return repeat_inds
|
190 |
+
|
191 |
+
def poly_seg(image, segm):
|
192 |
+
poly = np.array(segm).reshape((int(len(segm)/2), 2))
|
193 |
+
overlay = image.copy()
|
194 |
+
alpha = 0.5
|
195 |
+
cv2.fillPoly(overlay, [poly], color=(255, 255, 0))
|
196 |
+
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
|
197 |
+
return image
|
198 |
+
|
199 |
+
def visualize_unoccuded_clusters(repeat_inds, tracks, segmentation_all, timestamps_final, cwalt_data_path):
|
200 |
+
for index_, repeat_ind in enumerate(repeat_inds):
|
201 |
+
image = np.array(Image.open(cwalt_data_path+'/'+'T18-median_image.jpg'))
|
202 |
+
try:
|
203 |
+
os.mkdir(cwalt_data_path+ '/Cwalt_database/')
|
204 |
+
except:
|
205 |
+
print('folder exists')
|
206 |
+
try:
|
207 |
+
os.mkdir(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
|
208 |
+
except:
|
209 |
+
print(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
|
210 |
+
|
211 |
+
for i in repeat_ind[0]:
|
212 |
+
try:
|
213 |
+
bb_left, bb_top, bb_width, bb_height, confidence = tracks[i]#bbox
|
214 |
+
except:
|
215 |
+
bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[i]#bbox
|
216 |
+
|
217 |
+
cv2.rectangle(image,(int(bb_left), int(bb_top)),(int(bb_left+bb_width), int(bb_top+bb_height)),(0, 0, 255), 2)
|
218 |
+
time = timestamps_final[i]
|
219 |
+
for week_loop in range(5):
|
220 |
+
try:
|
221 |
+
image1 = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
222 |
+
break
|
223 |
+
except:
|
224 |
+
continue
|
225 |
+
|
226 |
+
crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
|
227 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/o_' + str(i) +'.jpg', crop[:, :, ::-1])
|
228 |
+
image1 = poly_seg(image1,segmentation_all[i])
|
229 |
+
crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
|
230 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/' + str(i)+'.jpg', crop[:, :, ::-1])
|
231 |
+
if index_ > 100:
|
232 |
+
break
|
233 |
+
|
234 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'.jpg', image[:, :, ::-1])
|
235 |
+
|
236 |
+
def Get_unoccluded_objects(camera_name, debug = False, scale=True):
|
237 |
+
cwalt_data_path = 'data/' + camera_name
|
238 |
+
data_folder = cwalt_data_path
|
239 |
+
json_file_path = cwalt_data_path + '/' + camera_name + '.json'
|
240 |
+
|
241 |
+
with open(json_file_path, 'r') as j:
|
242 |
+
annotations = json.loads(j.read())
|
243 |
+
|
244 |
+
tracks_all = [parse_bbox(anno['bbox']) for anno in annotations]
|
245 |
+
segmentation_all = [parse_bbox(anno['segmentation']) for anno in annotations]
|
246 |
+
labels_all = [anno['label_id'] for anno in annotations]
|
247 |
+
timestamps_final = [parse(anno['time']) for anno in annotations]
|
248 |
+
|
249 |
+
if scale ==True:
|
250 |
+
scale_factor = 2
|
251 |
+
tracks_all_numpy = np.array(tracks_all)
|
252 |
+
tracks_all_numpy[:,:4] = np.array(tracks_all)[:,:4]/scale_factor
|
253 |
+
tracks_all = tracks_all_numpy.tolist()
|
254 |
+
|
255 |
+
segmentation_all_scaled = []
|
256 |
+
for list_loop in segmentation_all:
|
257 |
+
segmentation_all_scaled.append((np.floor_divide(np.array(list_loop),scale_factor)).tolist())
|
258 |
+
segmentation_all = segmentation_all_scaled
|
259 |
+
|
260 |
+
if debug == True:
|
261 |
+
timestamps_final = timestamps_final[:1000]
|
262 |
+
labels_all = labels_all[:1000]
|
263 |
+
segmentation_all = segmentation_all[:1000]
|
264 |
+
tracks_all = tracks_all[:1000]
|
265 |
+
|
266 |
+
unoccluded_indexes, stationary = get_unoccluded_instances(timestamps_final, tracks_all, threshold = 0.05)
|
267 |
+
if debug == True:
|
268 |
+
visualize_unoccluded_detection(timestamps_final, tracks_all, segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name)
|
269 |
+
|
270 |
+
tracks_all_unoccluded = [tracks_all[i] for i in unoccluded_indexes]
|
271 |
+
segmentation_all_unoccluded = [segmentation_all[i] for i in unoccluded_indexes]
|
272 |
+
labels_all_unoccluded = [labels_all[i] for i in unoccluded_indexes]
|
273 |
+
timestamps_final_unoccluded = [timestamps_final[i] for i in unoccluded_indexes]
|
274 |
+
np.savez(json_file_path,tracks_all_unoccluded=tracks_all_unoccluded, segmentation_all_unoccluded=segmentation_all_unoccluded, labels_all_unoccluded=labels_all_unoccluded, timestamps_final_unoccluded=timestamps_final_unoccluded )
|
275 |
+
|
276 |
+
if debug == True:
|
277 |
+
repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=1)
|
278 |
+
visualize_unoccuded_clusters(repeat_inds_clusters, tracks_all_unoccluded, segmentation_all_unoccluded, timestamps_final_unoccluded, cwalt_data_path)
|
279 |
+
else:
|
280 |
+
repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=10)
|
281 |
+
|
282 |
+
np.savez(json_file_path + '_clubbed', repeat_inds=repeat_inds_clusters)
|
283 |
+
np.savez(json_file_path + '_stationary', stationary=stationary)
|
284 |
+
|
cwalt/Download_Detections.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from psycopg2.extras import RealDictCursor
|
3 |
+
#import cv2
|
4 |
+
import psycopg2
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
CONNECTION = "postgres://postgres:"
|
9 |
+
|
10 |
+
conn = psycopg2.connect(CONNECTION)
|
11 |
+
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
12 |
+
|
13 |
+
|
14 |
+
def get_sample():
|
15 |
+
camera_name, camera_id = 'cam2', 4
|
16 |
+
|
17 |
+
print('Executing SQL command')
|
18 |
+
|
19 |
+
cursor.execute("SELECT * FROM annotations WHERE camera_id = {} and time >='2021-05-01 00:00:00' and time <='2021-05-07 23:59:50' and label_id in (1,2)".format(camera_id))
|
20 |
+
|
21 |
+
print('Dumping to json')
|
22 |
+
annotations = json.dumps(cursor.fetchall(), indent=2, default=str)
|
23 |
+
wjdata = json.loads(annotations)
|
24 |
+
with open('{}_{}_test.json'.format(camera_name, camera_id), 'w') as f:
|
25 |
+
json.dump(wjdata, f)
|
26 |
+
print('Done dumping to json')
|
27 |
+
|
28 |
+
get_sample()
|
cwalt/clustering_utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:18:20 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
# 0 - Import related libraries
|
10 |
+
|
11 |
+
import urllib
|
12 |
+
import zipfile
|
13 |
+
import os
|
14 |
+
import scipy.io
|
15 |
+
import math
|
16 |
+
import numpy as np
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import seaborn as sns
|
19 |
+
|
20 |
+
from scipy.spatial.distance import directed_hausdorff
|
21 |
+
from sklearn.cluster import DBSCAN
|
22 |
+
from sklearn.metrics.pairwise import pairwise_distances
|
23 |
+
import scipy.spatial.distance
|
24 |
+
|
25 |
+
from .kmedoid import kMedoids # kMedoids code is adapted from https://github.com/letiantian/kmedoids
|
26 |
+
|
27 |
+
# Some visualization stuff, not so important
|
28 |
+
# sns.set()
|
29 |
+
plt.rcParams['figure.figsize'] = (12, 12)
|
30 |
+
|
31 |
+
# Utility Functions
|
32 |
+
|
33 |
+
color_lst = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
34 |
+
color_lst.extend(['firebrick', 'olive', 'indigo', 'khaki', 'teal', 'saddlebrown',
|
35 |
+
'skyblue', 'coral', 'darkorange', 'lime', 'darkorchid', 'dimgray'])
|
36 |
+
|
37 |
+
|
38 |
+
def plot_cluster(image, traj_lst, cluster_lst):
|
39 |
+
'''
|
40 |
+
Plots given trajectories with a color that is specific for every trajectory's own cluster index.
|
41 |
+
Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
|
42 |
+
'''
|
43 |
+
cluster_count = np.max(cluster_lst) + 1
|
44 |
+
|
45 |
+
for traj, cluster in zip(traj_lst, cluster_lst):
|
46 |
+
|
47 |
+
# if cluster == -1:
|
48 |
+
# # Means it it a noisy trajectory, paint it black
|
49 |
+
# plt.plot(traj[:, 0], traj[:, 1], c='k', linestyle='dashed')
|
50 |
+
#
|
51 |
+
# else:
|
52 |
+
plt.plot(traj[:, 0], traj[:, 1], c=color_lst[cluster % len(color_lst)])
|
53 |
+
|
54 |
+
plt.imshow(image)
|
55 |
+
# plt.show()
|
56 |
+
plt.axis('off')
|
57 |
+
plt.savefig('trajectory.png', bbox_inches='tight')
|
58 |
+
plt.show()
|
59 |
+
|
60 |
+
|
61 |
+
# 3 - Distance matrix
|
62 |
+
|
63 |
+
def hausdorff( u, v):
|
64 |
+
d = max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0])
|
65 |
+
return d
|
66 |
+
|
67 |
+
|
68 |
+
def build_distance_matrix(traj_lst):
|
69 |
+
# 2 - Trajectory segmentation
|
70 |
+
|
71 |
+
print('Running trajectory segmentation...')
|
72 |
+
degree_threshold = 5
|
73 |
+
|
74 |
+
for traj_index, traj in enumerate(traj_lst):
|
75 |
+
|
76 |
+
hold_index_lst = []
|
77 |
+
previous_azimuth = 1000
|
78 |
+
|
79 |
+
for point_index, point in enumerate(traj[:-1]):
|
80 |
+
next_point = traj[point_index + 1]
|
81 |
+
diff_vector = next_point - point
|
82 |
+
azimuth = (math.degrees(math.atan2(*diff_vector)) + 360) % 360
|
83 |
+
|
84 |
+
if abs(azimuth - previous_azimuth) > degree_threshold:
|
85 |
+
hold_index_lst.append(point_index)
|
86 |
+
previous_azimuth = azimuth
|
87 |
+
hold_index_lst.append(traj.shape[0] - 1) # Last point of trajectory is always added
|
88 |
+
|
89 |
+
traj_lst[traj_index] = traj[hold_index_lst, :]
|
90 |
+
|
91 |
+
print('Building distance matrix...')
|
92 |
+
traj_count = len(traj_lst)
|
93 |
+
D = np.zeros((traj_count, traj_count))
|
94 |
+
|
95 |
+
# This may take a while
|
96 |
+
for i in range(traj_count):
|
97 |
+
if i % 20 == 0:
|
98 |
+
print(i)
|
99 |
+
for j in range(i + 1, traj_count):
|
100 |
+
distance = hausdorff(traj_lst[i], traj_lst[j])
|
101 |
+
D[i, j] = distance
|
102 |
+
D[j, i] = distance
|
103 |
+
|
104 |
+
return D
|
105 |
+
|
106 |
+
|
107 |
+
def run_kmedoids(image, traj_lst, D):
|
108 |
+
# 4 - Different clustering methods
|
109 |
+
|
110 |
+
# 4.1 - kmedoids
|
111 |
+
|
112 |
+
traj_count = len(traj_lst)
|
113 |
+
|
114 |
+
k = 3 # The number of clusters
|
115 |
+
medoid_center_lst, cluster2index_lst = kMedoids(D, k)
|
116 |
+
|
117 |
+
cluster_lst = np.empty((traj_count,), dtype=int)
|
118 |
+
|
119 |
+
for cluster in cluster2index_lst:
|
120 |
+
cluster_lst[cluster2index_lst[cluster]] = cluster
|
121 |
+
|
122 |
+
plot_cluster(image, traj_lst, cluster_lst)
|
123 |
+
|
124 |
+
|
125 |
+
def run_dbscan(image, traj_lst, D):
|
126 |
+
mdl = DBSCAN(eps=400, min_samples=10)
|
127 |
+
cluster_lst = mdl.fit_predict(D)
|
128 |
+
|
129 |
+
plot_cluster(image, traj_lst, cluster_lst)
|
130 |
+
|
131 |
+
|
132 |
+
|
cwalt/kmedoid.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:18:56 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
|
12 |
+
def kMedoids(D, k, tmax=100):
|
13 |
+
# determine dimensions of distance matrix D
|
14 |
+
m, n = D.shape
|
15 |
+
|
16 |
+
np.fill_diagonal(D, math.inf)
|
17 |
+
|
18 |
+
if k > n:
|
19 |
+
raise Exception('too many medoids')
|
20 |
+
# randomly initialize an array of k medoid indices
|
21 |
+
M = np.arange(n)
|
22 |
+
np.random.shuffle(M)
|
23 |
+
M = np.sort(M[:k])
|
24 |
+
|
25 |
+
# create a copy of the array of medoid indices
|
26 |
+
Mnew = np.copy(M)
|
27 |
+
|
28 |
+
# initialize a dictionary to represent clusters
|
29 |
+
C = {}
|
30 |
+
for t in range(tmax):
|
31 |
+
# determine clusters, i. e. arrays of data indices
|
32 |
+
J = np.argmin(D[:,M], axis=1)
|
33 |
+
|
34 |
+
for kappa in range(k):
|
35 |
+
C[kappa] = np.where(J==kappa)[0]
|
36 |
+
# update cluster medoids
|
37 |
+
for kappa in range(k):
|
38 |
+
J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1)
|
39 |
+
j = np.argmin(J)
|
40 |
+
Mnew[kappa] = C[kappa][j]
|
41 |
+
np.sort(Mnew)
|
42 |
+
# check for convergence
|
43 |
+
if np.array_equal(M, Mnew):
|
44 |
+
break
|
45 |
+
M = np.copy(Mnew)
|
46 |
+
else:
|
47 |
+
# final update of cluster memberships
|
48 |
+
J = np.argmin(D[:,M], axis=1)
|
49 |
+
for kappa in range(k):
|
50 |
+
C[kappa] = np.where(J==kappa)[0]
|
51 |
+
|
52 |
+
np.fill_diagonal(D, 0)
|
53 |
+
|
54 |
+
# return results
|
55 |
+
return M, C
|
cwalt/utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:16:56 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import cv2
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
from dateutil.parser import parse
|
14 |
+
|
15 |
+
def bb_intersection_over_union(box1, box2):
|
16 |
+
#print(box1, box2)
|
17 |
+
boxA = box1.copy()
|
18 |
+
boxB = box2.copy()
|
19 |
+
boxA[2] = boxA[0]+boxA[2]
|
20 |
+
boxA[3] = boxA[1]+boxA[3]
|
21 |
+
boxB[2] = boxB[0]+boxB[2]
|
22 |
+
boxB[3] = boxB[1]+boxB[3]
|
23 |
+
# determine the (x, y)-coordinates of the intersection rectangle
|
24 |
+
xA = max(boxA[0], boxB[0])
|
25 |
+
yA = max(boxA[1], boxB[1])
|
26 |
+
xB = min(boxA[2], boxB[2])
|
27 |
+
yB = min(boxA[3], boxB[3])
|
28 |
+
|
29 |
+
# compute the area of intersection rectangle
|
30 |
+
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
31 |
+
|
32 |
+
if interArea == 0:
|
33 |
+
return 0
|
34 |
+
# compute the area of both the prediction and ground-truth
|
35 |
+
# rectangles
|
36 |
+
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
37 |
+
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
38 |
+
|
39 |
+
# compute the intersection over union by taking the intersection
|
40 |
+
# area and dividing it by the sum of prediction + ground-truth
|
41 |
+
# areas - the interesection area
|
42 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
43 |
+
return iou
|
44 |
+
|
45 |
+
def bb_intersection_over_union_unoccluded(box1, box2, threshold=0.01):
|
46 |
+
#print(box1, box2)
|
47 |
+
boxA = box1.copy()
|
48 |
+
boxB = box2.copy()
|
49 |
+
boxA[2] = boxA[0]+boxA[2]
|
50 |
+
boxA[3] = boxA[1]+boxA[3]
|
51 |
+
boxB[2] = boxB[0]+boxB[2]
|
52 |
+
boxB[3] = boxB[1]+boxB[3]
|
53 |
+
# determine the (x, y)-coordinates of the intersection rectangle
|
54 |
+
xA = max(boxA[0], boxB[0])
|
55 |
+
yA = max(boxA[1], boxB[1])
|
56 |
+
xB = min(boxA[2], boxB[2])
|
57 |
+
yB = min(boxA[3], boxB[3])
|
58 |
+
|
59 |
+
# compute the area of intersection rectangle
|
60 |
+
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
61 |
+
|
62 |
+
if interArea == 0:
|
63 |
+
return 0
|
64 |
+
# compute the area of both the prediction and ground-truth
|
65 |
+
# rectangles
|
66 |
+
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
67 |
+
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
68 |
+
|
69 |
+
# compute the intersection over union by taking the intersection
|
70 |
+
# area and dividing it by the sum of prediction + ground-truth
|
71 |
+
# areas - the interesection area
|
72 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
73 |
+
|
74 |
+
#print(iou)
|
75 |
+
# return the intersection over union value
|
76 |
+
occlusion = False
|
77 |
+
if iou > threshold and iou < 1:
|
78 |
+
#print(boxA[3], boxB[3], boxB[1])
|
79 |
+
if boxA[3] < boxB[3]:# and boxA[3] > boxB[1]:
|
80 |
+
if boxB[2] > boxA[0]:# and boxB[2] < boxA[2]:
|
81 |
+
#print('first', (boxB[2] - boxA[0])/(boxA[2] - boxA[0]))
|
82 |
+
if (min(boxB[2],boxA[2]) - boxA[0])/(boxA[2] - boxA[0]) > threshold:
|
83 |
+
occlusion = True
|
84 |
+
|
85 |
+
if boxB[0] < boxA[2]: # boxB[0] > boxA[0] and
|
86 |
+
#print('second', (boxA[2] - boxB[0])/(boxA[2] - boxA[0]))
|
87 |
+
if (boxA[2] - max(boxB[0],boxA[0]))/(boxA[2] - boxA[0]) > threshold:
|
88 |
+
occlusion = True
|
89 |
+
if occlusion == False:
|
90 |
+
iou = iou*0
|
91 |
+
#asas
|
92 |
+
# asas
|
93 |
+
#iou = 0.9 #iou*0
|
94 |
+
#print(box1, box2, iou, occlusion)
|
95 |
+
return iou
|
96 |
+
def draw_tracks(image, tracks):
|
97 |
+
"""
|
98 |
+
Draw on input image.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
image (numpy.ndarray): image
|
102 |
+
tracks (list): list of tracks to be drawn on the image.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
numpy.ndarray: image with the track-ids drawn on it.
|
106 |
+
"""
|
107 |
+
|
108 |
+
for trk in tracks:
|
109 |
+
|
110 |
+
trk_id = trk[1]
|
111 |
+
xmin = trk[2]
|
112 |
+
ymin = trk[3]
|
113 |
+
width = trk[4]
|
114 |
+
height = trk[5]
|
115 |
+
|
116 |
+
xcentroid, ycentroid = int(xmin + 0.5*width), int(ymin + 0.5*height)
|
117 |
+
|
118 |
+
text = "ID {}".format(trk_id)
|
119 |
+
|
120 |
+
cv2.putText(image, text, (xcentroid - 10, ycentroid - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
121 |
+
cv2.circle(image, (xcentroid, ycentroid), 4, (0, 255, 0), -1)
|
122 |
+
|
123 |
+
return image
|
124 |
+
|
125 |
+
|
126 |
+
def draw_bboxes(image, tracks):
|
127 |
+
"""
|
128 |
+
Draw the bounding boxes about detected objects in the image.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
image (numpy.ndarray): Image or video frame.
|
132 |
+
bboxes (numpy.ndarray): Bounding boxes pixel coordinates as (xmin, ymin, width, height)
|
133 |
+
confidences (numpy.ndarray): Detection confidence or detection probability.
|
134 |
+
class_ids (numpy.ndarray): Array containing class ids (aka label ids) of each detected object.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
numpy.ndarray: image with the bounding boxes drawn on it.
|
138 |
+
"""
|
139 |
+
|
140 |
+
for trk in tracks:
|
141 |
+
xmin = int(trk[2])
|
142 |
+
ymin = int(trk[3])
|
143 |
+
width = int(trk[4])
|
144 |
+
height = int(trk[5])
|
145 |
+
clr = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
|
146 |
+
cv2.rectangle(image, (xmin, ymin), (xmin + width, ymin + height), clr, 2)
|
147 |
+
|
148 |
+
return image
|
149 |
+
|
150 |
+
|
151 |
+
def num(v):
|
152 |
+
number_as_float = float(v)
|
153 |
+
number_as_int = int(number_as_float)
|
154 |
+
return number_as_int if number_as_float == number_as_int else number_as_float
|
155 |
+
|
156 |
+
|
157 |
+
def parse_bbox(bbox_str):
|
158 |
+
bbox_list = bbox_str.strip('{').strip('}').split(',')
|
159 |
+
bbox_list = [num(elem) for elem in bbox_list]
|
160 |
+
return bbox_list
|
161 |
+
|
162 |
+
def parse_seg(bbox_str):
|
163 |
+
bbox_list = bbox_str.strip('{').strip('}').split(',')
|
164 |
+
bbox_list = [num(elem) for elem in bbox_list]
|
165 |
+
ret = bbox_list # []
|
166 |
+
# for i in range(0, len(bbox_list) - 1, 2):
|
167 |
+
# ret.append((bbox_list[i], bbox_list[i + 1]))
|
168 |
+
return ret
|
cwalt_generate.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sat Jun 4 16:55:58 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
from cwalt.CWALT import CWALT_Generation
|
9 |
+
from cwalt.Clip_WALT_Generate import Get_unoccluded_objects
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
camera_name = 'cam2'
|
13 |
+
Get_unoccluded_objects(camera_name)
|
14 |
+
CWALT_Generation(camera_name)
|
docker/Dockerfile
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG PYTORCH="1.9.0"
|
2 |
+
ARG CUDA="11.1"
|
3 |
+
ARG CUDNN="8"
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
6 |
+
|
7 |
+
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
|
8 |
+
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
|
9 |
+
ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
|
10 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
11 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
12 |
+
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
|
13 |
+
&& apt-get clean \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Install MMCV
|
17 |
+
#RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
|
18 |
+
# -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
|
19 |
+
RUN pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
|
20 |
+
# Install MMDetection
|
21 |
+
RUN conda clean --all
|
22 |
+
RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
|
23 |
+
WORKDIR /mmdetection
|
24 |
+
ENV FORCE_CUDA="1"
|
25 |
+
RUN cd /mmdetection && git checkout 7bd39044f35aec4b90dd797b965777541a8678ff
|
26 |
+
RUN pip install -r requirements/build.txt
|
27 |
+
RUN pip install --no-cache-dir -e .
|
28 |
+
RUN apt-get update
|
29 |
+
RUN apt-get install -y vim
|
30 |
+
RUN pip uninstall -y pycocotools
|
31 |
+
RUN pip install mmpycocotools timm scikit-image imagesize
|
32 |
+
|
33 |
+
|
34 |
+
# make sure we don't overwrite some existing directory called "apex"
|
35 |
+
WORKDIR /tmp/unique_for_apex
|
36 |
+
# uninstall Apex if present, twice to make absolutely sure :)
|
37 |
+
RUN pip uninstall -y apex || :
|
38 |
+
RUN pip uninstall -y apex || :
|
39 |
+
# SHA is something the user can touch to force recreation of this Docker layer,
|
40 |
+
# and therefore force cloning of the latest version of Apex
|
41 |
+
RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
|
42 |
+
WORKDIR /tmp/unique_for_apex/apex
|
43 |
+
RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
44 |
+
RUN pip install seaborn sklearn imantics gradio
|
45 |
+
WORKDIR /code
|
46 |
+
ENTRYPOINT ["python", "app.py"]
|
47 |
+
|
48 |
+
#RUN git clone https://github.com/NVIDIA/apex
|
49 |
+
#RUN cd apex
|
50 |
+
#RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
51 |
+
#RUN pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
52 |
+
|
github_vis/cwalt.gif
ADDED
![]() |
github_vis/vis_cars.gif
ADDED
![]() |
github_vis/vis_people.gif
ADDED
![]() |
infer.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
|
3 |
+
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
|
4 |
+
from mmdet.core.mask.utils import encode_mask_results
|
5 |
+
import numpy as np
|
6 |
+
import mmcv
|
7 |
+
import torch
|
8 |
+
from imantics import Polygons, Mask
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import cv2, glob
|
12 |
+
|
13 |
+
class detections():
|
14 |
+
def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth', threshold=0.85):
|
15 |
+
self.model = init_detector(cfg_path, model_path, device=device)
|
16 |
+
self.all_preds = []
|
17 |
+
self.all_scores = []
|
18 |
+
self.index = []
|
19 |
+
self.score_thr = threshold
|
20 |
+
self.result = []
|
21 |
+
self.record_dict = {'model': cfg_path,'results': []}
|
22 |
+
self.detect_count = []
|
23 |
+
|
24 |
+
|
25 |
+
def run_on_image(self, image):
|
26 |
+
self.result = inference_detector(self.model, image)
|
27 |
+
image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
|
28 |
+
return image_labelled
|
29 |
+
|
30 |
+
def process_output(self, count):
|
31 |
+
result = self.result
|
32 |
+
infer_result = {'url': count,
|
33 |
+
'boxes': [],
|
34 |
+
'scores': [],
|
35 |
+
'keypoints': [],
|
36 |
+
'segmentation': [],
|
37 |
+
'label_ids': [],
|
38 |
+
'track': [],
|
39 |
+
'labels': []}
|
40 |
+
|
41 |
+
if isinstance(result, tuple):
|
42 |
+
bbox_result, segm_result = result
|
43 |
+
#segm_result = encode_mask_results(segm_result)
|
44 |
+
if isinstance(segm_result, tuple):
|
45 |
+
segm_result = segm_result[0] # ms rcnn
|
46 |
+
bboxes = np.vstack(bbox_result)
|
47 |
+
labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
|
48 |
+
|
49 |
+
labels = np.concatenate(labels)
|
50 |
+
segms = None
|
51 |
+
if segm_result is not None and len(labels) > 0: # non empty
|
52 |
+
segms = mmcv.concat_list(segm_result)
|
53 |
+
if isinstance(segms[0], torch.Tensor):
|
54 |
+
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
|
55 |
+
else:
|
56 |
+
segms = np.stack(segms, axis=0)
|
57 |
+
|
58 |
+
for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
|
59 |
+
if bbox[-1].item() <0.3:
|
60 |
+
continue
|
61 |
+
box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
|
62 |
+
polygons = Mask(segm).polygons()
|
63 |
+
|
64 |
+
infer_result['boxes'].append(box)
|
65 |
+
infer_result['segmentation'].append(polygons.segmentation)
|
66 |
+
infer_result['scores'].append(bbox[-1].item())
|
67 |
+
infer_result['labels'].append(self.model.CLASSES[label])
|
68 |
+
infer_result['label_ids'].append(label)
|
69 |
+
self.record_dict['results'].append(infer_result)
|
70 |
+
self.detect_count = labels
|
71 |
+
|
72 |
+
def write_json(self, filename):
|
73 |
+
with open(filename + '.json', 'w') as f:
|
74 |
+
json.dump(self.record_dict, f)
|
75 |
+
|
76 |
+
|
77 |
+
def main():
|
78 |
+
if torch.cuda.is_available() == False:
|
79 |
+
device='cpu'
|
80 |
+
else:
|
81 |
+
device='cuda:0'
|
82 |
+
detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
83 |
+
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth')
|
84 |
+
filenames = sorted(glob.glob('demo/images/*'))
|
85 |
+
count = 0
|
86 |
+
for filename in filenames:
|
87 |
+
img=cv2.imread(filename)
|
88 |
+
try:
|
89 |
+
img = detect_people.run_on_image(img)
|
90 |
+
img = detect.run_on_image(img)
|
91 |
+
except:
|
92 |
+
continue
|
93 |
+
count=count+1
|
94 |
+
|
95 |
+
try:
|
96 |
+
import os
|
97 |
+
os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
|
98 |
+
os.mkdirs(os.path.dirname(filename))
|
99 |
+
except:
|
100 |
+
print('done')
|
101 |
+
cv2.imwrite(filename.replace('demo','demo/results/'),img)
|
102 |
+
if count == 30000:
|
103 |
+
break
|
104 |
+
try:
|
105 |
+
detect.process_output(count)
|
106 |
+
except:
|
107 |
+
continue
|
108 |
+
'''
|
109 |
+
|
110 |
+
np.savez('FC', a= detect.record_dict)
|
111 |
+
with open('check.json', 'w') as f:
|
112 |
+
json.dump(detect.record_dict, f)
|
113 |
+
detect.write_json('seq3')
|
114 |
+
asas
|
115 |
+
detect.process_output(0)
|
116 |
+
'''
|
117 |
+
if __name__ == "__main__":
|
118 |
+
main()
|
mmcv_custom/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .checkpoint import load_checkpoint
|
4 |
+
|
5 |
+
__all__ = ['load_checkpoint']
|
mmcv_custom/checkpoint.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pkgutil
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from collections import OrderedDict
|
9 |
+
from importlib import import_module
|
10 |
+
from tempfile import TemporaryDirectory
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torchvision
|
14 |
+
from torch.optim import Optimizer
|
15 |
+
from torch.utils import model_zoo
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
import mmcv
|
19 |
+
from mmcv.fileio import FileClient
|
20 |
+
from mmcv.fileio import load as load_file
|
21 |
+
from mmcv.parallel import is_module_wrapper
|
22 |
+
from mmcv.utils import mkdir_or_exist
|
23 |
+
from mmcv.runner import get_dist_info
|
24 |
+
|
25 |
+
ENV_MMCV_HOME = 'MMCV_HOME'
|
26 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
27 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
28 |
+
|
29 |
+
|
30 |
+
def _get_mmcv_home():
|
31 |
+
mmcv_home = os.path.expanduser(
|
32 |
+
os.getenv(
|
33 |
+
ENV_MMCV_HOME,
|
34 |
+
os.path.join(
|
35 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
36 |
+
|
37 |
+
mkdir_or_exist(mmcv_home)
|
38 |
+
return mmcv_home
|
39 |
+
|
40 |
+
|
41 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
42 |
+
"""Load state_dict to a module.
|
43 |
+
|
44 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
45 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
46 |
+
param mismatch will be shown even if strict is False.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
module (Module): Module that receives the state_dict.
|
50 |
+
state_dict (OrderedDict): Weights.
|
51 |
+
strict (bool): whether to strictly enforce that the keys
|
52 |
+
in :attr:`state_dict` match the keys returned by this module's
|
53 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
54 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
55 |
+
message. If not specified, print function will be used.
|
56 |
+
"""
|
57 |
+
unexpected_keys = []
|
58 |
+
all_missing_keys = []
|
59 |
+
err_msg = []
|
60 |
+
|
61 |
+
metadata = getattr(state_dict, '_metadata', None)
|
62 |
+
state_dict = state_dict.copy()
|
63 |
+
if metadata is not None:
|
64 |
+
state_dict._metadata = metadata
|
65 |
+
|
66 |
+
# use _load_from_state_dict to enable checkpoint version control
|
67 |
+
def load(module, prefix=''):
|
68 |
+
# recursively check parallel module in case that the model has a
|
69 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
70 |
+
if is_module_wrapper(module):
|
71 |
+
module = module.module
|
72 |
+
local_metadata = {} if metadata is None else metadata.get(
|
73 |
+
prefix[:-1], {})
|
74 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
75 |
+
all_missing_keys, unexpected_keys,
|
76 |
+
err_msg)
|
77 |
+
for name, child in module._modules.items():
|
78 |
+
if child is not None:
|
79 |
+
load(child, prefix + name + '.')
|
80 |
+
|
81 |
+
load(module)
|
82 |
+
load = None # break load->load reference cycle
|
83 |
+
|
84 |
+
# ignore "num_batches_tracked" of BN layers
|
85 |
+
missing_keys = [
|
86 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
87 |
+
]
|
88 |
+
|
89 |
+
if unexpected_keys:
|
90 |
+
err_msg.append('unexpected key in source '
|
91 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
92 |
+
if missing_keys:
|
93 |
+
err_msg.append(
|
94 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
95 |
+
|
96 |
+
rank, _ = get_dist_info()
|
97 |
+
if len(err_msg) > 0 and rank == 0:
|
98 |
+
err_msg.insert(
|
99 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
100 |
+
err_msg = '\n'.join(err_msg)
|
101 |
+
if strict:
|
102 |
+
raise RuntimeError(err_msg)
|
103 |
+
elif logger is not None:
|
104 |
+
logger.warning(err_msg)
|
105 |
+
else:
|
106 |
+
print(err_msg)
|
107 |
+
|
108 |
+
|
109 |
+
def load_url_dist(url, model_dir=None):
|
110 |
+
"""In distributed setting, this function only download checkpoint at local
|
111 |
+
rank 0."""
|
112 |
+
rank, world_size = get_dist_info()
|
113 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
114 |
+
if rank == 0:
|
115 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
116 |
+
if world_size > 1:
|
117 |
+
torch.distributed.barrier()
|
118 |
+
if rank > 0:
|
119 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
120 |
+
return checkpoint
|
121 |
+
|
122 |
+
|
123 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
124 |
+
"""In distributed setting, this function only download checkpoint at local
|
125 |
+
rank 0."""
|
126 |
+
try:
|
127 |
+
from pavi import modelcloud
|
128 |
+
except ImportError:
|
129 |
+
raise ImportError(
|
130 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
131 |
+
rank, world_size = get_dist_info()
|
132 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
133 |
+
if rank == 0:
|
134 |
+
model = modelcloud.get(model_path)
|
135 |
+
with TemporaryDirectory() as tmp_dir:
|
136 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
137 |
+
model.download(downloaded_file)
|
138 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
139 |
+
if world_size > 1:
|
140 |
+
torch.distributed.barrier()
|
141 |
+
if rank > 0:
|
142 |
+
model = modelcloud.get(model_path)
|
143 |
+
with TemporaryDirectory() as tmp_dir:
|
144 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
145 |
+
model.download(downloaded_file)
|
146 |
+
checkpoint = torch.load(
|
147 |
+
downloaded_file, map_location=map_location)
|
148 |
+
return checkpoint
|
149 |
+
|
150 |
+
|
151 |
+
def load_fileclient_dist(filename, backend, map_location):
|
152 |
+
"""In distributed setting, this function only download checkpoint at local
|
153 |
+
rank 0."""
|
154 |
+
rank, world_size = get_dist_info()
|
155 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
156 |
+
allowed_backends = ['ceph']
|
157 |
+
if backend not in allowed_backends:
|
158 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
159 |
+
if rank == 0:
|
160 |
+
fileclient = FileClient(backend=backend)
|
161 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
162 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
163 |
+
if world_size > 1:
|
164 |
+
torch.distributed.barrier()
|
165 |
+
if rank > 0:
|
166 |
+
fileclient = FileClient(backend=backend)
|
167 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
168 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
169 |
+
return checkpoint
|
170 |
+
|
171 |
+
|
172 |
+
def get_torchvision_models():
|
173 |
+
model_urls = dict()
|
174 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
175 |
+
if ispkg:
|
176 |
+
continue
|
177 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
178 |
+
if hasattr(_zoo, 'model_urls'):
|
179 |
+
_urls = getattr(_zoo, 'model_urls')
|
180 |
+
model_urls.update(_urls)
|
181 |
+
return model_urls
|
182 |
+
|
183 |
+
|
184 |
+
def get_external_models():
|
185 |
+
mmcv_home = _get_mmcv_home()
|
186 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
187 |
+
default_urls = load_file(default_json_path)
|
188 |
+
assert isinstance(default_urls, dict)
|
189 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
190 |
+
if osp.exists(external_json_path):
|
191 |
+
external_urls = load_file(external_json_path)
|
192 |
+
assert isinstance(external_urls, dict)
|
193 |
+
default_urls.update(external_urls)
|
194 |
+
|
195 |
+
return default_urls
|
196 |
+
|
197 |
+
|
198 |
+
def get_mmcls_models():
|
199 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
200 |
+
mmcls_urls = load_file(mmcls_json_path)
|
201 |
+
|
202 |
+
return mmcls_urls
|
203 |
+
|
204 |
+
|
205 |
+
def get_deprecated_model_names():
|
206 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
207 |
+
'model_zoo/deprecated.json')
|
208 |
+
deprecate_urls = load_file(deprecate_json_path)
|
209 |
+
assert isinstance(deprecate_urls, dict)
|
210 |
+
|
211 |
+
return deprecate_urls
|
212 |
+
|
213 |
+
|
214 |
+
def _process_mmcls_checkpoint(checkpoint):
|
215 |
+
state_dict = checkpoint['state_dict']
|
216 |
+
new_state_dict = OrderedDict()
|
217 |
+
for k, v in state_dict.items():
|
218 |
+
if k.startswith('backbone.'):
|
219 |
+
new_state_dict[k[9:]] = v
|
220 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
221 |
+
|
222 |
+
return new_checkpoint
|
223 |
+
|
224 |
+
|
225 |
+
def _load_checkpoint(filename, map_location=None):
|
226 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
227 |
+
|
228 |
+
Args:
|
229 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
230 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
231 |
+
details.
|
232 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
236 |
+
OrderedDict storing model weights or a dict containing other
|
237 |
+
information, which depends on the checkpoint.
|
238 |
+
"""
|
239 |
+
if filename.startswith('modelzoo://'):
|
240 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
241 |
+
'use "torchvision://" instead')
|
242 |
+
model_urls = get_torchvision_models()
|
243 |
+
model_name = filename[11:]
|
244 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
245 |
+
elif filename.startswith('torchvision://'):
|
246 |
+
model_urls = get_torchvision_models()
|
247 |
+
model_name = filename[14:]
|
248 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
249 |
+
elif filename.startswith('open-mmlab://'):
|
250 |
+
model_urls = get_external_models()
|
251 |
+
model_name = filename[13:]
|
252 |
+
deprecated_urls = get_deprecated_model_names()
|
253 |
+
if model_name in deprecated_urls:
|
254 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
255 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
256 |
+
model_name = deprecated_urls[model_name]
|
257 |
+
model_url = model_urls[model_name]
|
258 |
+
# check if is url
|
259 |
+
if model_url.startswith(('http://', 'https://')):
|
260 |
+
checkpoint = load_url_dist(model_url)
|
261 |
+
else:
|
262 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
263 |
+
if not osp.isfile(filename):
|
264 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
265 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
266 |
+
elif filename.startswith('mmcls://'):
|
267 |
+
model_urls = get_mmcls_models()
|
268 |
+
model_name = filename[8:]
|
269 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
270 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
271 |
+
elif filename.startswith(('http://', 'https://')):
|
272 |
+
checkpoint = load_url_dist(filename)
|
273 |
+
elif filename.startswith('pavi://'):
|
274 |
+
model_path = filename[7:]
|
275 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
276 |
+
elif filename.startswith('s3://'):
|
277 |
+
checkpoint = load_fileclient_dist(
|
278 |
+
filename, backend='ceph', map_location=map_location)
|
279 |
+
else:
|
280 |
+
if not osp.isfile(filename):
|
281 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
282 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
283 |
+
return checkpoint
|
284 |
+
|
285 |
+
|
286 |
+
def load_checkpoint(model,
|
287 |
+
filename,
|
288 |
+
map_location='cpu',
|
289 |
+
strict=False,
|
290 |
+
logger=None):
|
291 |
+
"""Load checkpoint from a file or URI.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
model (Module): Module to load checkpoint.
|
295 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
296 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
297 |
+
details.
|
298 |
+
map_location (str): Same as :func:`torch.load`.
|
299 |
+
strict (bool): Whether to allow different params for the model and
|
300 |
+
checkpoint.
|
301 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
dict or OrderedDict: The loaded checkpoint.
|
305 |
+
"""
|
306 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
307 |
+
# OrderedDict is a subclass of dict
|
308 |
+
if not isinstance(checkpoint, dict):
|
309 |
+
raise RuntimeError(
|
310 |
+
f'No state_dict found in checkpoint file {filename}')
|
311 |
+
# get state_dict from checkpoint
|
312 |
+
if 'state_dict' in checkpoint:
|
313 |
+
state_dict = checkpoint['state_dict']
|
314 |
+
elif 'model' in checkpoint:
|
315 |
+
state_dict = checkpoint['model']
|
316 |
+
else:
|
317 |
+
state_dict = checkpoint
|
318 |
+
# strip prefix of state_dict
|
319 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
320 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
321 |
+
|
322 |
+
# for MoBY, load model of online branch
|
323 |
+
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
324 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
325 |
+
|
326 |
+
# reshape absolute position embedding
|
327 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
328 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
329 |
+
N1, L, C1 = absolute_pos_embed.size()
|
330 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
331 |
+
if N1 != N2 or C1 != C2 or L != H*W:
|
332 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
333 |
+
else:
|
334 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
335 |
+
|
336 |
+
# interpolate position bias table if needed
|
337 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
338 |
+
for table_key in relative_position_bias_table_keys:
|
339 |
+
table_pretrained = state_dict[table_key]
|
340 |
+
table_current = model.state_dict()[table_key]
|
341 |
+
L1, nH1 = table_pretrained.size()
|
342 |
+
L2, nH2 = table_current.size()
|
343 |
+
if nH1 != nH2:
|
344 |
+
logger.warning(f"Error in loading {table_key}, pass")
|
345 |
+
else:
|
346 |
+
if L1 != L2:
|
347 |
+
S1 = int(L1 ** 0.5)
|
348 |
+
S2 = int(L2 ** 0.5)
|
349 |
+
table_pretrained_resized = F.interpolate(
|
350 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
351 |
+
size=(S2, S2), mode='bicubic')
|
352 |
+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
353 |
+
|
354 |
+
# load state_dict
|
355 |
+
load_state_dict(model, state_dict, strict, logger)
|
356 |
+
return checkpoint
|
357 |
+
|
358 |
+
|
359 |
+
def weights_to_cpu(state_dict):
|
360 |
+
"""Copy a model state_dict to cpu.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
state_dict (OrderedDict): Model weights on GPU.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
OrderedDict: Model weights on GPU.
|
367 |
+
"""
|
368 |
+
state_dict_cpu = OrderedDict()
|
369 |
+
for key, val in state_dict.items():
|
370 |
+
state_dict_cpu[key] = val.cpu()
|
371 |
+
return state_dict_cpu
|
372 |
+
|
373 |
+
|
374 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
375 |
+
"""Saves module state to `destination` dictionary.
|
376 |
+
|
377 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
module (nn.Module): The module to generate state_dict.
|
381 |
+
destination (dict): A dict where state will be stored.
|
382 |
+
prefix (str): The prefix for parameters and buffers used in this
|
383 |
+
module.
|
384 |
+
"""
|
385 |
+
for name, param in module._parameters.items():
|
386 |
+
if param is not None:
|
387 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
388 |
+
for name, buf in module._buffers.items():
|
389 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
390 |
+
if buf is not None:
|
391 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
392 |
+
|
393 |
+
|
394 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
395 |
+
"""Returns a dictionary containing a whole state of the module.
|
396 |
+
|
397 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
398 |
+
included. Keys are corresponding parameter and buffer names.
|
399 |
+
|
400 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
401 |
+
recursively check parallel module in case that the model has a complicated
|
402 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
403 |
+
|
404 |
+
Args:
|
405 |
+
module (nn.Module): The module to generate state_dict.
|
406 |
+
destination (OrderedDict): Returned dict for the state of the
|
407 |
+
module.
|
408 |
+
prefix (str): Prefix of the key.
|
409 |
+
keep_vars (bool): Whether to keep the variable property of the
|
410 |
+
parameters. Default: False.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
dict: A dictionary containing a whole state of the module.
|
414 |
+
"""
|
415 |
+
# recursively check parallel module in case that the model has a
|
416 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
417 |
+
if is_module_wrapper(module):
|
418 |
+
module = module.module
|
419 |
+
|
420 |
+
# below is the same as torch.nn.Module.state_dict()
|
421 |
+
if destination is None:
|
422 |
+
destination = OrderedDict()
|
423 |
+
destination._metadata = OrderedDict()
|
424 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
425 |
+
version=module._version)
|
426 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
427 |
+
for name, child in module._modules.items():
|
428 |
+
if child is not None:
|
429 |
+
get_state_dict(
|
430 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
431 |
+
for hook in module._state_dict_hooks.values():
|
432 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
433 |
+
if hook_result is not None:
|
434 |
+
destination = hook_result
|
435 |
+
return destination
|
436 |
+
|
437 |
+
|
438 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
439 |
+
"""Save checkpoint to file.
|
440 |
+
|
441 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
442 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
model (Module): Module whose params are to be saved.
|
446 |
+
filename (str): Checkpoint filename.
|
447 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
448 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
449 |
+
"""
|
450 |
+
if meta is None:
|
451 |
+
meta = {}
|
452 |
+
elif not isinstance(meta, dict):
|
453 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
454 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
455 |
+
|
456 |
+
if is_module_wrapper(model):
|
457 |
+
model = model.module
|
458 |
+
|
459 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
460 |
+
# save class name to the meta
|
461 |
+
meta.update(CLASSES=model.CLASSES)
|
462 |
+
|
463 |
+
checkpoint = {
|
464 |
+
'meta': meta,
|
465 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
466 |
+
}
|
467 |
+
# save optimizer state dict in the checkpoint
|
468 |
+
if isinstance(optimizer, Optimizer):
|
469 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
470 |
+
elif isinstance(optimizer, dict):
|
471 |
+
checkpoint['optimizer'] = {}
|
472 |
+
for name, optim in optimizer.items():
|
473 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
474 |
+
|
475 |
+
if filename.startswith('pavi://'):
|
476 |
+
try:
|
477 |
+
from pavi import modelcloud
|
478 |
+
from pavi.exception import NodeNotFoundError
|
479 |
+
except ImportError:
|
480 |
+
raise ImportError(
|
481 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
482 |
+
model_path = filename[7:]
|
483 |
+
root = modelcloud.Folder()
|
484 |
+
model_dir, model_name = osp.split(model_path)
|
485 |
+
try:
|
486 |
+
model = modelcloud.get(model_dir)
|
487 |
+
except NodeNotFoundError:
|
488 |
+
model = root.create_training_model(model_dir)
|
489 |
+
with TemporaryDirectory() as tmp_dir:
|
490 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
491 |
+
with open(checkpoint_file, 'wb') as f:
|
492 |
+
torch.save(checkpoint, f)
|
493 |
+
f.flush()
|
494 |
+
model.create_file(checkpoint_file, name=model_name)
|
495 |
+
else:
|
496 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
497 |
+
# immediately flush buffer
|
498 |
+
with open(filename, 'wb') as f:
|
499 |
+
torch.save(checkpoint, f)
|
500 |
+
f.flush()
|
mmcv_custom/runner/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
from .checkpoint import save_checkpoint
|
3 |
+
from .epoch_based_runner import EpochBasedRunnerAmp
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'EpochBasedRunnerAmp', 'save_checkpoint'
|
8 |
+
]
|
mmcv_custom/runner/checkpoint.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import time
|
4 |
+
from tempfile import TemporaryDirectory
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmcv.parallel import is_module_wrapper
|
11 |
+
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
|
12 |
+
|
13 |
+
try:
|
14 |
+
import apex
|
15 |
+
except:
|
16 |
+
print('apex is not installed')
|
17 |
+
|
18 |
+
|
19 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
20 |
+
"""Save checkpoint to file.
|
21 |
+
|
22 |
+
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
|
23 |
+
``optimizer``, ``amp``. By default ``meta`` will contain version
|
24 |
+
and time info.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model (Module): Module whose params are to be saved.
|
28 |
+
filename (str): Checkpoint filename.
|
29 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
30 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
31 |
+
"""
|
32 |
+
if meta is None:
|
33 |
+
meta = {}
|
34 |
+
elif not isinstance(meta, dict):
|
35 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
36 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
37 |
+
|
38 |
+
if is_module_wrapper(model):
|
39 |
+
model = model.module
|
40 |
+
|
41 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
42 |
+
# save class name to the meta
|
43 |
+
meta.update(CLASSES=model.CLASSES)
|
44 |
+
|
45 |
+
checkpoint = {
|
46 |
+
'meta': meta,
|
47 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
48 |
+
}
|
49 |
+
# save optimizer state dict in the checkpoint
|
50 |
+
if isinstance(optimizer, Optimizer):
|
51 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
52 |
+
elif isinstance(optimizer, dict):
|
53 |
+
checkpoint['optimizer'] = {}
|
54 |
+
for name, optim in optimizer.items():
|
55 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
56 |
+
|
57 |
+
# save amp state dict in the checkpoint
|
58 |
+
checkpoint['amp'] = apex.amp.state_dict()
|
59 |
+
|
60 |
+
if filename.startswith('pavi://'):
|
61 |
+
try:
|
62 |
+
from pavi import modelcloud
|
63 |
+
from pavi.exception import NodeNotFoundError
|
64 |
+
except ImportError:
|
65 |
+
raise ImportError(
|
66 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
67 |
+
model_path = filename[7:]
|
68 |
+
root = modelcloud.Folder()
|
69 |
+
model_dir, model_name = osp.split(model_path)
|
70 |
+
try:
|
71 |
+
model = modelcloud.get(model_dir)
|
72 |
+
except NodeNotFoundError:
|
73 |
+
model = root.create_training_model(model_dir)
|
74 |
+
with TemporaryDirectory() as tmp_dir:
|
75 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
76 |
+
with open(checkpoint_file, 'wb') as f:
|
77 |
+
torch.save(checkpoint, f)
|
78 |
+
f.flush()
|
79 |
+
model.create_file(checkpoint_file, name=model_name)
|
80 |
+
else:
|
81 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
82 |
+
# immediately flush buffer
|
83 |
+
with open(filename, 'wb') as f:
|
84 |
+
torch.save(checkpoint, f)
|
85 |
+
f.flush()
|
mmcv_custom/runner/epoch_based_runner.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import platform
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmcv.runner import RUNNERS, EpochBasedRunner
|
11 |
+
from .checkpoint import save_checkpoint
|
12 |
+
|
13 |
+
try:
|
14 |
+
import apex
|
15 |
+
except:
|
16 |
+
print('apex is not installed')
|
17 |
+
|
18 |
+
|
19 |
+
@RUNNERS.register_module()
|
20 |
+
class EpochBasedRunnerAmp(EpochBasedRunner):
|
21 |
+
"""Epoch-based Runner with AMP support.
|
22 |
+
|
23 |
+
This runner train models epoch by epoch.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def save_checkpoint(self,
|
27 |
+
out_dir,
|
28 |
+
filename_tmpl='epoch_{}.pth',
|
29 |
+
save_optimizer=True,
|
30 |
+
meta=None,
|
31 |
+
create_symlink=True):
|
32 |
+
"""Save the checkpoint.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
out_dir (str): The directory that checkpoints are saved.
|
36 |
+
filename_tmpl (str, optional): The checkpoint filename template,
|
37 |
+
which contains a placeholder for the epoch number.
|
38 |
+
Defaults to 'epoch_{}.pth'.
|
39 |
+
save_optimizer (bool, optional): Whether to save the optimizer to
|
40 |
+
the checkpoint. Defaults to True.
|
41 |
+
meta (dict, optional): The meta information to be saved in the
|
42 |
+
checkpoint. Defaults to None.
|
43 |
+
create_symlink (bool, optional): Whether to create a symlink
|
44 |
+
"latest.pth" to point to the latest checkpoint.
|
45 |
+
Defaults to True.
|
46 |
+
"""
|
47 |
+
if meta is None:
|
48 |
+
meta = dict(epoch=self.epoch + 1, iter=self.iter)
|
49 |
+
elif isinstance(meta, dict):
|
50 |
+
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
51 |
+
else:
|
52 |
+
raise TypeError(
|
53 |
+
f'meta should be a dict or None, but got {type(meta)}')
|
54 |
+
if self.meta is not None:
|
55 |
+
meta.update(self.meta)
|
56 |
+
|
57 |
+
filename = filename_tmpl.format(self.epoch + 1)
|
58 |
+
filepath = osp.join(out_dir, filename)
|
59 |
+
optimizer = self.optimizer if save_optimizer else None
|
60 |
+
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
61 |
+
# in some environments, `os.symlink` is not supported, you may need to
|
62 |
+
# set `create_symlink` to False
|
63 |
+
if create_symlink:
|
64 |
+
dst_file = osp.join(out_dir, 'latest.pth')
|
65 |
+
if platform.system() != 'Windows':
|
66 |
+
mmcv.symlink(filename, dst_file)
|
67 |
+
else:
|
68 |
+
shutil.copy(filepath, dst_file)
|
69 |
+
|
70 |
+
def resume(self,
|
71 |
+
checkpoint,
|
72 |
+
resume_optimizer=True,
|
73 |
+
map_location='default'):
|
74 |
+
if map_location == 'default':
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
device_id = torch.cuda.current_device()
|
77 |
+
checkpoint = self.load_checkpoint(
|
78 |
+
checkpoint,
|
79 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
80 |
+
else:
|
81 |
+
checkpoint = self.load_checkpoint(checkpoint)
|
82 |
+
else:
|
83 |
+
checkpoint = self.load_checkpoint(
|
84 |
+
checkpoint, map_location=map_location)
|
85 |
+
|
86 |
+
self._epoch = checkpoint['meta']['epoch']
|
87 |
+
self._iter = checkpoint['meta']['iter']
|
88 |
+
if 'optimizer' in checkpoint and resume_optimizer:
|
89 |
+
if isinstance(self.optimizer, Optimizer):
|
90 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
91 |
+
elif isinstance(self.optimizer, dict):
|
92 |
+
for k in self.optimizer.keys():
|
93 |
+
self.optimizer[k].load_state_dict(
|
94 |
+
checkpoint['optimizer'][k])
|
95 |
+
else:
|
96 |
+
raise TypeError(
|
97 |
+
'Optimizer should be dict or torch.optim.Optimizer '
|
98 |
+
f'but got {type(self.optimizer)}')
|
99 |
+
|
100 |
+
if 'amp' in checkpoint:
|
101 |
+
apex.amp.load_state_dict(checkpoint['amp'])
|
102 |
+
self.logger.info('load amp state dict')
|
103 |
+
|
104 |
+
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
mmdet/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
|
3 |
+
from .version import __version__, short_version
|
4 |
+
|
5 |
+
|
6 |
+
def digit_version(version_str):
|
7 |
+
digit_version = []
|
8 |
+
for x in version_str.split('.'):
|
9 |
+
if x.isdigit():
|
10 |
+
digit_version.append(int(x))
|
11 |
+
elif x.find('rc') != -1:
|
12 |
+
patch_version = x.split('rc')
|
13 |
+
digit_version.append(int(patch_version[0]) - 1)
|
14 |
+
digit_version.append(int(patch_version[1]))
|
15 |
+
return digit_version
|
16 |
+
|
17 |
+
|
18 |
+
mmcv_minimum_version = '1.2.4'
|
19 |
+
mmcv_maximum_version = '1.4.0'
|
20 |
+
mmcv_version = digit_version(mmcv.__version__)
|
21 |
+
|
22 |
+
|
23 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
24 |
+
and mmcv_version <= digit_version(mmcv_maximum_version)), \
|
25 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
26 |
+
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
27 |
+
|
28 |
+
__all__ = ['__version__', 'short_version']
|
mmdet/apis/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import (async_inference_detector, inference_detector,
|
2 |
+
init_detector, show_result_pyplot)
|
3 |
+
from .test import multi_gpu_test, single_gpu_test
|
4 |
+
from .train import get_root_logger, set_random_seed, train_detector
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
|
8 |
+
'async_inference_detector', 'inference_detector', 'show_result_pyplot',
|
9 |
+
'multi_gpu_test', 'single_gpu_test'
|
10 |
+
]
|
mmdet/apis/inference.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.ops import RoIPool
|
7 |
+
from mmcv.parallel import collate, scatter
|
8 |
+
from mmcv.runner import load_checkpoint
|
9 |
+
|
10 |
+
from mmdet.core import get_classes
|
11 |
+
from mmdet.datasets import replace_ImageToTensor
|
12 |
+
from mmdet.datasets.pipelines import Compose
|
13 |
+
from mmdet.models import build_detector
|
14 |
+
|
15 |
+
|
16 |
+
def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
17 |
+
"""Initialize a detector from config file.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
config (str or :obj:`mmcv.Config`): Config file path or the config
|
21 |
+
object.
|
22 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
23 |
+
will not load any weights.
|
24 |
+
cfg_options (dict): Options to override some settings in the used
|
25 |
+
config.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
nn.Module: The constructed detector.
|
29 |
+
"""
|
30 |
+
if isinstance(config, str):
|
31 |
+
config = mmcv.Config.fromfile(config)
|
32 |
+
elif not isinstance(config, mmcv.Config):
|
33 |
+
raise TypeError('config must be a filename or Config object, '
|
34 |
+
f'but got {type(config)}')
|
35 |
+
if cfg_options is not None:
|
36 |
+
config.merge_from_dict(cfg_options)
|
37 |
+
config.model.pretrained = None
|
38 |
+
config.model.train_cfg = None
|
39 |
+
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
40 |
+
if checkpoint is not None:
|
41 |
+
map_loc = 'cpu' if device == 'cpu' else None
|
42 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
|
43 |
+
if 'CLASSES' in checkpoint.get('meta', {}):
|
44 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
45 |
+
else:
|
46 |
+
warnings.simplefilter('once')
|
47 |
+
warnings.warn('Class names are not saved in the checkpoint\'s '
|
48 |
+
'meta data, use COCO classes by default.')
|
49 |
+
model.CLASSES = get_classes('coco')
|
50 |
+
model.cfg = config # save the config in the model for convenience
|
51 |
+
model.to(device)
|
52 |
+
model.eval()
|
53 |
+
return model
|
54 |
+
|
55 |
+
|
56 |
+
class LoadImage(object):
|
57 |
+
"""Deprecated.
|
58 |
+
|
59 |
+
A simple pipeline to load image.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __call__(self, results):
|
63 |
+
"""Call function to load images into results.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
results (dict): A result dict contains the file name
|
67 |
+
of the image to be read.
|
68 |
+
Returns:
|
69 |
+
dict: ``results`` will be returned containing loaded image.
|
70 |
+
"""
|
71 |
+
warnings.simplefilter('once')
|
72 |
+
warnings.warn('`LoadImage` is deprecated and will be removed in '
|
73 |
+
'future releases. You may use `LoadImageFromWebcam` '
|
74 |
+
'from `mmdet.datasets.pipelines.` instead.')
|
75 |
+
if isinstance(results['img'], str):
|
76 |
+
results['filename'] = results['img']
|
77 |
+
results['ori_filename'] = results['img']
|
78 |
+
else:
|
79 |
+
results['filename'] = None
|
80 |
+
results['ori_filename'] = None
|
81 |
+
img = mmcv.imread(results['img'])
|
82 |
+
results['img'] = img
|
83 |
+
results['img_fields'] = ['img']
|
84 |
+
results['img_shape'] = img.shape
|
85 |
+
results['ori_shape'] = img.shape
|
86 |
+
return results
|
87 |
+
|
88 |
+
|
89 |
+
def inference_detector(model, imgs):
|
90 |
+
"""Inference image(s) with the detector.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model (nn.Module): The loaded detector.
|
94 |
+
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
95 |
+
Either image files or loaded images.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
If imgs is a list or tuple, the same length list type results
|
99 |
+
will be returned, otherwise return the detection results directly.
|
100 |
+
"""
|
101 |
+
|
102 |
+
if isinstance(imgs, (list, tuple)):
|
103 |
+
is_batch = True
|
104 |
+
else:
|
105 |
+
imgs = [imgs]
|
106 |
+
is_batch = False
|
107 |
+
|
108 |
+
cfg = model.cfg
|
109 |
+
device = next(model.parameters()).device # model device
|
110 |
+
|
111 |
+
if isinstance(imgs[0], np.ndarray):
|
112 |
+
cfg = cfg.copy()
|
113 |
+
# set loading pipeline type
|
114 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
115 |
+
|
116 |
+
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
117 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
118 |
+
|
119 |
+
datas = []
|
120 |
+
for img in imgs:
|
121 |
+
# prepare data
|
122 |
+
if isinstance(img, np.ndarray):
|
123 |
+
# directly add img
|
124 |
+
data = dict(img=img)
|
125 |
+
else:
|
126 |
+
# add information into dict
|
127 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
128 |
+
# build the data pipeline
|
129 |
+
data = test_pipeline(data)
|
130 |
+
datas.append(data)
|
131 |
+
|
132 |
+
data = collate(datas, samples_per_gpu=len(imgs))
|
133 |
+
# just get the actual data from DataContainer
|
134 |
+
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
135 |
+
data['img'] = [img.data[0] for img in data['img']]
|
136 |
+
if next(model.parameters()).is_cuda:
|
137 |
+
# scatter to specified GPU
|
138 |
+
data = scatter(data, [device])[0]
|
139 |
+
else:
|
140 |
+
for m in model.modules():
|
141 |
+
assert not isinstance(
|
142 |
+
m, RoIPool
|
143 |
+
), 'CPU inference with RoIPool is not supported currently.'
|
144 |
+
|
145 |
+
# forward the model
|
146 |
+
with torch.no_grad():
|
147 |
+
results = model(return_loss=False, rescale=True, **data)
|
148 |
+
|
149 |
+
if not is_batch:
|
150 |
+
return results[0]
|
151 |
+
else:
|
152 |
+
return results
|
153 |
+
|
154 |
+
|
155 |
+
async def async_inference_detector(model, img):
|
156 |
+
"""Async inference image(s) with the detector.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
model (nn.Module): The loaded detector.
|
160 |
+
img (str | ndarray): Either image files or loaded images.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
Awaitable detection results.
|
164 |
+
"""
|
165 |
+
cfg = model.cfg
|
166 |
+
device = next(model.parameters()).device # model device
|
167 |
+
# prepare data
|
168 |
+
if isinstance(img, np.ndarray):
|
169 |
+
# directly add img
|
170 |
+
data = dict(img=img)
|
171 |
+
cfg = cfg.copy()
|
172 |
+
# set loading pipeline type
|
173 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
174 |
+
else:
|
175 |
+
# add information into dict
|
176 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
177 |
+
# build the data pipeline
|
178 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
179 |
+
data = test_pipeline(data)
|
180 |
+
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
|
181 |
+
|
182 |
+
# We don't restore `torch.is_grad_enabled()` value during concurrent
|
183 |
+
# inference since execution can overlap
|
184 |
+
torch.set_grad_enabled(False)
|
185 |
+
result = await model.aforward_test(rescale=True, **data)
|
186 |
+
return result
|
187 |
+
|
188 |
+
|
189 |
+
def show_result_pyplot(model,
|
190 |
+
img,
|
191 |
+
result,
|
192 |
+
score_thr=0.3,
|
193 |
+
title='result',
|
194 |
+
wait_time=0):
|
195 |
+
"""Visualize the detection results on the image.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model (nn.Module): The loaded detector.
|
199 |
+
img (str or np.ndarray): Image filename or loaded image.
|
200 |
+
result (tuple[list] or list): The detection result, can be either
|
201 |
+
(bbox, segm) or just bbox.
|
202 |
+
score_thr (float): The threshold to visualize the bboxes and masks.
|
203 |
+
title (str): Title of the pyplot figure.
|
204 |
+
wait_time (float): Value of waitKey param.
|
205 |
+
Default: 0.
|
206 |
+
"""
|
207 |
+
if hasattr(model, 'module'):
|
208 |
+
model = model.module
|
209 |
+
model.show_result(
|
210 |
+
img,
|
211 |
+
result,
|
212 |
+
score_thr=score_thr,
|
213 |
+
show=True,
|
214 |
+
wait_time=wait_time,
|
215 |
+
win_name=title,
|
216 |
+
bbox_color=(72, 101, 241),
|
217 |
+
text_color=(72, 101, 241))
|
mmdet/apis/test.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import pickle
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import time
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from mmcv.image import tensor2imgs
|
11 |
+
from mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
from mmdet.core import encode_mask_results
|
14 |
+
|
15 |
+
|
16 |
+
def single_gpu_test(model,
|
17 |
+
data_loader,
|
18 |
+
show=False,
|
19 |
+
out_dir=None,
|
20 |
+
show_score_thr=0.3):
|
21 |
+
model.eval()
|
22 |
+
results = []
|
23 |
+
dataset = data_loader.dataset
|
24 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
25 |
+
for i, data in enumerate(data_loader):
|
26 |
+
with torch.no_grad():
|
27 |
+
result = model(return_loss=False, rescale=True, **data)
|
28 |
+
|
29 |
+
batch_size = len(result)
|
30 |
+
if show or out_dir:
|
31 |
+
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
|
32 |
+
img_tensor = data['img'][0]
|
33 |
+
else:
|
34 |
+
img_tensor = data['img'][0].data[0]
|
35 |
+
img_metas = data['img_metas'][0].data[0]
|
36 |
+
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
37 |
+
assert len(imgs) == len(img_metas)
|
38 |
+
|
39 |
+
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
|
40 |
+
h, w, _ = img_meta['img_shape']
|
41 |
+
img_show = img[:h, :w, :]
|
42 |
+
|
43 |
+
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
44 |
+
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
45 |
+
|
46 |
+
if out_dir:
|
47 |
+
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
48 |
+
else:
|
49 |
+
out_file = None
|
50 |
+
model.module.show_result(
|
51 |
+
img_show,
|
52 |
+
result[i],
|
53 |
+
show=show,
|
54 |
+
out_file=out_file,
|
55 |
+
score_thr=show_score_thr)
|
56 |
+
|
57 |
+
# encode mask results
|
58 |
+
if isinstance(result[0], tuple):
|
59 |
+
result = [(bbox_results, encode_mask_results(mask_results))
|
60 |
+
for bbox_results, mask_results in result]
|
61 |
+
results.extend(result)
|
62 |
+
|
63 |
+
for _ in range(batch_size):
|
64 |
+
prog_bar.update()
|
65 |
+
return results
|
66 |
+
|
67 |
+
|
68 |
+
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
69 |
+
"""Test model with multiple gpus.
|
70 |
+
|
71 |
+
This method tests model with multiple gpus and collects the results
|
72 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
73 |
+
it encodes results to gpu tensors and use gpu communication for results
|
74 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
75 |
+
and collects them by the rank 0 worker.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
model (nn.Module): Model to be tested.
|
79 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
80 |
+
tmpdir (str): Path of directory to save the temporary results from
|
81 |
+
different gpus under cpu mode.
|
82 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
list: The prediction results.
|
86 |
+
"""
|
87 |
+
model.eval()
|
88 |
+
results = []
|
89 |
+
dataset = data_loader.dataset
|
90 |
+
rank, world_size = get_dist_info()
|
91 |
+
if rank == 0:
|
92 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
93 |
+
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
94 |
+
for i, data in enumerate(data_loader):
|
95 |
+
with torch.no_grad():
|
96 |
+
result = model(return_loss=False, rescale=True, **data)
|
97 |
+
# encode mask results
|
98 |
+
if isinstance(result[0], tuple):
|
99 |
+
result = [(bbox_results, encode_mask_results(mask_results))
|
100 |
+
for bbox_results, mask_results in result]
|
101 |
+
results.extend(result)
|
102 |
+
|
103 |
+
if rank == 0:
|
104 |
+
batch_size = len(result)
|
105 |
+
for _ in range(batch_size * world_size):
|
106 |
+
prog_bar.update()
|
107 |
+
|
108 |
+
# collect results from all ranks
|
109 |
+
if gpu_collect:
|
110 |
+
results = collect_results_gpu(results, len(dataset))
|
111 |
+
else:
|
112 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
113 |
+
return results
|
114 |
+
|
115 |
+
|
116 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
117 |
+
rank, world_size = get_dist_info()
|
118 |
+
# create a tmp dir if it is not specified
|
119 |
+
if tmpdir is None:
|
120 |
+
MAX_LEN = 512
|
121 |
+
# 32 is whitespace
|
122 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
123 |
+
32,
|
124 |
+
dtype=torch.uint8,
|
125 |
+
device='cuda')
|
126 |
+
if rank == 0:
|
127 |
+
mmcv.mkdir_or_exist('.dist_test')
|
128 |
+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
129 |
+
tmpdir = torch.tensor(
|
130 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
131 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
132 |
+
dist.broadcast(dir_tensor, 0)
|
133 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
134 |
+
else:
|
135 |
+
mmcv.mkdir_or_exist(tmpdir)
|
136 |
+
# dump the part result to the dir
|
137 |
+
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
|
138 |
+
dist.barrier()
|
139 |
+
# collect all parts
|
140 |
+
if rank != 0:
|
141 |
+
return None
|
142 |
+
else:
|
143 |
+
# load results of all parts from tmp dir
|
144 |
+
part_list = []
|
145 |
+
for i in range(world_size):
|
146 |
+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
|
147 |
+
part_list.append(mmcv.load(part_file))
|
148 |
+
# sort the results
|
149 |
+
ordered_results = []
|
150 |
+
for res in zip(*part_list):
|
151 |
+
ordered_results.extend(list(res))
|
152 |
+
# the dataloader may pad some samples
|
153 |
+
ordered_results = ordered_results[:size]
|
154 |
+
# remove tmp dir
|
155 |
+
shutil.rmtree(tmpdir)
|
156 |
+
return ordered_results
|
157 |
+
|
158 |
+
|
159 |
+
def collect_results_gpu(result_part, size):
|
160 |
+
rank, world_size = get_dist_info()
|
161 |
+
# dump result part to tensor with pickle
|
162 |
+
part_tensor = torch.tensor(
|
163 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
164 |
+
# gather all result part tensor shape
|
165 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
166 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
167 |
+
dist.all_gather(shape_list, shape_tensor)
|
168 |
+
# padding result part tensor to max length
|
169 |
+
shape_max = torch.tensor(shape_list).max()
|
170 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
171 |
+
part_send[:shape_tensor[0]] = part_tensor
|
172 |
+
part_recv_list = [
|
173 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
174 |
+
]
|
175 |
+
# gather all result part
|
176 |
+
dist.all_gather(part_recv_list, part_send)
|
177 |
+
|
178 |
+
if rank == 0:
|
179 |
+
part_list = []
|
180 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
181 |
+
part_list.append(
|
182 |
+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
183 |
+
# sort the results
|
184 |
+
ordered_results = []
|
185 |
+
for res in zip(*part_list):
|
186 |
+
ordered_results.extend(list(res))
|
187 |
+
# the dataloader may pad some samples
|
188 |
+
ordered_results = ordered_results[:size]
|
189 |
+
return ordered_results
|
mmdet/apis/train.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
7 |
+
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
8 |
+
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
9 |
+
build_runner)
|
10 |
+
from mmcv.utils import build_from_cfg
|
11 |
+
|
12 |
+
from mmdet.core import DistEvalHook, EvalHook
|
13 |
+
from mmdet.datasets import (build_dataloader, build_dataset,
|
14 |
+
replace_ImageToTensor)
|
15 |
+
from mmdet.utils import get_root_logger
|
16 |
+
from mmcv_custom.runner import EpochBasedRunnerAmp
|
17 |
+
try:
|
18 |
+
import apex
|
19 |
+
except:
|
20 |
+
print('apex is not installed')
|
21 |
+
|
22 |
+
|
23 |
+
def set_random_seed(seed, deterministic=False):
|
24 |
+
"""Set random seed.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
seed (int): Seed to be used.
|
28 |
+
deterministic (bool): Whether to set the deterministic option for
|
29 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
30 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
31 |
+
Default: False.
|
32 |
+
"""
|
33 |
+
random.seed(seed)
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
if deterministic:
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
torch.backends.cudnn.benchmark = False
|
40 |
+
|
41 |
+
|
42 |
+
def train_detector(model,
|
43 |
+
dataset,
|
44 |
+
cfg,
|
45 |
+
distributed=False,
|
46 |
+
validate=False,
|
47 |
+
timestamp=None,
|
48 |
+
meta=None):
|
49 |
+
logger = get_root_logger(cfg.log_level)
|
50 |
+
|
51 |
+
# prepare data loaders
|
52 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
53 |
+
if 'imgs_per_gpu' in cfg.data:
|
54 |
+
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
55 |
+
'Please use "samples_per_gpu" instead')
|
56 |
+
if 'samples_per_gpu' in cfg.data:
|
57 |
+
logger.warning(
|
58 |
+
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
59 |
+
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
60 |
+
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
61 |
+
else:
|
62 |
+
logger.warning(
|
63 |
+
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
64 |
+
f'{cfg.data.imgs_per_gpu} in this experiments')
|
65 |
+
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
66 |
+
|
67 |
+
data_loaders = [
|
68 |
+
build_dataloader(
|
69 |
+
ds,
|
70 |
+
cfg.data.samples_per_gpu,
|
71 |
+
cfg.data.workers_per_gpu,
|
72 |
+
# cfg.gpus will be ignored if distributed
|
73 |
+
len(cfg.gpu_ids),
|
74 |
+
dist=distributed,
|
75 |
+
seed=cfg.seed) for ds in dataset
|
76 |
+
]
|
77 |
+
|
78 |
+
# build optimizer
|
79 |
+
optimizer = build_optimizer(model, cfg.optimizer)
|
80 |
+
|
81 |
+
# use apex fp16 optimizer
|
82 |
+
if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
|
83 |
+
if cfg.optimizer_config.get("use_fp16", False):
|
84 |
+
model, optimizer = apex.amp.initialize(
|
85 |
+
model.cuda(), optimizer, opt_level="O1")
|
86 |
+
for m in model.modules():
|
87 |
+
if hasattr(m, "fp16_enabled"):
|
88 |
+
m.fp16_enabled = True
|
89 |
+
|
90 |
+
# put model on gpus
|
91 |
+
if distributed:
|
92 |
+
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
93 |
+
# Sets the `find_unused_parameters` parameter in
|
94 |
+
# torch.nn.parallel.DistributedDataParallel
|
95 |
+
model = MMDistributedDataParallel(
|
96 |
+
model.cuda(),
|
97 |
+
device_ids=[torch.cuda.current_device()],
|
98 |
+
broadcast_buffers=False,
|
99 |
+
find_unused_parameters=find_unused_parameters)
|
100 |
+
else:
|
101 |
+
model = MMDataParallel(
|
102 |
+
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
103 |
+
|
104 |
+
if 'runner' not in cfg:
|
105 |
+
cfg.runner = {
|
106 |
+
'type': 'EpochBasedRunner',
|
107 |
+
'max_epochs': cfg.total_epochs
|
108 |
+
}
|
109 |
+
warnings.warn(
|
110 |
+
'config is now expected to have a `runner` section, '
|
111 |
+
'please set `runner` in your config.', UserWarning)
|
112 |
+
else:
|
113 |
+
if 'total_epochs' in cfg:
|
114 |
+
assert cfg.total_epochs == cfg.runner.max_epochs
|
115 |
+
|
116 |
+
# build runner
|
117 |
+
runner = build_runner(
|
118 |
+
cfg.runner,
|
119 |
+
default_args=dict(
|
120 |
+
model=model,
|
121 |
+
optimizer=optimizer,
|
122 |
+
work_dir=cfg.work_dir,
|
123 |
+
logger=logger,
|
124 |
+
meta=meta))
|
125 |
+
|
126 |
+
# an ugly workaround to make .log and .log.json filenames the same
|
127 |
+
runner.timestamp = timestamp
|
128 |
+
|
129 |
+
# fp16 setting
|
130 |
+
fp16_cfg = cfg.get('fp16', None)
|
131 |
+
if fp16_cfg is not None:
|
132 |
+
optimizer_config = Fp16OptimizerHook(
|
133 |
+
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
134 |
+
elif distributed and 'type' not in cfg.optimizer_config:
|
135 |
+
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
136 |
+
else:
|
137 |
+
optimizer_config = cfg.optimizer_config
|
138 |
+
|
139 |
+
# register hooks
|
140 |
+
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
141 |
+
cfg.checkpoint_config, cfg.log_config,
|
142 |
+
cfg.get('momentum_config', None))
|
143 |
+
if distributed:
|
144 |
+
if isinstance(runner, EpochBasedRunner):
|
145 |
+
runner.register_hook(DistSamplerSeedHook())
|
146 |
+
|
147 |
+
# register eval hooks
|
148 |
+
if validate:
|
149 |
+
# Support batch_size > 1 in validation
|
150 |
+
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
|
151 |
+
if val_samples_per_gpu > 1:
|
152 |
+
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
153 |
+
cfg.data.val.pipeline = replace_ImageToTensor(
|
154 |
+
cfg.data.val.pipeline)
|
155 |
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
156 |
+
val_dataloader = build_dataloader(
|
157 |
+
val_dataset,
|
158 |
+
samples_per_gpu=val_samples_per_gpu,
|
159 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
160 |
+
dist=distributed,
|
161 |
+
shuffle=False)
|
162 |
+
eval_cfg = cfg.get('evaluation', {})
|
163 |
+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
164 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
165 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
166 |
+
|
167 |
+
# user-defined hooks
|
168 |
+
if cfg.get('custom_hooks', None):
|
169 |
+
custom_hooks = cfg.custom_hooks
|
170 |
+
assert isinstance(custom_hooks, list), \
|
171 |
+
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
172 |
+
for hook_cfg in cfg.custom_hooks:
|
173 |
+
assert isinstance(hook_cfg, dict), \
|
174 |
+
'Each item in custom_hooks expects dict type, but got ' \
|
175 |
+
f'{type(hook_cfg)}'
|
176 |
+
hook_cfg = hook_cfg.copy()
|
177 |
+
priority = hook_cfg.pop('priority', 'NORMAL')
|
178 |
+
hook = build_from_cfg(hook_cfg, HOOKS)
|
179 |
+
runner.register_hook(hook, priority=priority)
|
180 |
+
|
181 |
+
if cfg.resume_from:
|
182 |
+
runner.resume(cfg.resume_from)
|
183 |
+
elif cfg.load_from:
|
184 |
+
runner.load_checkpoint(cfg.load_from)
|
185 |
+
runner.run(data_loaders, cfg.workflow)
|
mmdet/core/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anchor import * # noqa: F401, F403
|
2 |
+
from .bbox import * # noqa: F401, F403
|
3 |
+
from .evaluation import * # noqa: F401, F403
|
4 |
+
from .export import * # noqa: F401, F403
|
5 |
+
from .mask import * # noqa: F401, F403
|
6 |
+
from .post_processing import * # noqa: F401, F403
|
7 |
+
from .utils import * # noqa: F401, F403
|
mmdet/core/anchor/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
|
2 |
+
YOLOAnchorGenerator)
|
3 |
+
from .builder import ANCHOR_GENERATORS, build_anchor_generator
|
4 |
+
from .point_generator import PointGenerator
|
5 |
+
from .utils import anchor_inside_flags, calc_region, images_to_levels
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
|
9 |
+
'PointGenerator', 'images_to_levels', 'calc_region',
|
10 |
+
'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
|
11 |
+
]
|
mmdet/core/anchor/anchor_generator.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn.modules.utils import _pair
|
5 |
+
|
6 |
+
from .builder import ANCHOR_GENERATORS
|
7 |
+
|
8 |
+
|
9 |
+
@ANCHOR_GENERATORS.register_module()
|
10 |
+
class AnchorGenerator(object):
|
11 |
+
"""Standard anchor generator for 2D anchor-based detectors.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
15 |
+
in multiple feature levels in order (w, h).
|
16 |
+
ratios (list[float]): The list of ratios between the height and width
|
17 |
+
of anchors in a single level.
|
18 |
+
scales (list[int] | None): Anchor scales for anchors in a single level.
|
19 |
+
It cannot be set at the same time if `octave_base_scale` and
|
20 |
+
`scales_per_octave` are set.
|
21 |
+
base_sizes (list[int] | None): The basic sizes
|
22 |
+
of anchors in multiple levels.
|
23 |
+
If None is given, strides will be used as base_sizes.
|
24 |
+
(If strides are non square, the shortest stride is taken.)
|
25 |
+
scale_major (bool): Whether to multiply scales first when generating
|
26 |
+
base anchors. If true, the anchors in the same row will have the
|
27 |
+
same scales. By default it is True in V2.0
|
28 |
+
octave_base_scale (int): The base scale of octave.
|
29 |
+
scales_per_octave (int): Number of scales for each octave.
|
30 |
+
`octave_base_scale` and `scales_per_octave` are usually used in
|
31 |
+
retinanet and the `scales` should be None when they are set.
|
32 |
+
centers (list[tuple[float, float]] | None): The centers of the anchor
|
33 |
+
relative to the feature grid center in multiple feature levels.
|
34 |
+
By default it is set to be None and not used. If a list of tuple of
|
35 |
+
float is given, they will be used to shift the centers of anchors.
|
36 |
+
center_offset (float): The offset of center in proportion to anchors'
|
37 |
+
width and height. By default it is 0 in V2.0.
|
38 |
+
|
39 |
+
Examples:
|
40 |
+
>>> from mmdet.core import AnchorGenerator
|
41 |
+
>>> self = AnchorGenerator([16], [1.], [1.], [9])
|
42 |
+
>>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
|
43 |
+
>>> print(all_anchors)
|
44 |
+
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
|
45 |
+
[11.5000, -4.5000, 20.5000, 4.5000],
|
46 |
+
[-4.5000, 11.5000, 4.5000, 20.5000],
|
47 |
+
[11.5000, 11.5000, 20.5000, 20.5000]])]
|
48 |
+
>>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
|
49 |
+
>>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
|
50 |
+
>>> print(all_anchors)
|
51 |
+
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
|
52 |
+
[11.5000, -4.5000, 20.5000, 4.5000],
|
53 |
+
[-4.5000, 11.5000, 4.5000, 20.5000],
|
54 |
+
[11.5000, 11.5000, 20.5000, 20.5000]]), \
|
55 |
+
tensor([[-9., -9., 9., 9.]])]
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
strides,
|
60 |
+
ratios,
|
61 |
+
scales=None,
|
62 |
+
base_sizes=None,
|
63 |
+
scale_major=True,
|
64 |
+
octave_base_scale=None,
|
65 |
+
scales_per_octave=None,
|
66 |
+
centers=None,
|
67 |
+
center_offset=0.):
|
68 |
+
# check center and center_offset
|
69 |
+
if center_offset != 0:
|
70 |
+
assert centers is None, 'center cannot be set when center_offset' \
|
71 |
+
f'!=0, {centers} is given.'
|
72 |
+
if not (0 <= center_offset <= 1):
|
73 |
+
raise ValueError('center_offset should be in range [0, 1], '
|
74 |
+
f'{center_offset} is given.')
|
75 |
+
if centers is not None:
|
76 |
+
assert len(centers) == len(strides), \
|
77 |
+
'The number of strides should be the same as centers, got ' \
|
78 |
+
f'{strides} and {centers}'
|
79 |
+
|
80 |
+
# calculate base sizes of anchors
|
81 |
+
self.strides = [_pair(stride) for stride in strides]
|
82 |
+
self.base_sizes = [min(stride) for stride in self.strides
|
83 |
+
] if base_sizes is None else base_sizes
|
84 |
+
assert len(self.base_sizes) == len(self.strides), \
|
85 |
+
'The number of strides should be the same as base sizes, got ' \
|
86 |
+
f'{self.strides} and {self.base_sizes}'
|
87 |
+
|
88 |
+
# calculate scales of anchors
|
89 |
+
assert ((octave_base_scale is not None
|
90 |
+
and scales_per_octave is not None) ^ (scales is not None)), \
|
91 |
+
'scales and octave_base_scale with scales_per_octave cannot' \
|
92 |
+
' be set at the same time'
|
93 |
+
if scales is not None:
|
94 |
+
self.scales = torch.Tensor(scales)
|
95 |
+
elif octave_base_scale is not None and scales_per_octave is not None:
|
96 |
+
octave_scales = np.array(
|
97 |
+
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
|
98 |
+
scales = octave_scales * octave_base_scale
|
99 |
+
self.scales = torch.Tensor(scales)
|
100 |
+
else:
|
101 |
+
raise ValueError('Either scales or octave_base_scale with '
|
102 |
+
'scales_per_octave should be set')
|
103 |
+
|
104 |
+
self.octave_base_scale = octave_base_scale
|
105 |
+
self.scales_per_octave = scales_per_octave
|
106 |
+
self.ratios = torch.Tensor(ratios)
|
107 |
+
self.scale_major = scale_major
|
108 |
+
self.centers = centers
|
109 |
+
self.center_offset = center_offset
|
110 |
+
self.base_anchors = self.gen_base_anchors()
|
111 |
+
|
112 |
+
@property
|
113 |
+
def num_base_anchors(self):
|
114 |
+
"""list[int]: total number of base anchors in a feature grid"""
|
115 |
+
return [base_anchors.size(0) for base_anchors in self.base_anchors]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def num_levels(self):
|
119 |
+
"""int: number of feature levels that the generator will be applied"""
|
120 |
+
return len(self.strides)
|
121 |
+
|
122 |
+
def gen_base_anchors(self):
|
123 |
+
"""Generate base anchors.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
127 |
+
feature levels.
|
128 |
+
"""
|
129 |
+
multi_level_base_anchors = []
|
130 |
+
for i, base_size in enumerate(self.base_sizes):
|
131 |
+
center = None
|
132 |
+
if self.centers is not None:
|
133 |
+
center = self.centers[i]
|
134 |
+
multi_level_base_anchors.append(
|
135 |
+
self.gen_single_level_base_anchors(
|
136 |
+
base_size,
|
137 |
+
scales=self.scales,
|
138 |
+
ratios=self.ratios,
|
139 |
+
center=center))
|
140 |
+
return multi_level_base_anchors
|
141 |
+
|
142 |
+
def gen_single_level_base_anchors(self,
|
143 |
+
base_size,
|
144 |
+
scales,
|
145 |
+
ratios,
|
146 |
+
center=None):
|
147 |
+
"""Generate base anchors of a single level.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
base_size (int | float): Basic size of an anchor.
|
151 |
+
scales (torch.Tensor): Scales of the anchor.
|
152 |
+
ratios (torch.Tensor): The ratio between between the height
|
153 |
+
and width of anchors in a single level.
|
154 |
+
center (tuple[float], optional): The center of the base anchor
|
155 |
+
related to a single feature grid. Defaults to None.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Anchors in a single-level feature maps.
|
159 |
+
"""
|
160 |
+
w = base_size
|
161 |
+
h = base_size
|
162 |
+
if center is None:
|
163 |
+
x_center = self.center_offset * w
|
164 |
+
y_center = self.center_offset * h
|
165 |
+
else:
|
166 |
+
x_center, y_center = center
|
167 |
+
|
168 |
+
h_ratios = torch.sqrt(ratios)
|
169 |
+
w_ratios = 1 / h_ratios
|
170 |
+
if self.scale_major:
|
171 |
+
ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
|
172 |
+
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
|
173 |
+
else:
|
174 |
+
ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
|
175 |
+
hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
|
176 |
+
|
177 |
+
# use float anchor and the anchor's center is aligned with the
|
178 |
+
# pixel center
|
179 |
+
base_anchors = [
|
180 |
+
x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
|
181 |
+
y_center + 0.5 * hs
|
182 |
+
]
|
183 |
+
base_anchors = torch.stack(base_anchors, dim=-1)
|
184 |
+
|
185 |
+
return base_anchors
|
186 |
+
|
187 |
+
def _meshgrid(self, x, y, row_major=True):
|
188 |
+
"""Generate mesh grid of x and y.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
x (torch.Tensor): Grids of x dimension.
|
192 |
+
y (torch.Tensor): Grids of y dimension.
|
193 |
+
row_major (bool, optional): Whether to return y grids first.
|
194 |
+
Defaults to True.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
tuple[torch.Tensor]: The mesh grids of x and y.
|
198 |
+
"""
|
199 |
+
# use shape instead of len to keep tracing while exporting to onnx
|
200 |
+
xx = x.repeat(y.shape[0])
|
201 |
+
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
|
202 |
+
if row_major:
|
203 |
+
return xx, yy
|
204 |
+
else:
|
205 |
+
return yy, xx
|
206 |
+
|
207 |
+
def grid_anchors(self, featmap_sizes, device='cuda'):
|
208 |
+
"""Generate grid anchors in multiple feature levels.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
featmap_sizes (list[tuple]): List of feature map sizes in
|
212 |
+
multiple feature levels.
|
213 |
+
device (str): Device where the anchors will be put on.
|
214 |
+
|
215 |
+
Return:
|
216 |
+
list[torch.Tensor]: Anchors in multiple feature levels. \
|
217 |
+
The sizes of each tensor should be [N, 4], where \
|
218 |
+
N = width * height * num_base_anchors, width and height \
|
219 |
+
are the sizes of the corresponding feature level, \
|
220 |
+
num_base_anchors is the number of anchors for that level.
|
221 |
+
"""
|
222 |
+
assert self.num_levels == len(featmap_sizes)
|
223 |
+
multi_level_anchors = []
|
224 |
+
for i in range(self.num_levels):
|
225 |
+
anchors = self.single_level_grid_anchors(
|
226 |
+
self.base_anchors[i].to(device),
|
227 |
+
featmap_sizes[i],
|
228 |
+
self.strides[i],
|
229 |
+
device=device)
|
230 |
+
multi_level_anchors.append(anchors)
|
231 |
+
return multi_level_anchors
|
232 |
+
|
233 |
+
def single_level_grid_anchors(self,
|
234 |
+
base_anchors,
|
235 |
+
featmap_size,
|
236 |
+
stride=(16, 16),
|
237 |
+
device='cuda'):
|
238 |
+
"""Generate grid anchors of a single level.
|
239 |
+
|
240 |
+
Note:
|
241 |
+
This function is usually called by method ``self.grid_anchors``.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
base_anchors (torch.Tensor): The base anchors of a feature grid.
|
245 |
+
featmap_size (tuple[int]): Size of the feature maps.
|
246 |
+
stride (tuple[int], optional): Stride of the feature map in order
|
247 |
+
(w, h). Defaults to (16, 16).
|
248 |
+
device (str, optional): Device the tensor will be put on.
|
249 |
+
Defaults to 'cuda'.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
torch.Tensor: Anchors in the overall feature maps.
|
253 |
+
"""
|
254 |
+
# keep as Tensor, so that we can covert to ONNX correctly
|
255 |
+
feat_h, feat_w = featmap_size
|
256 |
+
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
|
257 |
+
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
|
258 |
+
|
259 |
+
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
260 |
+
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
|
261 |
+
shifts = shifts.type_as(base_anchors)
|
262 |
+
# first feat_w elements correspond to the first row of shifts
|
263 |
+
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
|
264 |
+
# shifted anchors (K, A, 4), reshape to (K*A, 4)
|
265 |
+
|
266 |
+
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
267 |
+
all_anchors = all_anchors.view(-1, 4)
|
268 |
+
# first A rows correspond to A anchors of (0, 0) in feature map,
|
269 |
+
# then (0, 1), (0, 2), ...
|
270 |
+
return all_anchors
|
271 |
+
|
272 |
+
def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
|
273 |
+
"""Generate valid flags of anchors in multiple feature levels.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
featmap_sizes (list(tuple)): List of feature map sizes in
|
277 |
+
multiple feature levels.
|
278 |
+
pad_shape (tuple): The padded shape of the image.
|
279 |
+
device (str): Device where the anchors will be put on.
|
280 |
+
|
281 |
+
Return:
|
282 |
+
list(torch.Tensor): Valid flags of anchors in multiple levels.
|
283 |
+
"""
|
284 |
+
assert self.num_levels == len(featmap_sizes)
|
285 |
+
multi_level_flags = []
|
286 |
+
for i in range(self.num_levels):
|
287 |
+
anchor_stride = self.strides[i]
|
288 |
+
feat_h, feat_w = featmap_sizes[i]
|
289 |
+
h, w = pad_shape[:2]
|
290 |
+
valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
|
291 |
+
valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
|
292 |
+
flags = self.single_level_valid_flags((feat_h, feat_w),
|
293 |
+
(valid_feat_h, valid_feat_w),
|
294 |
+
self.num_base_anchors[i],
|
295 |
+
device=device)
|
296 |
+
multi_level_flags.append(flags)
|
297 |
+
return multi_level_flags
|
298 |
+
|
299 |
+
def single_level_valid_flags(self,
|
300 |
+
featmap_size,
|
301 |
+
valid_size,
|
302 |
+
num_base_anchors,
|
303 |
+
device='cuda'):
|
304 |
+
"""Generate the valid flags of anchor in a single feature map.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
featmap_size (tuple[int]): The size of feature maps.
|
308 |
+
valid_size (tuple[int]): The valid size of the feature maps.
|
309 |
+
num_base_anchors (int): The number of base anchors.
|
310 |
+
device (str, optional): Device where the flags will be put on.
|
311 |
+
Defaults to 'cuda'.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
torch.Tensor: The valid flags of each anchor in a single level \
|
315 |
+
feature map.
|
316 |
+
"""
|
317 |
+
feat_h, feat_w = featmap_size
|
318 |
+
valid_h, valid_w = valid_size
|
319 |
+
assert valid_h <= feat_h and valid_w <= feat_w
|
320 |
+
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
321 |
+
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
322 |
+
valid_x[:valid_w] = 1
|
323 |
+
valid_y[:valid_h] = 1
|
324 |
+
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
325 |
+
valid = valid_xx & valid_yy
|
326 |
+
valid = valid[:, None].expand(valid.size(0),
|
327 |
+
num_base_anchors).contiguous().view(-1)
|
328 |
+
return valid
|
329 |
+
|
330 |
+
def __repr__(self):
|
331 |
+
"""str: a string that describes the module"""
|
332 |
+
indent_str = ' '
|
333 |
+
repr_str = self.__class__.__name__ + '(\n'
|
334 |
+
repr_str += f'{indent_str}strides={self.strides},\n'
|
335 |
+
repr_str += f'{indent_str}ratios={self.ratios},\n'
|
336 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
337 |
+
repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
|
338 |
+
repr_str += f'{indent_str}scale_major={self.scale_major},\n'
|
339 |
+
repr_str += f'{indent_str}octave_base_scale='
|
340 |
+
repr_str += f'{self.octave_base_scale},\n'
|
341 |
+
repr_str += f'{indent_str}scales_per_octave='
|
342 |
+
repr_str += f'{self.scales_per_octave},\n'
|
343 |
+
repr_str += f'{indent_str}num_levels={self.num_levels}\n'
|
344 |
+
repr_str += f'{indent_str}centers={self.centers},\n'
|
345 |
+
repr_str += f'{indent_str}center_offset={self.center_offset})'
|
346 |
+
return repr_str
|
347 |
+
|
348 |
+
|
349 |
+
@ANCHOR_GENERATORS.register_module()
|
350 |
+
class SSDAnchorGenerator(AnchorGenerator):
|
351 |
+
"""Anchor generator for SSD.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
355 |
+
in multiple feature levels.
|
356 |
+
ratios (list[float]): The list of ratios between the height and width
|
357 |
+
of anchors in a single level.
|
358 |
+
basesize_ratio_range (tuple(float)): Ratio range of anchors.
|
359 |
+
input_size (int): Size of feature map, 300 for SSD300,
|
360 |
+
512 for SSD512.
|
361 |
+
scale_major (bool): Whether to multiply scales first when generating
|
362 |
+
base anchors. If true, the anchors in the same row will have the
|
363 |
+
same scales. It is always set to be False in SSD.
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self,
|
367 |
+
strides,
|
368 |
+
ratios,
|
369 |
+
basesize_ratio_range,
|
370 |
+
input_size=300,
|
371 |
+
scale_major=True):
|
372 |
+
assert len(strides) == len(ratios)
|
373 |
+
assert mmcv.is_tuple_of(basesize_ratio_range, float)
|
374 |
+
|
375 |
+
self.strides = [_pair(stride) for stride in strides]
|
376 |
+
self.input_size = input_size
|
377 |
+
self.centers = [(stride[0] / 2., stride[1] / 2.)
|
378 |
+
for stride in self.strides]
|
379 |
+
self.basesize_ratio_range = basesize_ratio_range
|
380 |
+
|
381 |
+
# calculate anchor ratios and sizes
|
382 |
+
min_ratio, max_ratio = basesize_ratio_range
|
383 |
+
min_ratio = int(min_ratio * 100)
|
384 |
+
max_ratio = int(max_ratio * 100)
|
385 |
+
step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
|
386 |
+
min_sizes = []
|
387 |
+
max_sizes = []
|
388 |
+
for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
|
389 |
+
min_sizes.append(int(self.input_size * ratio / 100))
|
390 |
+
max_sizes.append(int(self.input_size * (ratio + step) / 100))
|
391 |
+
if self.input_size == 300:
|
392 |
+
if basesize_ratio_range[0] == 0.15: # SSD300 COCO
|
393 |
+
min_sizes.insert(0, int(self.input_size * 7 / 100))
|
394 |
+
max_sizes.insert(0, int(self.input_size * 15 / 100))
|
395 |
+
elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
|
396 |
+
min_sizes.insert(0, int(self.input_size * 10 / 100))
|
397 |
+
max_sizes.insert(0, int(self.input_size * 20 / 100))
|
398 |
+
else:
|
399 |
+
raise ValueError(
|
400 |
+
'basesize_ratio_range[0] should be either 0.15'
|
401 |
+
'or 0.2 when input_size is 300, got '
|
402 |
+
f'{basesize_ratio_range[0]}.')
|
403 |
+
elif self.input_size == 512:
|
404 |
+
if basesize_ratio_range[0] == 0.1: # SSD512 COCO
|
405 |
+
min_sizes.insert(0, int(self.input_size * 4 / 100))
|
406 |
+
max_sizes.insert(0, int(self.input_size * 10 / 100))
|
407 |
+
elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
|
408 |
+
min_sizes.insert(0, int(self.input_size * 7 / 100))
|
409 |
+
max_sizes.insert(0, int(self.input_size * 15 / 100))
|
410 |
+
else:
|
411 |
+
raise ValueError('basesize_ratio_range[0] should be either 0.1'
|
412 |
+
'or 0.15 when input_size is 512, got'
|
413 |
+
f' {basesize_ratio_range[0]}.')
|
414 |
+
else:
|
415 |
+
raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
|
416 |
+
f', got {self.input_size}.')
|
417 |
+
|
418 |
+
anchor_ratios = []
|
419 |
+
anchor_scales = []
|
420 |
+
for k in range(len(self.strides)):
|
421 |
+
scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
|
422 |
+
anchor_ratio = [1.]
|
423 |
+
for r in ratios[k]:
|
424 |
+
anchor_ratio += [1 / r, r] # 4 or 6 ratio
|
425 |
+
anchor_ratios.append(torch.Tensor(anchor_ratio))
|
426 |
+
anchor_scales.append(torch.Tensor(scales))
|
427 |
+
|
428 |
+
self.base_sizes = min_sizes
|
429 |
+
self.scales = anchor_scales
|
430 |
+
self.ratios = anchor_ratios
|
431 |
+
self.scale_major = scale_major
|
432 |
+
self.center_offset = 0
|
433 |
+
self.base_anchors = self.gen_base_anchors()
|
434 |
+
|
435 |
+
def gen_base_anchors(self):
|
436 |
+
"""Generate base anchors.
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
440 |
+
feature levels.
|
441 |
+
"""
|
442 |
+
multi_level_base_anchors = []
|
443 |
+
for i, base_size in enumerate(self.base_sizes):
|
444 |
+
base_anchors = self.gen_single_level_base_anchors(
|
445 |
+
base_size,
|
446 |
+
scales=self.scales[i],
|
447 |
+
ratios=self.ratios[i],
|
448 |
+
center=self.centers[i])
|
449 |
+
indices = list(range(len(self.ratios[i])))
|
450 |
+
indices.insert(1, len(indices))
|
451 |
+
base_anchors = torch.index_select(base_anchors, 0,
|
452 |
+
torch.LongTensor(indices))
|
453 |
+
multi_level_base_anchors.append(base_anchors)
|
454 |
+
return multi_level_base_anchors
|
455 |
+
|
456 |
+
def __repr__(self):
|
457 |
+
"""str: a string that describes the module"""
|
458 |
+
indent_str = ' '
|
459 |
+
repr_str = self.__class__.__name__ + '(\n'
|
460 |
+
repr_str += f'{indent_str}strides={self.strides},\n'
|
461 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
462 |
+
repr_str += f'{indent_str}scale_major={self.scale_major},\n'
|
463 |
+
repr_str += f'{indent_str}input_size={self.input_size},\n'
|
464 |
+
repr_str += f'{indent_str}scales={self.scales},\n'
|
465 |
+
repr_str += f'{indent_str}ratios={self.ratios},\n'
|
466 |
+
repr_str += f'{indent_str}num_levels={self.num_levels},\n'
|
467 |
+
repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
|
468 |
+
repr_str += f'{indent_str}basesize_ratio_range='
|
469 |
+
repr_str += f'{self.basesize_ratio_range})'
|
470 |
+
return repr_str
|
471 |
+
|
472 |
+
|
473 |
+
@ANCHOR_GENERATORS.register_module()
|
474 |
+
class LegacyAnchorGenerator(AnchorGenerator):
|
475 |
+
"""Legacy anchor generator used in MMDetection V1.x.
|
476 |
+
|
477 |
+
Note:
|
478 |
+
Difference to the V2.0 anchor generator:
|
479 |
+
|
480 |
+
1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
|
481 |
+
2. The width/height are minused by 1 when calculating the anchors' \
|
482 |
+
centers and corners to meet the V1.x coordinate system.
|
483 |
+
3. The anchors' corners are quantized.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
strides (list[int] | list[tuple[int]]): Strides of anchors
|
487 |
+
in multiple feature levels.
|
488 |
+
ratios (list[float]): The list of ratios between the height and width
|
489 |
+
of anchors in a single level.
|
490 |
+
scales (list[int] | None): Anchor scales for anchors in a single level.
|
491 |
+
It cannot be set at the same time if `octave_base_scale` and
|
492 |
+
`scales_per_octave` are set.
|
493 |
+
base_sizes (list[int]): The basic sizes of anchors in multiple levels.
|
494 |
+
If None is given, strides will be used to generate base_sizes.
|
495 |
+
scale_major (bool): Whether to multiply scales first when generating
|
496 |
+
base anchors. If true, the anchors in the same row will have the
|
497 |
+
same scales. By default it is True in V2.0
|
498 |
+
octave_base_scale (int): The base scale of octave.
|
499 |
+
scales_per_octave (int): Number of scales for each octave.
|
500 |
+
`octave_base_scale` and `scales_per_octave` are usually used in
|
501 |
+
retinanet and the `scales` should be None when they are set.
|
502 |
+
centers (list[tuple[float, float]] | None): The centers of the anchor
|
503 |
+
relative to the feature grid center in multiple feature levels.
|
504 |
+
By default it is set to be None and not used. It a list of float
|
505 |
+
is given, this list will be used to shift the centers of anchors.
|
506 |
+
center_offset (float): The offset of center in propotion to anchors'
|
507 |
+
width and height. By default it is 0.5 in V2.0 but it should be 0.5
|
508 |
+
in v1.x models.
|
509 |
+
|
510 |
+
Examples:
|
511 |
+
>>> from mmdet.core import LegacyAnchorGenerator
|
512 |
+
>>> self = LegacyAnchorGenerator(
|
513 |
+
>>> [16], [1.], [1.], [9], center_offset=0.5)
|
514 |
+
>>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
|
515 |
+
>>> print(all_anchors)
|
516 |
+
[tensor([[ 0., 0., 8., 8.],
|
517 |
+
[16., 0., 24., 8.],
|
518 |
+
[ 0., 16., 8., 24.],
|
519 |
+
[16., 16., 24., 24.]])]
|
520 |
+
"""
|
521 |
+
|
522 |
+
def gen_single_level_base_anchors(self,
|
523 |
+
base_size,
|
524 |
+
scales,
|
525 |
+
ratios,
|
526 |
+
center=None):
|
527 |
+
"""Generate base anchors of a single level.
|
528 |
+
|
529 |
+
Note:
|
530 |
+
The width/height of anchors are minused by 1 when calculating \
|
531 |
+
the centers and corners to meet the V1.x coordinate system.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
base_size (int | float): Basic size of an anchor.
|
535 |
+
scales (torch.Tensor): Scales of the anchor.
|
536 |
+
ratios (torch.Tensor): The ratio between between the height.
|
537 |
+
and width of anchors in a single level.
|
538 |
+
center (tuple[float], optional): The center of the base anchor
|
539 |
+
related to a single feature grid. Defaults to None.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
torch.Tensor: Anchors in a single-level feature map.
|
543 |
+
"""
|
544 |
+
w = base_size
|
545 |
+
h = base_size
|
546 |
+
if center is None:
|
547 |
+
x_center = self.center_offset * (w - 1)
|
548 |
+
y_center = self.center_offset * (h - 1)
|
549 |
+
else:
|
550 |
+
x_center, y_center = center
|
551 |
+
|
552 |
+
h_ratios = torch.sqrt(ratios)
|
553 |
+
w_ratios = 1 / h_ratios
|
554 |
+
if self.scale_major:
|
555 |
+
ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
|
556 |
+
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
|
557 |
+
else:
|
558 |
+
ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
|
559 |
+
hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
|
560 |
+
|
561 |
+
# use float anchor and the anchor's center is aligned with the
|
562 |
+
# pixel center
|
563 |
+
base_anchors = [
|
564 |
+
x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
|
565 |
+
x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
|
566 |
+
]
|
567 |
+
base_anchors = torch.stack(base_anchors, dim=-1).round()
|
568 |
+
|
569 |
+
return base_anchors
|
570 |
+
|
571 |
+
|
572 |
+
@ANCHOR_GENERATORS.register_module()
|
573 |
+
class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
|
574 |
+
"""Legacy anchor generator used in MMDetection V1.x.
|
575 |
+
|
576 |
+
The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
|
577 |
+
can be found in `LegacyAnchorGenerator`.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self,
|
581 |
+
strides,
|
582 |
+
ratios,
|
583 |
+
basesize_ratio_range,
|
584 |
+
input_size=300,
|
585 |
+
scale_major=True):
|
586 |
+
super(LegacySSDAnchorGenerator,
|
587 |
+
self).__init__(strides, ratios, basesize_ratio_range, input_size,
|
588 |
+
scale_major)
|
589 |
+
self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
|
590 |
+
for stride in strides]
|
591 |
+
self.base_anchors = self.gen_base_anchors()
|
592 |
+
|
593 |
+
|
594 |
+
@ANCHOR_GENERATORS.register_module()
|
595 |
+
class YOLOAnchorGenerator(AnchorGenerator):
|
596 |
+
"""Anchor generator for YOLO.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
600 |
+
in multiple feature levels.
|
601 |
+
base_sizes (list[list[tuple[int, int]]]): The basic sizes
|
602 |
+
of anchors in multiple levels.
|
603 |
+
"""
|
604 |
+
|
605 |
+
def __init__(self, strides, base_sizes):
|
606 |
+
self.strides = [_pair(stride) for stride in strides]
|
607 |
+
self.centers = [(stride[0] / 2., stride[1] / 2.)
|
608 |
+
for stride in self.strides]
|
609 |
+
self.base_sizes = []
|
610 |
+
num_anchor_per_level = len(base_sizes[0])
|
611 |
+
for base_sizes_per_level in base_sizes:
|
612 |
+
assert num_anchor_per_level == len(base_sizes_per_level)
|
613 |
+
self.base_sizes.append(
|
614 |
+
[_pair(base_size) for base_size in base_sizes_per_level])
|
615 |
+
self.base_anchors = self.gen_base_anchors()
|
616 |
+
|
617 |
+
@property
|
618 |
+
def num_levels(self):
|
619 |
+
"""int: number of feature levels that the generator will be applied"""
|
620 |
+
return len(self.base_sizes)
|
621 |
+
|
622 |
+
def gen_base_anchors(self):
|
623 |
+
"""Generate base anchors.
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
list(torch.Tensor): Base anchors of a feature grid in multiple \
|
627 |
+
feature levels.
|
628 |
+
"""
|
629 |
+
multi_level_base_anchors = []
|
630 |
+
for i, base_sizes_per_level in enumerate(self.base_sizes):
|
631 |
+
center = None
|
632 |
+
if self.centers is not None:
|
633 |
+
center = self.centers[i]
|
634 |
+
multi_level_base_anchors.append(
|
635 |
+
self.gen_single_level_base_anchors(base_sizes_per_level,
|
636 |
+
center))
|
637 |
+
return multi_level_base_anchors
|
638 |
+
|
639 |
+
def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
|
640 |
+
"""Generate base anchors of a single level.
|
641 |
+
|
642 |
+
Args:
|
643 |
+
base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
|
644 |
+
anchors.
|
645 |
+
center (tuple[float], optional): The center of the base anchor
|
646 |
+
related to a single feature grid. Defaults to None.
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
torch.Tensor: Anchors in a single-level feature maps.
|
650 |
+
"""
|
651 |
+
x_center, y_center = center
|
652 |
+
base_anchors = []
|
653 |
+
for base_size in base_sizes_per_level:
|
654 |
+
w, h = base_size
|
655 |
+
|
656 |
+
# use float anchor and the anchor's center is aligned with the
|
657 |
+
# pixel center
|
658 |
+
base_anchor = torch.Tensor([
|
659 |
+
x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
|
660 |
+
y_center + 0.5 * h
|
661 |
+
])
|
662 |
+
base_anchors.append(base_anchor)
|
663 |
+
base_anchors = torch.stack(base_anchors, dim=0)
|
664 |
+
|
665 |
+
return base_anchors
|
666 |
+
|
667 |
+
def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
|
668 |
+
"""Generate responsible anchor flags of grid cells in multiple scales.
|
669 |
+
|
670 |
+
Args:
|
671 |
+
featmap_sizes (list(tuple)): List of feature map sizes in multiple
|
672 |
+
feature levels.
|
673 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
|
674 |
+
device (str): Device where the anchors will be put on.
|
675 |
+
|
676 |
+
Return:
|
677 |
+
list(torch.Tensor): responsible flags of anchors in multiple level
|
678 |
+
"""
|
679 |
+
assert self.num_levels == len(featmap_sizes)
|
680 |
+
multi_level_responsible_flags = []
|
681 |
+
for i in range(self.num_levels):
|
682 |
+
anchor_stride = self.strides[i]
|
683 |
+
flags = self.single_level_responsible_flags(
|
684 |
+
featmap_sizes[i],
|
685 |
+
gt_bboxes,
|
686 |
+
anchor_stride,
|
687 |
+
self.num_base_anchors[i],
|
688 |
+
device=device)
|
689 |
+
multi_level_responsible_flags.append(flags)
|
690 |
+
return multi_level_responsible_flags
|
691 |
+
|
692 |
+
def single_level_responsible_flags(self,
|
693 |
+
featmap_size,
|
694 |
+
gt_bboxes,
|
695 |
+
stride,
|
696 |
+
num_base_anchors,
|
697 |
+
device='cuda'):
|
698 |
+
"""Generate the responsible flags of anchor in a single feature map.
|
699 |
+
|
700 |
+
Args:
|
701 |
+
featmap_size (tuple[int]): The size of feature maps.
|
702 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
|
703 |
+
stride (tuple(int)): stride of current level
|
704 |
+
num_base_anchors (int): The number of base anchors.
|
705 |
+
device (str, optional): Device where the flags will be put on.
|
706 |
+
Defaults to 'cuda'.
|
707 |
+
|
708 |
+
Returns:
|
709 |
+
torch.Tensor: The valid flags of each anchor in a single level \
|
710 |
+
feature map.
|
711 |
+
"""
|
712 |
+
feat_h, feat_w = featmap_size
|
713 |
+
gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
|
714 |
+
gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
|
715 |
+
gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
|
716 |
+
gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
|
717 |
+
|
718 |
+
# row major indexing
|
719 |
+
gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
|
720 |
+
|
721 |
+
responsible_grid = torch.zeros(
|
722 |
+
feat_h * feat_w, dtype=torch.uint8, device=device)
|
723 |
+
responsible_grid[gt_bboxes_grid_idx] = 1
|
724 |
+
|
725 |
+
responsible_grid = responsible_grid[:, None].expand(
|
726 |
+
responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
|
727 |
+
return responsible_grid
|
mmdet/core/anchor/builder.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.utils import Registry, build_from_cfg
|
2 |
+
|
3 |
+
ANCHOR_GENERATORS = Registry('Anchor generator')
|
4 |
+
|
5 |
+
|
6 |
+
def build_anchor_generator(cfg, default_args=None):
|
7 |
+
return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
|
mmdet/core/anchor/point_generator.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .builder import ANCHOR_GENERATORS
|
4 |
+
|
5 |
+
|
6 |
+
@ANCHOR_GENERATORS.register_module()
|
7 |
+
class PointGenerator(object):
|
8 |
+
|
9 |
+
def _meshgrid(self, x, y, row_major=True):
|
10 |
+
xx = x.repeat(len(y))
|
11 |
+
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
|
12 |
+
if row_major:
|
13 |
+
return xx, yy
|
14 |
+
else:
|
15 |
+
return yy, xx
|
16 |
+
|
17 |
+
def grid_points(self, featmap_size, stride=16, device='cuda'):
|
18 |
+
feat_h, feat_w = featmap_size
|
19 |
+
shift_x = torch.arange(0., feat_w, device=device) * stride
|
20 |
+
shift_y = torch.arange(0., feat_h, device=device) * stride
|
21 |
+
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
22 |
+
stride = shift_x.new_full((shift_xx.shape[0], ), stride)
|
23 |
+
shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
|
24 |
+
all_points = shifts.to(device)
|
25 |
+
return all_points
|
26 |
+
|
27 |
+
def valid_flags(self, featmap_size, valid_size, device='cuda'):
|
28 |
+
feat_h, feat_w = featmap_size
|
29 |
+
valid_h, valid_w = valid_size
|
30 |
+
assert valid_h <= feat_h and valid_w <= feat_w
|
31 |
+
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
32 |
+
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
33 |
+
valid_x[:valid_w] = 1
|
34 |
+
valid_y[:valid_h] = 1
|
35 |
+
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
36 |
+
valid = valid_xx & valid_yy
|
37 |
+
return valid
|
mmdet/core/anchor/utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def images_to_levels(target, num_levels):
|
5 |
+
"""Convert targets by image to targets by feature level.
|
6 |
+
|
7 |
+
[target_img0, target_img1] -> [target_level0, target_level1, ...]
|
8 |
+
"""
|
9 |
+
target = torch.stack(target, 0)
|
10 |
+
level_targets = []
|
11 |
+
start = 0
|
12 |
+
for n in num_levels:
|
13 |
+
end = start + n
|
14 |
+
# level_targets.append(target[:, start:end].squeeze(0))
|
15 |
+
level_targets.append(target[:, start:end])
|
16 |
+
start = end
|
17 |
+
return level_targets
|
18 |
+
|
19 |
+
|
20 |
+
def anchor_inside_flags(flat_anchors,
|
21 |
+
valid_flags,
|
22 |
+
img_shape,
|
23 |
+
allowed_border=0):
|
24 |
+
"""Check whether the anchors are inside the border.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
|
28 |
+
valid_flags (torch.Tensor): An existing valid flags of anchors.
|
29 |
+
img_shape (tuple(int)): Shape of current image.
|
30 |
+
allowed_border (int, optional): The border to allow the valid anchor.
|
31 |
+
Defaults to 0.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
torch.Tensor: Flags indicating whether the anchors are inside a \
|
35 |
+
valid range.
|
36 |
+
"""
|
37 |
+
img_h, img_w = img_shape[:2]
|
38 |
+
if allowed_border >= 0:
|
39 |
+
inside_flags = valid_flags & \
|
40 |
+
(flat_anchors[:, 0] >= -allowed_border) & \
|
41 |
+
(flat_anchors[:, 1] >= -allowed_border) & \
|
42 |
+
(flat_anchors[:, 2] < img_w + allowed_border) & \
|
43 |
+
(flat_anchors[:, 3] < img_h + allowed_border)
|
44 |
+
else:
|
45 |
+
inside_flags = valid_flags
|
46 |
+
return inside_flags
|
47 |
+
|
48 |
+
|
49 |
+
def calc_region(bbox, ratio, featmap_size=None):
|
50 |
+
"""Calculate a proportional bbox region.
|
51 |
+
|
52 |
+
The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
|
56 |
+
ratio (float): Ratio of the output region.
|
57 |
+
featmap_size (tuple): Feature map size used for clipping the boundary.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tuple: x1, y1, x2, y2
|
61 |
+
"""
|
62 |
+
x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
|
63 |
+
y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
|
64 |
+
x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
|
65 |
+
y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
|
66 |
+
if featmap_size is not None:
|
67 |
+
x1 = x1.clamp(min=0, max=featmap_size[1])
|
68 |
+
y1 = y1.clamp(min=0, max=featmap_size[0])
|
69 |
+
x2 = x2.clamp(min=0, max=featmap_size[1])
|
70 |
+
y2 = y2.clamp(min=0, max=featmap_size[0])
|
71 |
+
return (x1, y1, x2, y2)
|
mmdet/core/bbox/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
|
2 |
+
MaxIoUAssigner, RegionAssigner)
|
3 |
+
from .builder import build_assigner, build_bbox_coder, build_sampler
|
4 |
+
from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
|
5 |
+
TBLRBBoxCoder)
|
6 |
+
from .iou_calculators import BboxOverlaps2D, bbox_overlaps
|
7 |
+
from .samplers import (BaseSampler, CombinedSampler,
|
8 |
+
InstanceBalancedPosSampler, IoUBalancedNegSampler,
|
9 |
+
OHEMSampler, PseudoSampler, RandomSampler,
|
10 |
+
SamplingResult, ScoreHLRSampler)
|
11 |
+
from .transforms import (bbox2distance, bbox2result, bbox2roi,
|
12 |
+
bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
|
13 |
+
bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
|
14 |
+
distance2bbox, roi2bbox)
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
|
18 |
+
'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
|
19 |
+
'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
|
20 |
+
'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
|
21 |
+
'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
|
22 |
+
'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
|
23 |
+
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
|
24 |
+
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
|
25 |
+
'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
|
26 |
+
'RegionAssigner'
|
27 |
+
]
|
mmdet/core/bbox/assigners/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .approx_max_iou_assigner import ApproxMaxIoUAssigner
|
2 |
+
from .assign_result import AssignResult
|
3 |
+
from .atss_assigner import ATSSAssigner
|
4 |
+
from .base_assigner import BaseAssigner
|
5 |
+
from .center_region_assigner import CenterRegionAssigner
|
6 |
+
from .grid_assigner import GridAssigner
|
7 |
+
from .hungarian_assigner import HungarianAssigner
|
8 |
+
from .max_iou_assigner import MaxIoUAssigner
|
9 |
+
from .point_assigner import PointAssigner
|
10 |
+
from .region_assigner import RegionAssigner
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
|
14 |
+
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
|
15 |
+
'HungarianAssigner', 'RegionAssigner'
|
16 |
+
]
|
mmdet/core/bbox/assigners/approx_max_iou_assigner.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .max_iou_assigner import MaxIoUAssigner
|
6 |
+
|
7 |
+
|
8 |
+
@BBOX_ASSIGNERS.register_module()
|
9 |
+
class ApproxMaxIoUAssigner(MaxIoUAssigner):
|
10 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
11 |
+
|
12 |
+
Each proposals will be assigned with an integer indicating the ground truth
|
13 |
+
index. (semi-positive index: gt label (0-based), -1: background)
|
14 |
+
|
15 |
+
- -1: negative sample, no assigned gt
|
16 |
+
- semi-positive integer: positive sample, index (0-based) of assigned gt
|
17 |
+
|
18 |
+
Args:
|
19 |
+
pos_iou_thr (float): IoU threshold for positive bboxes.
|
20 |
+
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
|
21 |
+
min_pos_iou (float): Minimum iou for a bbox to be considered as a
|
22 |
+
positive bbox. Positive samples can have smaller IoU than
|
23 |
+
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
|
24 |
+
gt_max_assign_all (bool): Whether to assign all bboxes with the same
|
25 |
+
highest overlap with some gt to that gt.
|
26 |
+
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
|
27 |
+
`gt_bboxes_ignore` is specified). Negative values mean not
|
28 |
+
ignoring any bboxes.
|
29 |
+
ignore_wrt_candidates (bool): Whether to compute the iof between
|
30 |
+
`bboxes` and `gt_bboxes_ignore`, or the contrary.
|
31 |
+
match_low_quality (bool): Whether to allow quality matches. This is
|
32 |
+
usually allowed for RPN and single stage detectors, but not allowed
|
33 |
+
in the second stage.
|
34 |
+
gpu_assign_thr (int): The upper bound of the number of GT for GPU
|
35 |
+
assign. When the number of gt is above this threshold, will assign
|
36 |
+
on CPU device. Negative values mean not assign on CPU.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
pos_iou_thr,
|
41 |
+
neg_iou_thr,
|
42 |
+
min_pos_iou=.0,
|
43 |
+
gt_max_assign_all=True,
|
44 |
+
ignore_iof_thr=-1,
|
45 |
+
ignore_wrt_candidates=True,
|
46 |
+
match_low_quality=True,
|
47 |
+
gpu_assign_thr=-1,
|
48 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
49 |
+
self.pos_iou_thr = pos_iou_thr
|
50 |
+
self.neg_iou_thr = neg_iou_thr
|
51 |
+
self.min_pos_iou = min_pos_iou
|
52 |
+
self.gt_max_assign_all = gt_max_assign_all
|
53 |
+
self.ignore_iof_thr = ignore_iof_thr
|
54 |
+
self.ignore_wrt_candidates = ignore_wrt_candidates
|
55 |
+
self.gpu_assign_thr = gpu_assign_thr
|
56 |
+
self.match_low_quality = match_low_quality
|
57 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
58 |
+
|
59 |
+
def assign(self,
|
60 |
+
approxs,
|
61 |
+
squares,
|
62 |
+
approxs_per_octave,
|
63 |
+
gt_bboxes,
|
64 |
+
gt_bboxes_ignore=None,
|
65 |
+
gt_labels=None):
|
66 |
+
"""Assign gt to approxs.
|
67 |
+
|
68 |
+
This method assign a gt bbox to each group of approxs (bboxes),
|
69 |
+
each group of approxs is represent by a base approx (bbox) and
|
70 |
+
will be assigned with -1, or a semi-positive number.
|
71 |
+
background_label (-1) means negative sample,
|
72 |
+
semi-positive number is the index (0-based) of assigned gt.
|
73 |
+
The assignment is done in following steps, the order matters.
|
74 |
+
|
75 |
+
1. assign every bbox to background_label (-1)
|
76 |
+
2. use the max IoU of each group of approxs to assign
|
77 |
+
2. assign proposals whose iou with all gts < neg_iou_thr to background
|
78 |
+
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
|
79 |
+
assign it to that bbox
|
80 |
+
4. for each gt bbox, assign its nearest proposals (may be more than
|
81 |
+
one) to itself
|
82 |
+
|
83 |
+
Args:
|
84 |
+
approxs (Tensor): Bounding boxes to be assigned,
|
85 |
+
shape(approxs_per_octave*n, 4).
|
86 |
+
squares (Tensor): Base Bounding boxes to be assigned,
|
87 |
+
shape(n, 4).
|
88 |
+
approxs_per_octave (int): number of approxs per octave
|
89 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
90 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
91 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
92 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
:obj:`AssignResult`: The assign result.
|
96 |
+
"""
|
97 |
+
num_squares = squares.size(0)
|
98 |
+
num_gts = gt_bboxes.size(0)
|
99 |
+
|
100 |
+
if num_squares == 0 or num_gts == 0:
|
101 |
+
# No predictions and/or truth, return empty assignment
|
102 |
+
overlaps = approxs.new(num_gts, num_squares)
|
103 |
+
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
|
104 |
+
return assign_result
|
105 |
+
|
106 |
+
# re-organize anchors by approxs_per_octave x num_squares
|
107 |
+
approxs = torch.transpose(
|
108 |
+
approxs.view(num_squares, approxs_per_octave, 4), 0,
|
109 |
+
1).contiguous().view(-1, 4)
|
110 |
+
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
|
111 |
+
num_gts > self.gpu_assign_thr) else False
|
112 |
+
# compute overlap and assign gt on CPU when number of GT is large
|
113 |
+
if assign_on_cpu:
|
114 |
+
device = approxs.device
|
115 |
+
approxs = approxs.cpu()
|
116 |
+
gt_bboxes = gt_bboxes.cpu()
|
117 |
+
if gt_bboxes_ignore is not None:
|
118 |
+
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
|
119 |
+
if gt_labels is not None:
|
120 |
+
gt_labels = gt_labels.cpu()
|
121 |
+
all_overlaps = self.iou_calculator(approxs, gt_bboxes)
|
122 |
+
|
123 |
+
overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
|
124 |
+
num_gts).max(dim=0)
|
125 |
+
overlaps = torch.transpose(overlaps, 0, 1)
|
126 |
+
|
127 |
+
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
|
128 |
+
and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
|
129 |
+
if self.ignore_wrt_candidates:
|
130 |
+
ignore_overlaps = self.iou_calculator(
|
131 |
+
squares, gt_bboxes_ignore, mode='iof')
|
132 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
|
133 |
+
else:
|
134 |
+
ignore_overlaps = self.iou_calculator(
|
135 |
+
gt_bboxes_ignore, squares, mode='iof')
|
136 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
|
137 |
+
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
|
138 |
+
|
139 |
+
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
|
140 |
+
if assign_on_cpu:
|
141 |
+
assign_result.gt_inds = assign_result.gt_inds.to(device)
|
142 |
+
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
|
143 |
+
if assign_result.labels is not None:
|
144 |
+
assign_result.labels = assign_result.labels.to(device)
|
145 |
+
return assign_result
|
mmdet/core/bbox/assigners/assign_result.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from mmdet.utils import util_mixins
|
4 |
+
|
5 |
+
|
6 |
+
class AssignResult(util_mixins.NiceRepr):
|
7 |
+
"""Stores assignments between predicted and truth boxes.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
num_gts (int): the number of truth boxes considered when computing this
|
11 |
+
assignment
|
12 |
+
|
13 |
+
gt_inds (LongTensor): for each predicted box indicates the 1-based
|
14 |
+
index of the assigned truth box. 0 means unassigned and -1 means
|
15 |
+
ignore.
|
16 |
+
|
17 |
+
max_overlaps (FloatTensor): the iou between the predicted box and its
|
18 |
+
assigned truth box.
|
19 |
+
|
20 |
+
labels (None | LongTensor): If specified, for each predicted box
|
21 |
+
indicates the category label of the assigned truth box.
|
22 |
+
|
23 |
+
Example:
|
24 |
+
>>> # An assign result between 4 predicted boxes and 9 true boxes
|
25 |
+
>>> # where only two boxes were assigned.
|
26 |
+
>>> num_gts = 9
|
27 |
+
>>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
|
28 |
+
>>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
|
29 |
+
>>> labels = torch.LongTensor([0, 3, 4, 0])
|
30 |
+
>>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
|
31 |
+
>>> print(str(self)) # xdoctest: +IGNORE_WANT
|
32 |
+
<AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
|
33 |
+
labels.shape=(4,))>
|
34 |
+
>>> # Force addition of gt labels (when adding gt as proposals)
|
35 |
+
>>> new_labels = torch.LongTensor([3, 4, 5])
|
36 |
+
>>> self.add_gt_(new_labels)
|
37 |
+
>>> print(str(self)) # xdoctest: +IGNORE_WANT
|
38 |
+
<AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
|
39 |
+
labels.shape=(7,))>
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
|
43 |
+
self.num_gts = num_gts
|
44 |
+
self.gt_inds = gt_inds
|
45 |
+
self.max_overlaps = max_overlaps
|
46 |
+
self.labels = labels
|
47 |
+
# Interface for possible user-defined properties
|
48 |
+
self._extra_properties = {}
|
49 |
+
|
50 |
+
@property
|
51 |
+
def num_preds(self):
|
52 |
+
"""int: the number of predictions in this assignment"""
|
53 |
+
return len(self.gt_inds)
|
54 |
+
|
55 |
+
def set_extra_property(self, key, value):
|
56 |
+
"""Set user-defined new property."""
|
57 |
+
assert key not in self.info
|
58 |
+
self._extra_properties[key] = value
|
59 |
+
|
60 |
+
def get_extra_property(self, key):
|
61 |
+
"""Get user-defined property."""
|
62 |
+
return self._extra_properties.get(key, None)
|
63 |
+
|
64 |
+
@property
|
65 |
+
def info(self):
|
66 |
+
"""dict: a dictionary of info about the object"""
|
67 |
+
basic_info = {
|
68 |
+
'num_gts': self.num_gts,
|
69 |
+
'num_preds': self.num_preds,
|
70 |
+
'gt_inds': self.gt_inds,
|
71 |
+
'max_overlaps': self.max_overlaps,
|
72 |
+
'labels': self.labels,
|
73 |
+
}
|
74 |
+
basic_info.update(self._extra_properties)
|
75 |
+
return basic_info
|
76 |
+
|
77 |
+
def __nice__(self):
|
78 |
+
"""str: a "nice" summary string describing this assign result"""
|
79 |
+
parts = []
|
80 |
+
parts.append(f'num_gts={self.num_gts!r}')
|
81 |
+
if self.gt_inds is None:
|
82 |
+
parts.append(f'gt_inds={self.gt_inds!r}')
|
83 |
+
else:
|
84 |
+
parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
|
85 |
+
if self.max_overlaps is None:
|
86 |
+
parts.append(f'max_overlaps={self.max_overlaps!r}')
|
87 |
+
else:
|
88 |
+
parts.append('max_overlaps.shape='
|
89 |
+
f'{tuple(self.max_overlaps.shape)!r}')
|
90 |
+
if self.labels is None:
|
91 |
+
parts.append(f'labels={self.labels!r}')
|
92 |
+
else:
|
93 |
+
parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
|
94 |
+
return ', '.join(parts)
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def random(cls, **kwargs):
|
98 |
+
"""Create random AssignResult for tests or debugging.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
num_preds: number of predicted boxes
|
102 |
+
num_gts: number of true boxes
|
103 |
+
p_ignore (float): probability of a predicted box assinged to an
|
104 |
+
ignored truth
|
105 |
+
p_assigned (float): probability of a predicted box not being
|
106 |
+
assigned
|
107 |
+
p_use_label (float | bool): with labels or not
|
108 |
+
rng (None | int | numpy.random.RandomState): seed or state
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
:obj:`AssignResult`: Randomly generated assign results.
|
112 |
+
|
113 |
+
Example:
|
114 |
+
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
|
115 |
+
>>> self = AssignResult.random()
|
116 |
+
>>> print(self.info)
|
117 |
+
"""
|
118 |
+
from mmdet.core.bbox import demodata
|
119 |
+
rng = demodata.ensure_rng(kwargs.get('rng', None))
|
120 |
+
|
121 |
+
num_gts = kwargs.get('num_gts', None)
|
122 |
+
num_preds = kwargs.get('num_preds', None)
|
123 |
+
p_ignore = kwargs.get('p_ignore', 0.3)
|
124 |
+
p_assigned = kwargs.get('p_assigned', 0.7)
|
125 |
+
p_use_label = kwargs.get('p_use_label', 0.5)
|
126 |
+
num_classes = kwargs.get('p_use_label', 3)
|
127 |
+
|
128 |
+
if num_gts is None:
|
129 |
+
num_gts = rng.randint(0, 8)
|
130 |
+
if num_preds is None:
|
131 |
+
num_preds = rng.randint(0, 16)
|
132 |
+
|
133 |
+
if num_gts == 0:
|
134 |
+
max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
|
135 |
+
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
|
136 |
+
if p_use_label is True or p_use_label < rng.rand():
|
137 |
+
labels = torch.zeros(num_preds, dtype=torch.int64)
|
138 |
+
else:
|
139 |
+
labels = None
|
140 |
+
else:
|
141 |
+
import numpy as np
|
142 |
+
# Create an overlap for each predicted box
|
143 |
+
max_overlaps = torch.from_numpy(rng.rand(num_preds))
|
144 |
+
|
145 |
+
# Construct gt_inds for each predicted box
|
146 |
+
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
|
147 |
+
# maximum number of assignments constraints
|
148 |
+
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
|
149 |
+
|
150 |
+
assigned_idxs = np.where(is_assigned)[0]
|
151 |
+
rng.shuffle(assigned_idxs)
|
152 |
+
assigned_idxs = assigned_idxs[0:n_assigned]
|
153 |
+
assigned_idxs.sort()
|
154 |
+
|
155 |
+
is_assigned[:] = 0
|
156 |
+
is_assigned[assigned_idxs] = True
|
157 |
+
|
158 |
+
is_ignore = torch.from_numpy(
|
159 |
+
rng.rand(num_preds) < p_ignore) & is_assigned
|
160 |
+
|
161 |
+
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
|
162 |
+
|
163 |
+
true_idxs = np.arange(num_gts)
|
164 |
+
rng.shuffle(true_idxs)
|
165 |
+
true_idxs = torch.from_numpy(true_idxs)
|
166 |
+
gt_inds[is_assigned] = true_idxs[:n_assigned]
|
167 |
+
|
168 |
+
gt_inds = torch.from_numpy(
|
169 |
+
rng.randint(1, num_gts + 1, size=num_preds))
|
170 |
+
gt_inds[is_ignore] = -1
|
171 |
+
gt_inds[~is_assigned] = 0
|
172 |
+
max_overlaps[~is_assigned] = 0
|
173 |
+
|
174 |
+
if p_use_label is True or p_use_label < rng.rand():
|
175 |
+
if num_classes == 0:
|
176 |
+
labels = torch.zeros(num_preds, dtype=torch.int64)
|
177 |
+
else:
|
178 |
+
labels = torch.from_numpy(
|
179 |
+
# remind that we set FG labels to [0, num_class-1]
|
180 |
+
# since mmdet v2.0
|
181 |
+
# BG cat_id: num_class
|
182 |
+
rng.randint(0, num_classes, size=num_preds))
|
183 |
+
labels[~is_assigned] = 0
|
184 |
+
else:
|
185 |
+
labels = None
|
186 |
+
|
187 |
+
self = cls(num_gts, gt_inds, max_overlaps, labels)
|
188 |
+
return self
|
189 |
+
|
190 |
+
def add_gt_(self, gt_labels):
|
191 |
+
"""Add ground truth as assigned results.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
gt_labels (torch.Tensor): Labels of gt boxes
|
195 |
+
"""
|
196 |
+
self_inds = torch.arange(
|
197 |
+
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
|
198 |
+
self.gt_inds = torch.cat([self_inds, self.gt_inds])
|
199 |
+
|
200 |
+
self.max_overlaps = torch.cat(
|
201 |
+
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
|
202 |
+
|
203 |
+
if self.labels is not None:
|
204 |
+
self.labels = torch.cat([gt_labels, self.labels])
|
mmdet/core/bbox/assigners/atss_assigner.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_ASSIGNERS.register_module()
|
10 |
+
class ATSSAssigner(BaseAssigner):
|
11 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
12 |
+
|
13 |
+
Each proposals will be assigned with `0` or a positive integer
|
14 |
+
indicating the ground truth index.
|
15 |
+
|
16 |
+
- 0: negative sample, no assigned gt
|
17 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
18 |
+
|
19 |
+
Args:
|
20 |
+
topk (float): number of bbox selected in each level
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
topk,
|
25 |
+
iou_calculator=dict(type='BboxOverlaps2D'),
|
26 |
+
ignore_iof_thr=-1):
|
27 |
+
self.topk = topk
|
28 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
29 |
+
self.ignore_iof_thr = ignore_iof_thr
|
30 |
+
|
31 |
+
# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
|
32 |
+
|
33 |
+
def assign(self,
|
34 |
+
bboxes,
|
35 |
+
num_level_bboxes,
|
36 |
+
gt_bboxes,
|
37 |
+
gt_bboxes_ignore=None,
|
38 |
+
gt_labels=None):
|
39 |
+
"""Assign gt to bboxes.
|
40 |
+
|
41 |
+
The assignment is done in following steps
|
42 |
+
|
43 |
+
1. compute iou between all bbox (bbox of all pyramid levels) and gt
|
44 |
+
2. compute center distance between all bbox and gt
|
45 |
+
3. on each pyramid level, for each gt, select k bbox whose center
|
46 |
+
are closest to the gt center, so we total select k*l bbox as
|
47 |
+
candidates for each gt
|
48 |
+
4. get corresponding iou for the these candidates, and compute the
|
49 |
+
mean and std, set mean + std as the iou threshold
|
50 |
+
5. select these candidates whose iou are greater than or equal to
|
51 |
+
the threshold as positive
|
52 |
+
6. limit the positive sample's center in gt
|
53 |
+
|
54 |
+
|
55 |
+
Args:
|
56 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
57 |
+
num_level_bboxes (List): num of bboxes in each level
|
58 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
59 |
+
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
|
60 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
61 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
:obj:`AssignResult`: The assign result.
|
65 |
+
"""
|
66 |
+
INF = 100000000
|
67 |
+
bboxes = bboxes[:, :4]
|
68 |
+
num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
69 |
+
|
70 |
+
# compute iou between all bbox and gt
|
71 |
+
overlaps = self.iou_calculator(bboxes, gt_bboxes)
|
72 |
+
|
73 |
+
# assign 0 by default
|
74 |
+
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
|
75 |
+
0,
|
76 |
+
dtype=torch.long)
|
77 |
+
|
78 |
+
if num_gt == 0 or num_bboxes == 0:
|
79 |
+
# No ground truth or boxes, return empty assignment
|
80 |
+
max_overlaps = overlaps.new_zeros((num_bboxes, ))
|
81 |
+
if num_gt == 0:
|
82 |
+
# No truth, assign everything to background
|
83 |
+
assigned_gt_inds[:] = 0
|
84 |
+
if gt_labels is None:
|
85 |
+
assigned_labels = None
|
86 |
+
else:
|
87 |
+
assigned_labels = overlaps.new_full((num_bboxes, ),
|
88 |
+
-1,
|
89 |
+
dtype=torch.long)
|
90 |
+
return AssignResult(
|
91 |
+
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
92 |
+
|
93 |
+
# compute center distance between all bbox and gt
|
94 |
+
gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
|
95 |
+
gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
|
96 |
+
gt_points = torch.stack((gt_cx, gt_cy), dim=1)
|
97 |
+
|
98 |
+
bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
|
99 |
+
bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
|
100 |
+
bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
|
101 |
+
|
102 |
+
distances = (bboxes_points[:, None, :] -
|
103 |
+
gt_points[None, :, :]).pow(2).sum(-1).sqrt()
|
104 |
+
|
105 |
+
if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
|
106 |
+
and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
|
107 |
+
ignore_overlaps = self.iou_calculator(
|
108 |
+
bboxes, gt_bboxes_ignore, mode='iof')
|
109 |
+
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
|
110 |
+
ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
|
111 |
+
distances[ignore_idxs, :] = INF
|
112 |
+
assigned_gt_inds[ignore_idxs] = -1
|
113 |
+
|
114 |
+
# Selecting candidates based on the center distance
|
115 |
+
candidate_idxs = []
|
116 |
+
start_idx = 0
|
117 |
+
for level, bboxes_per_level in enumerate(num_level_bboxes):
|
118 |
+
# on each pyramid level, for each gt,
|
119 |
+
# select k bbox whose center are closest to the gt center
|
120 |
+
end_idx = start_idx + bboxes_per_level
|
121 |
+
distances_per_level = distances[start_idx:end_idx, :]
|
122 |
+
selectable_k = min(self.topk, bboxes_per_level)
|
123 |
+
_, topk_idxs_per_level = distances_per_level.topk(
|
124 |
+
selectable_k, dim=0, largest=False)
|
125 |
+
candidate_idxs.append(topk_idxs_per_level + start_idx)
|
126 |
+
start_idx = end_idx
|
127 |
+
candidate_idxs = torch.cat(candidate_idxs, dim=0)
|
128 |
+
|
129 |
+
# get corresponding iou for the these candidates, and compute the
|
130 |
+
# mean and std, set mean + std as the iou threshold
|
131 |
+
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
|
132 |
+
overlaps_mean_per_gt = candidate_overlaps.mean(0)
|
133 |
+
overlaps_std_per_gt = candidate_overlaps.std(0)
|
134 |
+
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
|
135 |
+
|
136 |
+
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
|
137 |
+
|
138 |
+
# limit the positive sample's center in gt
|
139 |
+
for gt_idx in range(num_gt):
|
140 |
+
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
|
141 |
+
ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
|
142 |
+
num_gt, num_bboxes).contiguous().view(-1)
|
143 |
+
ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
|
144 |
+
num_gt, num_bboxes).contiguous().view(-1)
|
145 |
+
candidate_idxs = candidate_idxs.view(-1)
|
146 |
+
|
147 |
+
# calculate the left, top, right, bottom distance between positive
|
148 |
+
# bbox center and gt side
|
149 |
+
l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
|
150 |
+
t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
|
151 |
+
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
|
152 |
+
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
|
153 |
+
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
|
154 |
+
is_pos = is_pos & is_in_gts
|
155 |
+
|
156 |
+
# if an anchor box is assigned to multiple gts,
|
157 |
+
# the one with the highest IoU will be selected.
|
158 |
+
overlaps_inf = torch.full_like(overlaps,
|
159 |
+
-INF).t().contiguous().view(-1)
|
160 |
+
index = candidate_idxs.view(-1)[is_pos.view(-1)]
|
161 |
+
overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
|
162 |
+
overlaps_inf = overlaps_inf.view(num_gt, -1).t()
|
163 |
+
|
164 |
+
max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
|
165 |
+
assigned_gt_inds[
|
166 |
+
max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
|
167 |
+
|
168 |
+
if gt_labels is not None:
|
169 |
+
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
|
170 |
+
pos_inds = torch.nonzero(
|
171 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
172 |
+
if pos_inds.numel() > 0:
|
173 |
+
assigned_labels[pos_inds] = gt_labels[
|
174 |
+
assigned_gt_inds[pos_inds] - 1]
|
175 |
+
else:
|
176 |
+
assigned_labels = None
|
177 |
+
return AssignResult(
|
178 |
+
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|
mmdet/core/bbox/assigners/base_assigner.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class BaseAssigner(metaclass=ABCMeta):
|
5 |
+
"""Base assigner that assigns boxes to ground truth boxes."""
|
6 |
+
|
7 |
+
@abstractmethod
|
8 |
+
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
9 |
+
"""Assign boxes to either a ground truth boxes or a negative boxes."""
|
mmdet/core/bbox/assigners/center_region_assigner.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
def scale_boxes(bboxes, scale):
|
10 |
+
"""Expand an array of boxes by a given scale.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
bboxes (Tensor): Shape (m, 4)
|
14 |
+
scale (float): The scale factor of bboxes
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
(Tensor): Shape (m, 4). Scaled bboxes
|
18 |
+
"""
|
19 |
+
assert bboxes.size(1) == 4
|
20 |
+
w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
|
21 |
+
h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
|
22 |
+
x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
|
23 |
+
y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
|
24 |
+
|
25 |
+
w_half *= scale
|
26 |
+
h_half *= scale
|
27 |
+
|
28 |
+
boxes_scaled = torch.zeros_like(bboxes)
|
29 |
+
boxes_scaled[:, 0] = x_c - w_half
|
30 |
+
boxes_scaled[:, 2] = x_c + w_half
|
31 |
+
boxes_scaled[:, 1] = y_c - h_half
|
32 |
+
boxes_scaled[:, 3] = y_c + h_half
|
33 |
+
return boxes_scaled
|
34 |
+
|
35 |
+
|
36 |
+
def is_located_in(points, bboxes):
|
37 |
+
"""Are points located in bboxes.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
points (Tensor): Points, shape: (m, 2).
|
41 |
+
bboxes (Tensor): Bounding boxes, shape: (n, 4).
|
42 |
+
|
43 |
+
Return:
|
44 |
+
Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
|
45 |
+
"""
|
46 |
+
assert points.size(1) == 2
|
47 |
+
assert bboxes.size(1) == 4
|
48 |
+
return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
|
49 |
+
(points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
|
50 |
+
(points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
|
51 |
+
(points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
|
52 |
+
|
53 |
+
|
54 |
+
def bboxes_area(bboxes):
|
55 |
+
"""Compute the area of an array of bboxes.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tensor: Area of the bboxes. Shape: (m, )
|
62 |
+
"""
|
63 |
+
assert bboxes.size(1) == 4
|
64 |
+
w = (bboxes[:, 2] - bboxes[:, 0])
|
65 |
+
h = (bboxes[:, 3] - bboxes[:, 1])
|
66 |
+
areas = w * h
|
67 |
+
return areas
|
68 |
+
|
69 |
+
|
70 |
+
@BBOX_ASSIGNERS.register_module()
|
71 |
+
class CenterRegionAssigner(BaseAssigner):
|
72 |
+
"""Assign pixels at the center region of a bbox as positive.
|
73 |
+
|
74 |
+
Each proposals will be assigned with `-1`, `0`, or a positive integer
|
75 |
+
indicating the ground truth index.
|
76 |
+
- -1: negative samples
|
77 |
+
- semi-positive numbers: positive sample, index (0-based) of assigned gt
|
78 |
+
|
79 |
+
Args:
|
80 |
+
pos_scale (float): Threshold within which pixels are
|
81 |
+
labelled as positive.
|
82 |
+
neg_scale (float): Threshold above which pixels are
|
83 |
+
labelled as positive.
|
84 |
+
min_pos_iof (float): Minimum iof of a pixel with a gt to be
|
85 |
+
labelled as positive. Default: 1e-2
|
86 |
+
ignore_gt_scale (float): Threshold within which the pixels
|
87 |
+
are ignored when the gt is labelled as shadowed. Default: 0.5
|
88 |
+
foreground_dominate (bool): If True, the bbox will be assigned as
|
89 |
+
positive when a gt's kernel region overlaps with another's shadowed
|
90 |
+
(ignored) region, otherwise it is set as ignored. Default to False.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self,
|
94 |
+
pos_scale,
|
95 |
+
neg_scale,
|
96 |
+
min_pos_iof=1e-2,
|
97 |
+
ignore_gt_scale=0.5,
|
98 |
+
foreground_dominate=False,
|
99 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
100 |
+
self.pos_scale = pos_scale
|
101 |
+
self.neg_scale = neg_scale
|
102 |
+
self.min_pos_iof = min_pos_iof
|
103 |
+
self.ignore_gt_scale = ignore_gt_scale
|
104 |
+
self.foreground_dominate = foreground_dominate
|
105 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
106 |
+
|
107 |
+
def get_gt_priorities(self, gt_bboxes):
|
108 |
+
"""Get gt priorities according to their areas.
|
109 |
+
|
110 |
+
Smaller gt has higher priority.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Tensor: The priority of gts so that gts with larger priority is \
|
117 |
+
more likely to be assigned. Shape (k, )
|
118 |
+
"""
|
119 |
+
gt_areas = bboxes_area(gt_bboxes)
|
120 |
+
# Rank all gt bbox areas. Smaller objects has larger priority
|
121 |
+
_, sort_idx = gt_areas.sort(descending=True)
|
122 |
+
sort_idx = sort_idx.argsort()
|
123 |
+
return sort_idx
|
124 |
+
|
125 |
+
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
|
126 |
+
"""Assign gt to bboxes.
|
127 |
+
|
128 |
+
This method assigns gts to every bbox (proposal/anchor), each bbox \
|
129 |
+
will be assigned with -1, or a semi-positive number. -1 means \
|
130 |
+
negative sample, semi-positive number is the index (0-based) of \
|
131 |
+
assigned gt.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
135 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
136 |
+
gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
|
137 |
+
labelled as `ignored`, e.g., crowd boxes in COCO.
|
138 |
+
gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
:obj:`AssignResult`: The assigned result. Note that \
|
142 |
+
shadowed_labels of shape (N, 2) is also added as an \
|
143 |
+
`assign_result` attribute. `shadowed_labels` is a tensor \
|
144 |
+
composed of N pairs of anchor_ind, class_label], where N \
|
145 |
+
is the number of anchors that lie in the outer region of a \
|
146 |
+
gt, anchor_ind is the shadowed anchor index and class_label \
|
147 |
+
is the shadowed class label.
|
148 |
+
|
149 |
+
Example:
|
150 |
+
>>> self = CenterRegionAssigner(0.2, 0.2)
|
151 |
+
>>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
|
152 |
+
>>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
|
153 |
+
>>> assign_result = self.assign(bboxes, gt_bboxes)
|
154 |
+
>>> expected_gt_inds = torch.LongTensor([1, 0])
|
155 |
+
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
|
156 |
+
"""
|
157 |
+
# There are in total 5 steps in the pixel assignment
|
158 |
+
# 1. Find core (the center region, say inner 0.2)
|
159 |
+
# and shadow (the relatively ourter part, say inner 0.2-0.5)
|
160 |
+
# regions of every gt.
|
161 |
+
# 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
|
162 |
+
# 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
|
163 |
+
# the image.
|
164 |
+
# 3.1. For overlapping objects, the prior bboxes in gt_core is
|
165 |
+
# assigned with the object with smallest area
|
166 |
+
# 4. Assign prior bboxes with class label according to its gt id.
|
167 |
+
# 4.1. Assign -1 to prior bboxes lying in shadowed gts
|
168 |
+
# 4.2. Assign positive prior boxes with the corresponding label
|
169 |
+
# 5. Find pixels lying in the shadow of an object and assign them with
|
170 |
+
# background label, but set the loss weight of its corresponding
|
171 |
+
# gt to zero.
|
172 |
+
assert bboxes.size(1) == 4, 'bboxes must have size of 4'
|
173 |
+
# 1. Find core positive and shadow region of every gt
|
174 |
+
gt_core = scale_boxes(gt_bboxes, self.pos_scale)
|
175 |
+
gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
|
176 |
+
|
177 |
+
# 2. Find prior bboxes that lie in gt_core and gt_shadow regions
|
178 |
+
bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
|
179 |
+
# The center points lie within the gt boxes
|
180 |
+
is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
|
181 |
+
# Only calculate bbox and gt_core IoF. This enables small prior bboxes
|
182 |
+
# to match large gts
|
183 |
+
bbox_and_gt_core_overlaps = self.iou_calculator(
|
184 |
+
bboxes, gt_core, mode='iof')
|
185 |
+
# The center point of effective priors should be within the gt box
|
186 |
+
is_bbox_in_gt_core = is_bbox_in_gt & (
|
187 |
+
bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
|
188 |
+
|
189 |
+
is_bbox_in_gt_shadow = (
|
190 |
+
self.iou_calculator(bboxes, gt_shadow, mode='iof') >
|
191 |
+
self.min_pos_iof)
|
192 |
+
# Rule out center effective positive pixels
|
193 |
+
is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
|
194 |
+
|
195 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
196 |
+
if num_gts == 0 or num_bboxes == 0:
|
197 |
+
# If no gts exist, assign all pixels to negative
|
198 |
+
assigned_gt_ids = \
|
199 |
+
is_bbox_in_gt_core.new_zeros((num_bboxes,),
|
200 |
+
dtype=torch.long)
|
201 |
+
pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
|
202 |
+
else:
|
203 |
+
# Step 3: assign a one-hot gt id to each pixel, and smaller objects
|
204 |
+
# have high priority to assign the pixel.
|
205 |
+
sort_idx = self.get_gt_priorities(gt_bboxes)
|
206 |
+
assigned_gt_ids, pixels_in_gt_shadow = \
|
207 |
+
self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
|
208 |
+
is_bbox_in_gt_shadow,
|
209 |
+
gt_priority=sort_idx)
|
210 |
+
|
211 |
+
if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
|
212 |
+
# No ground truth or boxes, return empty assignment
|
213 |
+
gt_bboxes_ignore = scale_boxes(
|
214 |
+
gt_bboxes_ignore, scale=self.ignore_gt_scale)
|
215 |
+
is_bbox_in_ignored_gts = is_located_in(bbox_centers,
|
216 |
+
gt_bboxes_ignore)
|
217 |
+
is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
|
218 |
+
assigned_gt_ids[is_bbox_in_ignored_gts] = -1
|
219 |
+
|
220 |
+
# 4. Assign prior bboxes with class label according to its gt id.
|
221 |
+
assigned_labels = None
|
222 |
+
shadowed_pixel_labels = None
|
223 |
+
if gt_labels is not None:
|
224 |
+
# Default assigned label is the background (-1)
|
225 |
+
assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
|
226 |
+
pos_inds = torch.nonzero(
|
227 |
+
assigned_gt_ids > 0, as_tuple=False).squeeze()
|
228 |
+
if pos_inds.numel() > 0:
|
229 |
+
assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
|
230 |
+
- 1]
|
231 |
+
# 5. Find pixels lying in the shadow of an object
|
232 |
+
shadowed_pixel_labels = pixels_in_gt_shadow.clone()
|
233 |
+
if pixels_in_gt_shadow.numel() > 0:
|
234 |
+
pixel_idx, gt_idx =\
|
235 |
+
pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
|
236 |
+
assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
|
237 |
+
'Some pixels are dually assigned to ignore and gt!'
|
238 |
+
shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
|
239 |
+
override = (
|
240 |
+
assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
|
241 |
+
if self.foreground_dominate:
|
242 |
+
# When a pixel is both positive and shadowed, set it as pos
|
243 |
+
shadowed_pixel_labels = shadowed_pixel_labels[~override]
|
244 |
+
else:
|
245 |
+
# When a pixel is both pos and shadowed, set it as shadowed
|
246 |
+
assigned_labels[pixel_idx[override]] = -1
|
247 |
+
assigned_gt_ids[pixel_idx[override]] = 0
|
248 |
+
|
249 |
+
assign_result = AssignResult(
|
250 |
+
num_gts, assigned_gt_ids, None, labels=assigned_labels)
|
251 |
+
# Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
|
252 |
+
assign_result.set_extra_property('shadowed_labels',
|
253 |
+
shadowed_pixel_labels)
|
254 |
+
return assign_result
|
255 |
+
|
256 |
+
def assign_one_hot_gt_indices(self,
|
257 |
+
is_bbox_in_gt_core,
|
258 |
+
is_bbox_in_gt_shadow,
|
259 |
+
gt_priority=None):
|
260 |
+
"""Assign only one gt index to each prior box.
|
261 |
+
|
262 |
+
Gts with large gt_priority are more likely to be assigned.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
|
266 |
+
is in the core area of a gt (e.g. 0-0.2).
|
267 |
+
Shape: (num_prior, num_gt).
|
268 |
+
is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
|
269 |
+
center is in the shadowed area of a gt (e.g. 0.2-0.5).
|
270 |
+
Shape: (num_prior, num_gt).
|
271 |
+
gt_priority (Tensor): Priorities of gts. The gt with a higher
|
272 |
+
priority is more likely to be assigned to the bbox when the bbox
|
273 |
+
match with multiple gts. Shape: (num_gt, ).
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
|
277 |
+
|
278 |
+
- assigned_gt_inds: The assigned gt index of each prior bbox \
|
279 |
+
(i.e. index from 1 to num_gts). Shape: (num_prior, ).
|
280 |
+
- shadowed_gt_inds: shadowed gt indices. It is a tensor of \
|
281 |
+
shape (num_ignore, 2) with first column being the \
|
282 |
+
shadowed prior bbox indices and the second column the \
|
283 |
+
shadowed gt indices (1-based).
|
284 |
+
"""
|
285 |
+
num_bboxes, num_gts = is_bbox_in_gt_core.shape
|
286 |
+
|
287 |
+
if gt_priority is None:
|
288 |
+
gt_priority = torch.arange(
|
289 |
+
num_gts, device=is_bbox_in_gt_core.device)
|
290 |
+
assert gt_priority.size(0) == num_gts
|
291 |
+
# The bigger gt_priority, the more preferable to be assigned
|
292 |
+
# The assigned inds are by default 0 (background)
|
293 |
+
assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
|
294 |
+
dtype=torch.long)
|
295 |
+
# Shadowed bboxes are assigned to be background. But the corresponding
|
296 |
+
# label is ignored during loss calculation, which is done through
|
297 |
+
# shadowed_gt_inds
|
298 |
+
shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
|
299 |
+
if is_bbox_in_gt_core.sum() == 0: # No gt match
|
300 |
+
shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
|
301 |
+
return assigned_gt_inds, shadowed_gt_inds
|
302 |
+
|
303 |
+
# The priority of each prior box and gt pair. If one prior box is
|
304 |
+
# matched bo multiple gts. Only the pair with the highest priority
|
305 |
+
# is saved
|
306 |
+
pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
|
307 |
+
-1,
|
308 |
+
dtype=torch.long)
|
309 |
+
|
310 |
+
# Each bbox could match with multiple gts.
|
311 |
+
# The following codes deal with this situation
|
312 |
+
# Matched bboxes (to any gt). Shape: (num_pos_anchor, )
|
313 |
+
inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
|
314 |
+
# The matched gt index of each positive bbox. Length >= num_pos_anchor
|
315 |
+
# , since one bbox could match multiple gts
|
316 |
+
matched_bbox_gt_inds = torch.nonzero(
|
317 |
+
is_bbox_in_gt_core, as_tuple=False)[:, 1]
|
318 |
+
# Assign priority to each bbox-gt pair.
|
319 |
+
pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
|
320 |
+
_, argmax_priority = pair_priority[inds_of_match].max(dim=1)
|
321 |
+
assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
|
322 |
+
# Zero-out the assigned anchor box to filter the shadowed gt indices
|
323 |
+
is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
|
324 |
+
# Concat the shadowed indices due to overlapping with that out side of
|
325 |
+
# effective scale. shape: (total_num_ignore, 2)
|
326 |
+
shadowed_gt_inds = torch.cat(
|
327 |
+
(shadowed_gt_inds, torch.nonzero(
|
328 |
+
is_bbox_in_gt_core, as_tuple=False)),
|
329 |
+
dim=0)
|
330 |
+
# `is_bbox_in_gt_core` should be changed back to keep arguments intact.
|
331 |
+
is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
|
332 |
+
# 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
|
333 |
+
if shadowed_gt_inds.numel() > 0:
|
334 |
+
shadowed_gt_inds[:, 1] += 1
|
335 |
+
return assigned_gt_inds, shadowed_gt_inds
|
mmdet/core/bbox/assigners/grid_assigner.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import BBOX_ASSIGNERS
|
4 |
+
from ..iou_calculators import build_iou_calculator
|
5 |
+
from .assign_result import AssignResult
|
6 |
+
from .base_assigner import BaseAssigner
|
7 |
+
|
8 |
+
|
9 |
+
@BBOX_ASSIGNERS.register_module()
|
10 |
+
class GridAssigner(BaseAssigner):
|
11 |
+
"""Assign a corresponding gt bbox or background to each bbox.
|
12 |
+
|
13 |
+
Each proposals will be assigned with `-1`, `0`, or a positive integer
|
14 |
+
indicating the ground truth index.
|
15 |
+
|
16 |
+
- -1: don't care
|
17 |
+
- 0: negative sample, no assigned gt
|
18 |
+
- positive integer: positive sample, index (1-based) of assigned gt
|
19 |
+
|
20 |
+
Args:
|
21 |
+
pos_iou_thr (float): IoU threshold for positive bboxes.
|
22 |
+
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
|
23 |
+
min_pos_iou (float): Minimum iou for a bbox to be considered as a
|
24 |
+
positive bbox. Positive samples can have smaller IoU than
|
25 |
+
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
|
26 |
+
gt_max_assign_all (bool): Whether to assign all bboxes with the same
|
27 |
+
highest overlap with some gt to that gt.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
pos_iou_thr,
|
32 |
+
neg_iou_thr,
|
33 |
+
min_pos_iou=.0,
|
34 |
+
gt_max_assign_all=True,
|
35 |
+
iou_calculator=dict(type='BboxOverlaps2D')):
|
36 |
+
self.pos_iou_thr = pos_iou_thr
|
37 |
+
self.neg_iou_thr = neg_iou_thr
|
38 |
+
self.min_pos_iou = min_pos_iou
|
39 |
+
self.gt_max_assign_all = gt_max_assign_all
|
40 |
+
self.iou_calculator = build_iou_calculator(iou_calculator)
|
41 |
+
|
42 |
+
def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
|
43 |
+
"""Assign gt to bboxes. The process is very much like the max iou
|
44 |
+
assigner, except that positive samples are constrained within the cell
|
45 |
+
that the gt boxes fell in.
|
46 |
+
|
47 |
+
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
|
48 |
+
will be assigned with -1, 0, or a positive number. -1 means don't care,
|
49 |
+
0 means negative sample, positive number is the index (1-based) of
|
50 |
+
assigned gt.
|
51 |
+
The assignment is done in following steps, the order matters.
|
52 |
+
|
53 |
+
1. assign every bbox to -1
|
54 |
+
2. assign proposals whose iou with all gts <= neg_iou_thr to 0
|
55 |
+
3. for each bbox within a cell, if the iou with its nearest gt >
|
56 |
+
pos_iou_thr and the center of that gt falls inside the cell,
|
57 |
+
assign it to that bbox
|
58 |
+
4. for each gt bbox, assign its nearest proposals within the cell the
|
59 |
+
gt bbox falls in to itself.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
|
63 |
+
box_responsible_flags (Tensor): flag to indicate whether box is
|
64 |
+
responsible for prediction, shape(n, )
|
65 |
+
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
|
66 |
+
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
:obj:`AssignResult`: The assign result.
|
70 |
+
"""
|
71 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
|
72 |
+
|
73 |
+
# compute iou between all gt and bboxes
|
74 |
+
overlaps = self.iou_calculator(gt_bboxes, bboxes)
|
75 |
+
|
76 |
+
# 1. assign -1 by default
|
77 |
+
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
|
78 |
+
-1,
|
79 |
+
dtype=torch.long)
|
80 |
+
|
81 |
+
if num_gts == 0 or num_bboxes == 0:
|
82 |
+
# No ground truth or boxes, return empty assignment
|
83 |
+
max_overlaps = overlaps.new_zeros((num_bboxes, ))
|
84 |
+
if num_gts == 0:
|
85 |
+
# No truth, assign everything to background
|
86 |
+
assigned_gt_inds[:] = 0
|
87 |
+
if gt_labels is None:
|
88 |
+
assigned_labels = None
|
89 |
+
else:
|
90 |
+
assigned_labels = overlaps.new_full((num_bboxes, ),
|
91 |
+
-1,
|
92 |
+
dtype=torch.long)
|
93 |
+
return AssignResult(
|
94 |
+
num_gts,
|
95 |
+
assigned_gt_inds,
|
96 |
+
max_overlaps,
|
97 |
+
labels=assigned_labels)
|
98 |
+
|
99 |
+
# 2. assign negative: below
|
100 |
+
# for each anchor, which gt best overlaps with it
|
101 |
+
# for each anchor, the max iou of all gts
|
102 |
+
# shape of max_overlaps == argmax_overlaps == num_bboxes
|
103 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
104 |
+
|
105 |
+
if isinstance(self.neg_iou_thr, float):
|
106 |
+
assigned_gt_inds[(max_overlaps >= 0)
|
107 |
+
& (max_overlaps <= self.neg_iou_thr)] = 0
|
108 |
+
elif isinstance(self.neg_iou_thr, (tuple, list)):
|
109 |
+
assert len(self.neg_iou_thr) == 2
|
110 |
+
assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
|
111 |
+
& (max_overlaps <= self.neg_iou_thr[1])] = 0
|
112 |
+
|
113 |
+
# 3. assign positive: falls into responsible cell and above
|
114 |
+
# positive IOU threshold, the order matters.
|
115 |
+
# the prior condition of comparision is to filter out all
|
116 |
+
# unrelated anchors, i.e. not box_responsible_flags
|
117 |
+
overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
|
118 |
+
|
119 |
+
# calculate max_overlaps again, but this time we only consider IOUs
|
120 |
+
# for anchors responsible for prediction
|
121 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
122 |
+
|
123 |
+
# for each gt, which anchor best overlaps with it
|
124 |
+
# for each gt, the max iou of all proposals
|
125 |
+
# shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
|
126 |
+
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
|
127 |
+
|
128 |
+
pos_inds = (max_overlaps >
|
129 |
+
self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
|
130 |
+
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
|
131 |
+
|
132 |
+
# 4. assign positive to max overlapped anchors within responsible cell
|
133 |
+
for i in range(num_gts):
|
134 |
+
if gt_max_overlaps[i] > self.min_pos_iou:
|
135 |
+
if self.gt_max_assign_all:
|
136 |
+
max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
|
137 |
+
box_responsible_flags.type(torch.bool)
|
138 |
+
assigned_gt_inds[max_iou_inds] = i + 1
|
139 |
+
elif box_responsible_flags[gt_argmax_overlaps[i]]:
|
140 |
+
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
|
141 |
+
|
142 |
+
# assign labels of positive anchors
|
143 |
+
if gt_labels is not None:
|
144 |
+
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
|
145 |
+
pos_inds = torch.nonzero(
|
146 |
+
assigned_gt_inds > 0, as_tuple=False).squeeze()
|
147 |
+
if pos_inds.numel() > 0:
|
148 |
+
assigned_labels[pos_inds] = gt_labels[
|
149 |
+
assigned_gt_inds[pos_inds] - 1]
|
150 |
+
|
151 |
+
else:
|
152 |
+
assigned_labels = None
|
153 |
+
|
154 |
+
return AssignResult(
|
155 |
+
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
|