FateScript
commited on
Commit
·
21e3f5a
1
Parent(s):
84a8937
feat(YOLOX): support YOLOX
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +214 -0
- LICENSE +2 -2
- demo/ncnn/yolox.cpp +419 -0
- requirements.txt +14 -0
- setup.cfg +18 -0
- setup.py +64 -0
- tools/demo.py +278 -0
- tools/eval.py +195 -0
- tools/export_onnx.py +92 -0
- tools/train.py +112 -0
- tools/trt.py +67 -0
- yolox/__init__.py +8 -0
- yolox/core/__init__.py +5 -0
- yolox/core/launch.py +105 -0
- yolox/core/trainer.py +318 -0
- yolox/data/__init__.py +8 -0
- yolox/data/data_augment.py +389 -0
- yolox/data/data_prefetcher.py +76 -0
- yolox/data/dataloading.py +172 -0
- yolox/data/datasets/__init__.py +7 -0
- yolox/data/datasets/coco.py +119 -0
- yolox/data/datasets/coco_classes.py +86 -0
- yolox/data/datasets/datasets_wrapper.py +126 -0
- yolox/data/datasets/mosaicdetection.py +195 -0
- yolox/data/datasets/voc.py +313 -0
- yolox/data/datasets/voc_classes.py +27 -0
- yolox/data/samplers.py +94 -0
- yolox/evalutors/__init__.py +4 -0
- yolox/evalutors/coco_evaluator.py +217 -0
- yolox/evalutors/voc_eval.py +184 -0
- yolox/evalutors/voc_evaluator.py +200 -0
- yolox/exp/__init__.py +7 -0
- yolox/exp/base_exp.py +73 -0
- yolox/exp/build.py +48 -0
- yolox/exp/yolox_base.py +251 -0
- yolox/layers/__init__.py +4 -0
- yolox/layers/csrc/cocoeval/cocoeval.cpp +502 -0
- yolox/layers/csrc/cocoeval/cocoeval.h +85 -0
- yolox/layers/csrc/vision.cpp +13 -0
- yolox/layers/fast_coco_eval_api.py +147 -0
- yolox/models/__init__.py +9 -0
- yolox/models/darknet.py +140 -0
- yolox/models/losses.py +48 -0
- yolox/models/network_blocks.py +166 -0
- yolox/models/yolo_fpn.py +80 -0
- yolox/models/yolo_head.py +525 -0
- yolox/models/yolo_pafpn.py +99 -0
- yolox/models/yolox.py +45 -0
- yolox/utils/__init__.py +14 -0
- yolox/utils/allreduce_norm.py +99 -0
.gitignore
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Linux ###
|
2 |
+
*~
|
3 |
+
|
4 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
5 |
+
.fuse_hidden*
|
6 |
+
|
7 |
+
# KDE directory preferences
|
8 |
+
.directory
|
9 |
+
|
10 |
+
# Linux trash folder which might appear on any partition or disk
|
11 |
+
.Trash-*
|
12 |
+
|
13 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
14 |
+
.nfs*
|
15 |
+
|
16 |
+
### PyCharm ###
|
17 |
+
# User-specific stuff
|
18 |
+
.idea
|
19 |
+
|
20 |
+
# CMake
|
21 |
+
cmake-build-*/
|
22 |
+
|
23 |
+
# Mongo Explorer plugin
|
24 |
+
.idea/**/mongoSettings.xml
|
25 |
+
|
26 |
+
# File-based project format
|
27 |
+
*.iws
|
28 |
+
|
29 |
+
# IntelliJ
|
30 |
+
out/
|
31 |
+
|
32 |
+
# mpeltonen/sbt-idea plugin
|
33 |
+
.idea_modules/
|
34 |
+
|
35 |
+
# JIRA plugin
|
36 |
+
atlassian-ide-plugin.xml
|
37 |
+
|
38 |
+
# Cursive Clojure plugin
|
39 |
+
.idea/replstate.xml
|
40 |
+
|
41 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
42 |
+
com_crashlytics_export_strings.xml
|
43 |
+
crashlytics.properties
|
44 |
+
crashlytics-build.properties
|
45 |
+
fabric.properties
|
46 |
+
|
47 |
+
# Editor-based Rest Client
|
48 |
+
.idea/httpRequests
|
49 |
+
|
50 |
+
# Android studio 3.1+ serialized cache file
|
51 |
+
.idea/caches/build_file_checksums.ser
|
52 |
+
|
53 |
+
# JetBrains templates
|
54 |
+
**___jb_tmp___
|
55 |
+
|
56 |
+
### Python ###
|
57 |
+
# Byte-compiled / optimized / DLL files
|
58 |
+
__pycache__/
|
59 |
+
*.py[cod]
|
60 |
+
*$py.class
|
61 |
+
|
62 |
+
# C extensions
|
63 |
+
*.so
|
64 |
+
|
65 |
+
# Distribution / packaging
|
66 |
+
.Python
|
67 |
+
build/
|
68 |
+
develop-eggs/
|
69 |
+
dist/
|
70 |
+
downloads/
|
71 |
+
eggs/
|
72 |
+
.eggs/
|
73 |
+
lib/
|
74 |
+
lib64/
|
75 |
+
parts/
|
76 |
+
sdist/
|
77 |
+
var/
|
78 |
+
wheels/
|
79 |
+
pip-wheel-metadata/
|
80 |
+
share/python-wheels/
|
81 |
+
*.egg-info/
|
82 |
+
.installed.cfg
|
83 |
+
*.egg
|
84 |
+
MANIFEST
|
85 |
+
|
86 |
+
# PyInstaller
|
87 |
+
# Usually these files are written by a python script from a template
|
88 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
89 |
+
*.manifest
|
90 |
+
*.spec
|
91 |
+
|
92 |
+
# Installer logs
|
93 |
+
pip-log.txt
|
94 |
+
pip-delete-this-directory.txt
|
95 |
+
|
96 |
+
# Unit test / coverage reports
|
97 |
+
htmlcov/
|
98 |
+
.tox/
|
99 |
+
.nox/
|
100 |
+
.coverage
|
101 |
+
.coverage.*
|
102 |
+
.cache
|
103 |
+
nosetests.xml
|
104 |
+
coverage.xml
|
105 |
+
*.cover
|
106 |
+
.hypothesis/
|
107 |
+
.pytest_cache/
|
108 |
+
|
109 |
+
# Translations
|
110 |
+
*.mo
|
111 |
+
*.pot
|
112 |
+
|
113 |
+
# Django stuff:
|
114 |
+
*.log
|
115 |
+
local_settings.py
|
116 |
+
db.sqlite3
|
117 |
+
|
118 |
+
# Flask stuff:
|
119 |
+
instance/
|
120 |
+
.webassets-cache
|
121 |
+
|
122 |
+
# Scrapy stuff:
|
123 |
+
.scrapy
|
124 |
+
|
125 |
+
# Sphinx documentation
|
126 |
+
docs/_build/
|
127 |
+
docs/build/
|
128 |
+
|
129 |
+
# PyBuilder
|
130 |
+
target/
|
131 |
+
|
132 |
+
# Jupyter Notebook
|
133 |
+
.ipynb_checkpoints
|
134 |
+
|
135 |
+
# IPython
|
136 |
+
profile_default/
|
137 |
+
ipython_config.py
|
138 |
+
|
139 |
+
# pyenv
|
140 |
+
.python-version
|
141 |
+
|
142 |
+
# pipenv
|
143 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
144 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
145 |
+
# having no cross-platform support, pipenv may install dependencies that don’t work, or not
|
146 |
+
# install all needed dependencies.
|
147 |
+
#Pipfile.lock
|
148 |
+
|
149 |
+
# celery beat schedule file
|
150 |
+
celerybeat-schedule
|
151 |
+
|
152 |
+
# SageMath parsed files
|
153 |
+
*.sage.py
|
154 |
+
|
155 |
+
# Environments
|
156 |
+
.env
|
157 |
+
.venv
|
158 |
+
env/
|
159 |
+
venv/
|
160 |
+
ENV/
|
161 |
+
env.bak/
|
162 |
+
venv.bak/
|
163 |
+
|
164 |
+
# Spyder project settings
|
165 |
+
.spyderproject
|
166 |
+
.spyproject
|
167 |
+
|
168 |
+
# Rope project settings
|
169 |
+
.ropeproject
|
170 |
+
|
171 |
+
# mkdocs documentation
|
172 |
+
/site
|
173 |
+
|
174 |
+
# mypy
|
175 |
+
.mypy_cache/
|
176 |
+
.dmypy.json
|
177 |
+
dmypy.json
|
178 |
+
|
179 |
+
# Pyre type checker
|
180 |
+
.pyre/
|
181 |
+
|
182 |
+
### Vim ###
|
183 |
+
# Swap
|
184 |
+
[._]*.s[a-v][a-z]
|
185 |
+
[._]*.sw[a-p]
|
186 |
+
[._]s[a-rt-v][a-z]
|
187 |
+
[._]ss[a-gi-z]
|
188 |
+
[._]sw[a-p]
|
189 |
+
|
190 |
+
# Session
|
191 |
+
Session.vim
|
192 |
+
|
193 |
+
# Temporary
|
194 |
+
.netrwhist
|
195 |
+
# Auto-generated tag files
|
196 |
+
tags
|
197 |
+
# Persistent undo
|
198 |
+
[._]*.un~
|
199 |
+
|
200 |
+
# output
|
201 |
+
docs/api
|
202 |
+
.code-workspace.code-workspace
|
203 |
+
*.pkl
|
204 |
+
*.npy
|
205 |
+
*.pth
|
206 |
+
*.onnx
|
207 |
+
events.out.tfevents*
|
208 |
+
|
209 |
+
# vscode
|
210 |
+
*.code-workspace
|
211 |
+
.vscode
|
212 |
+
|
213 |
+
# vim
|
214 |
+
.vim
|
LICENSE
CHANGED
@@ -178,7 +178,7 @@
|
|
178 |
APPENDIX: How to apply the Apache License to your work.
|
179 |
|
180 |
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "
|
182 |
replaced with your own identifying information. (Don't include
|
183 |
the brackets!) The text should be enclosed in the appropriate
|
184 |
comment syntax for the file format. We also recommend that a
|
@@ -186,7 +186,7 @@
|
|
186 |
same "printed page" as the copyright notice for easier
|
187 |
identification within third-party archives.
|
188 |
|
189 |
-
Copyright
|
190 |
|
191 |
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
you may not use this file except in compliance with the License.
|
|
|
178 |
APPENDIX: How to apply the Apache License to your work.
|
179 |
|
180 |
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
replaced with your own identifying information. (Don't include
|
183 |
the brackets!) The text should be enclosed in the appropriate
|
184 |
comment syntax for the file format. We also recommend that a
|
|
|
186 |
same "printed page" as the copyright notice for easier
|
187 |
identification within third-party archives.
|
188 |
|
189 |
+
Copyright 2021 Megvii, Base Detection
|
190 |
|
191 |
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
you may not use this file except in compliance with the License.
|
demo/ncnn/yolox.cpp
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// This file is wirtten base on the following file:
|
2 |
+
// https://github.com/Tencent/ncnn/blob/master/examples/yolov5.cpp
|
3 |
+
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
4 |
+
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
5 |
+
// in compliance with the License. You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// https://opensource.org/licenses/BSD-3-Clause
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software distributed
|
10 |
+
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
11 |
+
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
12 |
+
// specific language governing permissions and limitations under the License.
|
13 |
+
// ------------------------------------------------------------------------------
|
14 |
+
// Copyright (C) 2020-2021, Megvii Inc. All rights reserved.
|
15 |
+
|
16 |
+
#include "layer.h"
|
17 |
+
#include "net.h"
|
18 |
+
|
19 |
+
#if defined(USE_NCNN_SIMPLEOCV)
|
20 |
+
#include "simpleocv.h"
|
21 |
+
#else
|
22 |
+
#include <opencv2/core/core.hpp>
|
23 |
+
#include <opencv2/highgui/highgui.hpp>
|
24 |
+
#include <opencv2/imgproc/imgproc.hpp>
|
25 |
+
#endif
|
26 |
+
#include <float.h>
|
27 |
+
#include <stdio.h>
|
28 |
+
#include <vector>
|
29 |
+
|
30 |
+
// YOLOX use the same focus in yolov5
|
31 |
+
class YoloV5Focus : public ncnn::Layer
|
32 |
+
{
|
33 |
+
public:
|
34 |
+
YoloV5Focus()
|
35 |
+
{
|
36 |
+
one_blob_only = true;
|
37 |
+
}
|
38 |
+
|
39 |
+
virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
|
40 |
+
{
|
41 |
+
int w = bottom_blob.w;
|
42 |
+
int h = bottom_blob.h;
|
43 |
+
int channels = bottom_blob.c;
|
44 |
+
|
45 |
+
int outw = w / 2;
|
46 |
+
int outh = h / 2;
|
47 |
+
int outc = channels * 4;
|
48 |
+
|
49 |
+
top_blob.create(outw, outh, outc, 4u, 1, opt.blob_allocator);
|
50 |
+
if (top_blob.empty())
|
51 |
+
return -100;
|
52 |
+
|
53 |
+
#pragma omp parallel for num_threads(opt.num_threads)
|
54 |
+
for (int p = 0; p < outc; p++)
|
55 |
+
{
|
56 |
+
const float* ptr = bottom_blob.channel(p % channels).row((p / channels) % 2) + ((p / channels) / 2);
|
57 |
+
float* outptr = top_blob.channel(p);
|
58 |
+
|
59 |
+
for (int i = 0; i < outh; i++)
|
60 |
+
{
|
61 |
+
for (int j = 0; j < outw; j++)
|
62 |
+
{
|
63 |
+
*outptr = *ptr;
|
64 |
+
|
65 |
+
outptr += 1;
|
66 |
+
ptr += 2;
|
67 |
+
}
|
68 |
+
|
69 |
+
ptr += w;
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
return 0;
|
74 |
+
}
|
75 |
+
};
|
76 |
+
|
77 |
+
DEFINE_LAYER_CREATOR(YoloV5Focus)
|
78 |
+
|
79 |
+
struct Object
|
80 |
+
{
|
81 |
+
cv::Rect_<float> rect;
|
82 |
+
int label;
|
83 |
+
float prob;
|
84 |
+
};
|
85 |
+
|
86 |
+
struct GridAndStride
|
87 |
+
{
|
88 |
+
int grid0;
|
89 |
+
int grid1;
|
90 |
+
int stride;
|
91 |
+
};
|
92 |
+
|
93 |
+
static inline float intersection_area(const Object& a, const Object& b)
|
94 |
+
{
|
95 |
+
cv::Rect_<float> inter = a.rect & b.rect;
|
96 |
+
return inter.area();
|
97 |
+
}
|
98 |
+
|
99 |
+
static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
|
100 |
+
{
|
101 |
+
int i = left;
|
102 |
+
int j = right;
|
103 |
+
float p = faceobjects[(left + right) / 2].prob;
|
104 |
+
|
105 |
+
while (i <= j)
|
106 |
+
{
|
107 |
+
while (faceobjects[i].prob > p)
|
108 |
+
i++;
|
109 |
+
|
110 |
+
while (faceobjects[j].prob < p)
|
111 |
+
j--;
|
112 |
+
|
113 |
+
if (i <= j)
|
114 |
+
{
|
115 |
+
// swap
|
116 |
+
std::swap(faceobjects[i], faceobjects[j]);
|
117 |
+
|
118 |
+
i++;
|
119 |
+
j--;
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
#pragma omp parallel sections
|
124 |
+
{
|
125 |
+
#pragma omp section
|
126 |
+
{
|
127 |
+
if (left < j) qsort_descent_inplace(faceobjects, left, j);
|
128 |
+
}
|
129 |
+
#pragma omp section
|
130 |
+
{
|
131 |
+
if (i < right) qsort_descent_inplace(faceobjects, i, right);
|
132 |
+
}
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
static void qsort_descent_inplace(std::vector<Object>& objects)
|
137 |
+
{
|
138 |
+
if (objects.empty())
|
139 |
+
return;
|
140 |
+
|
141 |
+
qsort_descent_inplace(objects, 0, objects.size() - 1);
|
142 |
+
}
|
143 |
+
|
144 |
+
static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold)
|
145 |
+
{
|
146 |
+
picked.clear();
|
147 |
+
|
148 |
+
const int n = faceobjects.size();
|
149 |
+
|
150 |
+
std::vector<float> areas(n);
|
151 |
+
for (int i = 0; i < n; i++)
|
152 |
+
{
|
153 |
+
areas[i] = faceobjects[i].rect.area();
|
154 |
+
}
|
155 |
+
|
156 |
+
for (int i = 0; i < n; i++)
|
157 |
+
{
|
158 |
+
const Object& a = faceobjects[i];
|
159 |
+
|
160 |
+
int keep = 1;
|
161 |
+
for (int j = 0; j < (int)picked.size(); j++)
|
162 |
+
{
|
163 |
+
const Object& b = faceobjects[picked[j]];
|
164 |
+
|
165 |
+
// intersection over union
|
166 |
+
float inter_area = intersection_area(a, b);
|
167 |
+
float union_area = areas[i] + areas[picked[j]] - inter_area;
|
168 |
+
// float IoU = inter_area / union_area
|
169 |
+
if (inter_area / union_area > nms_threshold)
|
170 |
+
keep = 0;
|
171 |
+
}
|
172 |
+
|
173 |
+
if (keep)
|
174 |
+
picked.push_back(i);
|
175 |
+
}
|
176 |
+
}
|
177 |
+
|
178 |
+
static int generate_grids_and_stride(const int target_size, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
|
179 |
+
{
|
180 |
+
for (auto stride : strides)
|
181 |
+
{
|
182 |
+
int num_grid = target_size / stride;
|
183 |
+
for (int g1 = 0; g1 < num_grid; g1++)
|
184 |
+
{
|
185 |
+
for (int g0 = 0; g0 < num_grid; g0++)
|
186 |
+
{
|
187 |
+
grid_strides.push_back((GridAndStride){g0, g1, stride});
|
188 |
+
}
|
189 |
+
}
|
190 |
+
}
|
191 |
+
}
|
192 |
+
|
193 |
+
static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, const ncnn::Mat& feat_blob, float prob_threshold, std::vector<Object>& objects)
|
194 |
+
{
|
195 |
+
const int num_grid = feat_blob.h;
|
196 |
+
fprintf(stderr, "output height: %d, width: %d, channels: %d, dims:%d\n", feat_blob.h, feat_blob.w, feat_blob.c, feat_blob.dims);
|
197 |
+
|
198 |
+
const int num_class = feat_blob.w - 5;
|
199 |
+
|
200 |
+
const int num_anchors = grid_strides.size();
|
201 |
+
|
202 |
+
const float* feat_ptr = feat_blob.channel(0);
|
203 |
+
for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
|
204 |
+
{
|
205 |
+
const int grid0 = grid_strides[anchor_idx].grid0;
|
206 |
+
const int grid1 = grid_strides[anchor_idx].grid1;
|
207 |
+
const int stride = grid_strides[anchor_idx].stride;
|
208 |
+
|
209 |
+
// yolox/models/yolo_head.py decode logic
|
210 |
+
// outputs[..., :2] = (outputs[..., :2] + grids) * strides
|
211 |
+
// outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
|
212 |
+
float x_center = (feat_ptr[0] + grid0) * stride;
|
213 |
+
float y_center = (feat_ptr[1] + grid1) * stride;
|
214 |
+
float w = exp(feat_ptr[2]) * stride;
|
215 |
+
float h = exp(feat_ptr[3]) * stride;
|
216 |
+
float x0 = x_center - w * 0.5f;
|
217 |
+
float y0 = y_center - h * 0.5f;
|
218 |
+
|
219 |
+
float box_objectness = feat_ptr[4];
|
220 |
+
for (int class_idx = 0; class_idx < num_class; class_idx++)
|
221 |
+
{
|
222 |
+
float box_cls_score = feat_ptr[5 + class_idx];
|
223 |
+
float box_prob = box_objectness * box_cls_score;
|
224 |
+
if (box_prob > prob_threshold)
|
225 |
+
{
|
226 |
+
Object obj;
|
227 |
+
obj.rect.x = x0;
|
228 |
+
obj.rect.y = y0;
|
229 |
+
obj.rect.width = w;
|
230 |
+
obj.rect.height = h;
|
231 |
+
obj.label = class_idx;
|
232 |
+
obj.prob = box_prob;
|
233 |
+
|
234 |
+
objects.push_back(obj);
|
235 |
+
}
|
236 |
+
|
237 |
+
} // class loop
|
238 |
+
feat_ptr += feat_blob.w;
|
239 |
+
|
240 |
+
} // point anchor loop
|
241 |
+
}
|
242 |
+
|
243 |
+
static int detect_yolox(const cv::Mat& bgr, std::vector<Object>& objects)
|
244 |
+
{
|
245 |
+
ncnn::Net yolox;
|
246 |
+
|
247 |
+
yolox.opt.use_vulkan_compute = true;
|
248 |
+
// yolox.opt.use_bf16_storage = true;
|
249 |
+
|
250 |
+
yolox.register_custom_layer("YoloV5Focus", YoloV5Focus_layer_creator);
|
251 |
+
|
252 |
+
// original pretrained model from https://github.com/yolox
|
253 |
+
// TODO ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
|
254 |
+
yolox.load_param("yolox.param");
|
255 |
+
yolox.load_model("yolox.bin");
|
256 |
+
|
257 |
+
const int target_size = 416;
|
258 |
+
const float prob_threshold = 0.3f;
|
259 |
+
const float nms_threshold = 0.65f;
|
260 |
+
|
261 |
+
int img_w = bgr.cols;
|
262 |
+
int img_h = bgr.rows;
|
263 |
+
|
264 |
+
int w = img_w;
|
265 |
+
int h = img_h;
|
266 |
+
float scale = 1.f;
|
267 |
+
if (w > h)
|
268 |
+
{
|
269 |
+
scale = (float)target_size / w;
|
270 |
+
w = target_size;
|
271 |
+
h = h * scale;
|
272 |
+
}
|
273 |
+
else
|
274 |
+
{
|
275 |
+
scale = (float)target_size / h;
|
276 |
+
h = target_size;
|
277 |
+
w = w * scale;
|
278 |
+
}
|
279 |
+
ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);
|
280 |
+
|
281 |
+
// pad to target_size rectangle
|
282 |
+
int wpad = target_size - w;
|
283 |
+
int hpad = target_size - h;
|
284 |
+
ncnn::Mat in_pad;
|
285 |
+
// different from yolov5, yolox only pad on bottom and right side,
|
286 |
+
// which means users don't need to extra padding info to decode boxes coordinate.
|
287 |
+
ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 114.f);
|
288 |
+
|
289 |
+
// python 0-1 input tensor with rgb_means = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)
|
290 |
+
// so for 0-255 input image, rgb_mean should multiply 255 and norm should div by std.
|
291 |
+
const float mean_vals[3] = {255.f * 0.485f, 255.f * 0.456, 255.f * 0.406f};
|
292 |
+
const float norm_vals[3] = {1 / (255.f * 0.229f), 1 / (255.f * 0.224f), 1 / (255.f * 0.225f)};
|
293 |
+
|
294 |
+
in_pad.substract_mean_normalize(mean_vals, norm_vals);
|
295 |
+
|
296 |
+
ncnn::Extractor ex = yolox.create_extractor();
|
297 |
+
|
298 |
+
ex.input("images", in_pad);
|
299 |
+
|
300 |
+
std::vector<Object> proposals;
|
301 |
+
|
302 |
+
{
|
303 |
+
ncnn::Mat out;
|
304 |
+
ex.extract("output", out);
|
305 |
+
|
306 |
+
std::vector<int> strides = {8, 16, 32}; // might have stride=64
|
307 |
+
std::vector<GridAndStride> grid_strides;
|
308 |
+
generate_grids_and_stride(target_size, strides, grid_strides);
|
309 |
+
generate_yolox_proposals(grid_strides, out, prob_threshold, proposals);
|
310 |
+
}
|
311 |
+
|
312 |
+
// sort all proposals by score from highest to lowest
|
313 |
+
qsort_descent_inplace(proposals);
|
314 |
+
|
315 |
+
// apply nms with nms_threshold
|
316 |
+
std::vector<int> picked;
|
317 |
+
nms_sorted_bboxes(proposals, picked, nms_threshold);
|
318 |
+
|
319 |
+
int count = picked.size();
|
320 |
+
|
321 |
+
objects.resize(count);
|
322 |
+
for (int i = 0; i < count; i++)
|
323 |
+
{
|
324 |
+
objects[i] = proposals[picked[i]];
|
325 |
+
|
326 |
+
// adjust offset to original unpadded
|
327 |
+
float x0 = (objects[i].rect.x) / scale;
|
328 |
+
float y0 = (objects[i].rect.y) / scale;
|
329 |
+
float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
|
330 |
+
float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
|
331 |
+
|
332 |
+
// clip
|
333 |
+
x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
|
334 |
+
y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
|
335 |
+
x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
|
336 |
+
y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
|
337 |
+
|
338 |
+
objects[i].rect.x = x0;
|
339 |
+
objects[i].rect.y = y0;
|
340 |
+
objects[i].rect.width = x1 - x0;
|
341 |
+
objects[i].rect.height = y1 - y0;
|
342 |
+
}
|
343 |
+
|
344 |
+
return 0;
|
345 |
+
}
|
346 |
+
|
347 |
+
static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
|
348 |
+
{
|
349 |
+
static const char* class_names[] = {
|
350 |
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
351 |
+
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
352 |
+
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
353 |
+
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
354 |
+
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
355 |
+
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
356 |
+
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
357 |
+
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
358 |
+
"hair drier", "toothbrush"
|
359 |
+
};
|
360 |
+
|
361 |
+
cv::Mat image = bgr.clone();
|
362 |
+
|
363 |
+
for (size_t i = 0; i < objects.size(); i++)
|
364 |
+
{
|
365 |
+
const Object& obj = objects[i];
|
366 |
+
|
367 |
+
fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
|
368 |
+
obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
|
369 |
+
|
370 |
+
cv::rectangle(image, obj.rect, cv::Scalar(255, 0, 0));
|
371 |
+
|
372 |
+
char text[256];
|
373 |
+
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
|
374 |
+
|
375 |
+
int baseLine = 0;
|
376 |
+
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
|
377 |
+
|
378 |
+
int x = obj.rect.x;
|
379 |
+
int y = obj.rect.y - label_size.height - baseLine;
|
380 |
+
if (y < 0)
|
381 |
+
y = 0;
|
382 |
+
if (x + label_size.width > image.cols)
|
383 |
+
x = image.cols - label_size.width;
|
384 |
+
|
385 |
+
cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
|
386 |
+
cv::Scalar(255, 255, 255), -1);
|
387 |
+
|
388 |
+
cv::putText(image, text, cv::Point(x, y + label_size.height),
|
389 |
+
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
|
390 |
+
}
|
391 |
+
|
392 |
+
cv::imshow("image", image);
|
393 |
+
cv::waitKey(0);
|
394 |
+
}
|
395 |
+
|
396 |
+
int main(int argc, char** argv)
|
397 |
+
{
|
398 |
+
if (argc != 2)
|
399 |
+
{
|
400 |
+
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
|
401 |
+
return -1;
|
402 |
+
}
|
403 |
+
|
404 |
+
const char* imagepath = argv[1];
|
405 |
+
|
406 |
+
cv::Mat m = cv::imread(imagepath, 1);
|
407 |
+
if (m.empty())
|
408 |
+
{
|
409 |
+
fprintf(stderr, "cv::imread %s failed\n", imagepath);
|
410 |
+
return -1;
|
411 |
+
}
|
412 |
+
|
413 |
+
std::vector<Object> objects;
|
414 |
+
detect_yolox(m, objects);
|
415 |
+
|
416 |
+
draw_objects(m, objects);
|
417 |
+
|
418 |
+
return 0;
|
419 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
opencv_python
|
4 |
+
loguru
|
5 |
+
scikit_image
|
6 |
+
tqdm
|
7 |
+
apex
|
8 |
+
torchvision
|
9 |
+
pycocotools
|
10 |
+
apex
|
11 |
+
Pillow
|
12 |
+
skimage
|
13 |
+
thop
|
14 |
+
ninja
|
setup.cfg
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[isort]
|
2 |
+
line_length = 100
|
3 |
+
multi_line_output = 3
|
4 |
+
balanced_wrapping = True
|
5 |
+
known_standard_library = setuptools
|
6 |
+
known_third_party = tqdm,loguru
|
7 |
+
known_data_processing = cv2,numpy,scipy,PIL,matplotlib,scikit_image
|
8 |
+
known_datasets = pycocotools
|
9 |
+
known_deeplearning = torch,torchvision,caffe2,onnx,apex,timm,thop,torch2trt,tensorrt
|
10 |
+
known_myself = yolox
|
11 |
+
sections = FUTURE,STDLIB,THIRDPARTY,data_processing,datasets,deeplearning,myself,FIRSTPARTY,LOCALFOLDER
|
12 |
+
no_lines_before=STDLIB,THIRDPARTY,datasets
|
13 |
+
default_section = FIRSTPARTY
|
14 |
+
|
15 |
+
[flake8]
|
16 |
+
max-line-length = 100
|
17 |
+
max-complexity = 18
|
18 |
+
exclude = __init__.py
|
setup.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Megvii, Inc. and its affiliates. All Rights Reserved
|
3 |
+
|
4 |
+
import re
|
5 |
+
import setuptools
|
6 |
+
import glob
|
7 |
+
from os import path
|
8 |
+
import torch
|
9 |
+
from torch.utils.cpp_extension import CppExtension
|
10 |
+
|
11 |
+
torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
|
12 |
+
assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"
|
13 |
+
|
14 |
+
|
15 |
+
def get_extensions():
|
16 |
+
this_dir = path.dirname(path.abspath(__file__))
|
17 |
+
extensions_dir = path.join(this_dir, "yolox", "layers", "csrc")
|
18 |
+
|
19 |
+
main_source = path.join(extensions_dir, "vision.cpp")
|
20 |
+
sources = glob.glob(path.join(extensions_dir, "**", "*.cpp"))
|
21 |
+
|
22 |
+
sources = [main_source] + sources
|
23 |
+
extension = CppExtension
|
24 |
+
|
25 |
+
extra_compile_args = {"cxx": ["-O3"]}
|
26 |
+
define_macros = []
|
27 |
+
|
28 |
+
include_dirs = [extensions_dir]
|
29 |
+
|
30 |
+
ext_modules = [
|
31 |
+
extension(
|
32 |
+
"yolox._C",
|
33 |
+
sources,
|
34 |
+
include_dirs=include_dirs,
|
35 |
+
define_macros=define_macros,
|
36 |
+
extra_compile_args=extra_compile_args,
|
37 |
+
)
|
38 |
+
]
|
39 |
+
|
40 |
+
return ext_modules
|
41 |
+
|
42 |
+
|
43 |
+
with open("yolox/__init__.py", "r") as f:
|
44 |
+
version = re.search(
|
45 |
+
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
|
46 |
+
f.read(), re.MULTILINE
|
47 |
+
).group(1)
|
48 |
+
|
49 |
+
|
50 |
+
with open("README.md", "r") as f:
|
51 |
+
long_description = f.read()
|
52 |
+
|
53 |
+
|
54 |
+
setuptools.setup(
|
55 |
+
name="yolox",
|
56 |
+
version=version,
|
57 |
+
author="basedet team",
|
58 |
+
python_requires=">=3.6",
|
59 |
+
long_description=long_description,
|
60 |
+
ext_modules=get_extensions(),
|
61 |
+
classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"],
|
62 |
+
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
63 |
+
packages=setuptools.find_packages(),
|
64 |
+
)
|
tools/demo.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.backends.cudnn as cudnn
|
10 |
+
|
11 |
+
from yolox.data.data_augment import preproc
|
12 |
+
from yolox.data.datasets import COCO_CLASSES
|
13 |
+
from yolox.exp import get_exp
|
14 |
+
from yolox.utils import fuse_model, get_model_info, postprocess, setup_logger, vis, xyxy2xywh
|
15 |
+
|
16 |
+
IMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
|
17 |
+
|
18 |
+
|
19 |
+
def make_parser():
|
20 |
+
parser = argparse.ArgumentParser("YOLOX Demo!")
|
21 |
+
parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')
|
22 |
+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
23 |
+
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
24 |
+
|
25 |
+
parser.add_argument('--path', default='./demo', help='path to images or video')
|
26 |
+
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
|
27 |
+
parser.add_argument(
|
28 |
+
'--save_result', action='store_true',
|
29 |
+
help='whether to save the inference result of image/video'
|
30 |
+
)
|
31 |
+
|
32 |
+
# exp file
|
33 |
+
parser.add_argument(
|
34 |
+
"-f",
|
35 |
+
"--exp_file",
|
36 |
+
default=None,
|
37 |
+
type=str,
|
38 |
+
help="pls input your expriment description file",
|
39 |
+
)
|
40 |
+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
|
41 |
+
parser.add_argument("--conf", default=None, type=float, help="test conf")
|
42 |
+
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
|
43 |
+
parser.add_argument("--tsize", default=None, type=int, help="test img size")
|
44 |
+
parser.add_argument(
|
45 |
+
"--fp16",
|
46 |
+
dest="fp16",
|
47 |
+
default=False,
|
48 |
+
action="store_true",
|
49 |
+
help="Adopting mix precision evaluating.",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--fuse",
|
53 |
+
dest="fuse",
|
54 |
+
default=False,
|
55 |
+
action="store_true",
|
56 |
+
help="Fuse conv and bn for testing.",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--trt",
|
60 |
+
dest="trt",
|
61 |
+
default=False,
|
62 |
+
action="store_true",
|
63 |
+
help="Using TensorRT model for testing.",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"opts",
|
67 |
+
help="Modify config options using the command-line",
|
68 |
+
default=None,
|
69 |
+
nargs=argparse.REMAINDER,
|
70 |
+
)
|
71 |
+
return parser
|
72 |
+
|
73 |
+
|
74 |
+
def get_image_list(path):
|
75 |
+
image_names = []
|
76 |
+
for maindir, subdir, file_name_list in os.walk(path):
|
77 |
+
for filename in file_name_list:
|
78 |
+
apath = os.path.join(maindir, filename)
|
79 |
+
ext = os.path.splitext(apath)[1]
|
80 |
+
if ext in IMAGE_EXT:
|
81 |
+
image_names.append(apath)
|
82 |
+
return image_names
|
83 |
+
|
84 |
+
|
85 |
+
class Predictor(object):
|
86 |
+
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None):
|
87 |
+
self.model = model
|
88 |
+
self.cls_names = cls_names
|
89 |
+
self.decoder = decoder
|
90 |
+
self.num_classes = exp.num_classes
|
91 |
+
self.confthre = exp.test_conf
|
92 |
+
self.nmsthre = exp.nmsthre
|
93 |
+
self.test_size = exp.test_size
|
94 |
+
if trt_file is not None:
|
95 |
+
from torch2trt import TRTModule
|
96 |
+
model_trt = TRTModule()
|
97 |
+
model_trt.load_state_dict(torch.load(trt_file))
|
98 |
+
|
99 |
+
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
|
100 |
+
self.model(x)
|
101 |
+
self.model = model_trt
|
102 |
+
self.rgb_means = (0.485, 0.456, 0.406)
|
103 |
+
self.std = (0.229, 0.224, 0.225)
|
104 |
+
|
105 |
+
def inference(self, img):
|
106 |
+
img_info = {'id': 0}
|
107 |
+
if isinstance(img, str):
|
108 |
+
img_info['file_name'] = os.path.basename(img)
|
109 |
+
img = cv2.imread(img)
|
110 |
+
else:
|
111 |
+
img_info['file_name'] = None
|
112 |
+
|
113 |
+
height, width = img.shape[:2]
|
114 |
+
img_info['height'] = height
|
115 |
+
img_info['width'] = width
|
116 |
+
img_info['raw_img'] = img
|
117 |
+
|
118 |
+
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
|
119 |
+
img_info['ratio'] = ratio
|
120 |
+
img = torch.from_numpy(img).unsqueeze(0).cuda()
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
t0 = time.time()
|
124 |
+
outputs = self.model(img)
|
125 |
+
if self.decoder is not None:
|
126 |
+
outputs = self.decoder(outputs, dtype=outputs.type())
|
127 |
+
outputs = postprocess(
|
128 |
+
outputs, self.num_classes, self.confthre, self.nmsthre
|
129 |
+
)
|
130 |
+
logger.info('Infer time: {:.4f}s'.format(time.time()-t0))
|
131 |
+
return outputs, img_info
|
132 |
+
|
133 |
+
def visual(self, output, img_info, cls_conf=0.35):
|
134 |
+
ratio = img_info['ratio']
|
135 |
+
img = img_info['raw_img']
|
136 |
+
output = output.cpu()
|
137 |
+
|
138 |
+
bboxes = output[:, 0:4]
|
139 |
+
|
140 |
+
# preprocessing: resize
|
141 |
+
bboxes /= ratio
|
142 |
+
bboxes = xyxy2xywh(bboxes)
|
143 |
+
|
144 |
+
cls = output[:, 6]
|
145 |
+
scores = output[:, 4] * output[:, 5]
|
146 |
+
|
147 |
+
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
|
148 |
+
return vis_res
|
149 |
+
|
150 |
+
|
151 |
+
def image_demo(predictor, vis_folder, path, current_time, save_result):
|
152 |
+
if os.path.isdir(path):
|
153 |
+
files = get_image_list(path)
|
154 |
+
else:
|
155 |
+
files = [path]
|
156 |
+
files.sort()
|
157 |
+
for image_name in files:
|
158 |
+
outputs, img_info = predictor.inference(image_name)
|
159 |
+
result_image = predictor.visual(outputs[0], img_info)
|
160 |
+
if save_result:
|
161 |
+
save_folder = os.path.join(
|
162 |
+
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
|
163 |
+
)
|
164 |
+
os.makedirs(save_folder, exist_ok=True)
|
165 |
+
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
|
166 |
+
logger.info("Saving detection result in {}".format(save_file_name))
|
167 |
+
cv2.imwrite(save_file_name, result_image)
|
168 |
+
ch = cv2.waitKey(0)
|
169 |
+
if ch == 27 or ch == ord('q') or ch == ord('Q'):
|
170 |
+
break
|
171 |
+
|
172 |
+
|
173 |
+
def imageflow_demo(predictor, vis_folder, current_time, args):
|
174 |
+
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
|
175 |
+
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
|
176 |
+
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
|
177 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
178 |
+
save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
|
179 |
+
os.makedirs(save_folder, exist_ok=True)
|
180 |
+
if args.demo == "video":
|
181 |
+
save_path = os.path.join(save_folder, args.path.split('/')[-1])
|
182 |
+
else:
|
183 |
+
save_path = os.path.join(save_folder, 'camera.mp4')
|
184 |
+
logger.info(f'video save_path is {save_path}')
|
185 |
+
vid_writer = cv2.VideoWriter(
|
186 |
+
save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height))
|
187 |
+
)
|
188 |
+
while True:
|
189 |
+
ret_val, frame = cap.read()
|
190 |
+
if ret_val:
|
191 |
+
outputs, img_info = predictor.inference(frame)
|
192 |
+
result_frame = predictor.visualize(outputs[0], img_info)
|
193 |
+
if args.save_result:
|
194 |
+
vid_writer.write(result_frame)
|
195 |
+
ch = cv2.waitKey(1)
|
196 |
+
if ch == 27 or ch == ord('q') or ch == ord('Q'):
|
197 |
+
break
|
198 |
+
else:
|
199 |
+
break
|
200 |
+
|
201 |
+
|
202 |
+
def main(exp, args):
|
203 |
+
if not args.experiment_name:
|
204 |
+
args.experiment_name = exp.exp_name
|
205 |
+
|
206 |
+
# set environment variables for distributed training
|
207 |
+
cudnn.benchmark = True
|
208 |
+
rank = 0
|
209 |
+
|
210 |
+
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
211 |
+
os.makedirs(file_name, exist_ok=True)
|
212 |
+
|
213 |
+
if args.save_result:
|
214 |
+
vis_folder = os.path.join(file_name, 'vis_res')
|
215 |
+
os.makedirs(vis_folder, exist_ok=True)
|
216 |
+
|
217 |
+
setup_logger(
|
218 |
+
file_name, distributed_rank=rank, filename="demo_log.txt", mode="a"
|
219 |
+
)
|
220 |
+
logger.info("Args: {}".format(args))
|
221 |
+
|
222 |
+
if args.conf is not None:
|
223 |
+
exp.test_conf = args.conf
|
224 |
+
if args.nms is not None:
|
225 |
+
exp.nmsthre = args.nms
|
226 |
+
if args.tsize is not None:
|
227 |
+
exp.test_size = (args.tsize, args.tsize)
|
228 |
+
|
229 |
+
model = exp.get_model()
|
230 |
+
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
231 |
+
|
232 |
+
torch.cuda.set_device(rank)
|
233 |
+
model.cuda(rank)
|
234 |
+
model.eval()
|
235 |
+
|
236 |
+
if not args.trt:
|
237 |
+
if args.ckpt is None:
|
238 |
+
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
239 |
+
else:
|
240 |
+
ckpt_file = args.ckpt
|
241 |
+
logger.info("loading checkpoint")
|
242 |
+
loc = "cuda:{}".format(rank)
|
243 |
+
ckpt = torch.load(ckpt_file, map_location=loc)
|
244 |
+
# load the model state dict
|
245 |
+
model.load_state_dict(ckpt["model"])
|
246 |
+
logger.info("loaded checkpoint done.")
|
247 |
+
|
248 |
+
if args.fuse:
|
249 |
+
logger.info("\tFusing model...")
|
250 |
+
model = fuse_model(model)
|
251 |
+
|
252 |
+
if args.trt:
|
253 |
+
assert (not args.fuse),\
|
254 |
+
"TensorRT model is not support model fusing!"
|
255 |
+
trt_file = os.path.join(file_name, "model_trt.pth")
|
256 |
+
assert os.path.exists(trt_file), (
|
257 |
+
"TensorRT model is not found!\n Run python3 yolox/deploy/trt.py first!"
|
258 |
+
)
|
259 |
+
model.head.decode_in_inference = False
|
260 |
+
decoder = model.head.decode_outputs
|
261 |
+
logger.info("Using TensorRT to inference")
|
262 |
+
else:
|
263 |
+
trt_file = None
|
264 |
+
decoder = None
|
265 |
+
|
266 |
+
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder)
|
267 |
+
current_time = time.localtime()
|
268 |
+
if args.demo == 'image':
|
269 |
+
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|
270 |
+
elif args.demo == 'video' or args.demo == 'webcam':
|
271 |
+
imageflow_demo(predictor, vis_folder, current_time, args)
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
args = make_parser().parse_args()
|
276 |
+
exp = get_exp(args.exp_file, args.name)
|
277 |
+
|
278 |
+
main(exp, args)
|
tools/eval.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
+
|
11 |
+
from yolox.core import launch
|
12 |
+
from yolox.exp import get_exp
|
13 |
+
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
|
14 |
+
|
15 |
+
|
16 |
+
def make_parser():
|
17 |
+
parser = argparse.ArgumentParser("YOLOX Eval")
|
18 |
+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
19 |
+
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
20 |
+
|
21 |
+
# distributed
|
22 |
+
parser.add_argument(
|
23 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--dist-url", default=None, type=str, help="url used to set up distributed training"
|
27 |
+
)
|
28 |
+
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
|
29 |
+
parser.add_argument(
|
30 |
+
"-d", "--devices", default=None, type=int, help="device for training"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--local_rank", default=0, type=int, help="local rank for dist training"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--num_machine", default=1, type=int, help="num of node for training"
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--machine_rank", default=0, type=int, help="node rank for multi-node training"
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"-f",
|
43 |
+
"--exp_file",
|
44 |
+
default=None,
|
45 |
+
type=str,
|
46 |
+
help="pls input your expriment description file",
|
47 |
+
)
|
48 |
+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
|
49 |
+
parser.add_argument("--conf", default=None, type=float, help="test conf")
|
50 |
+
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
|
51 |
+
parser.add_argument("--tsize", default=None, type=int, help="test img size")
|
52 |
+
parser.add_argument("--seed", default=None, type=int, help="eval seed")
|
53 |
+
parser.add_argument(
|
54 |
+
"--fp16",
|
55 |
+
dest="fp16",
|
56 |
+
default=False,
|
57 |
+
action="store_true",
|
58 |
+
help="Adopting mix precision evaluating.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--fuse",
|
62 |
+
dest="fuse",
|
63 |
+
default=False,
|
64 |
+
action="store_true",
|
65 |
+
help="Fuse conv and bn for testing.",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--trt",
|
69 |
+
dest="trt",
|
70 |
+
default=False,
|
71 |
+
action="store_true",
|
72 |
+
help="Using TensorRT model for testing.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--test",
|
76 |
+
dest="test",
|
77 |
+
default=False,
|
78 |
+
action="store_true",
|
79 |
+
help="Evaluating on test-dev set.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--speed", dest="speed", default=False, action="store_true", help="speed test only."
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"opts",
|
86 |
+
help="Modify config options using the command-line",
|
87 |
+
default=None,
|
88 |
+
nargs=argparse.REMAINDER,
|
89 |
+
)
|
90 |
+
return parser
|
91 |
+
|
92 |
+
|
93 |
+
@logger.catch
|
94 |
+
def main(exp, num_gpu, args):
|
95 |
+
if not args.experiment_name:
|
96 |
+
args.experiment_name = exp.exp_name
|
97 |
+
|
98 |
+
if args.seed is not None:
|
99 |
+
random.seed(args.seed)
|
100 |
+
torch.manual_seed(args.seed)
|
101 |
+
cudnn.deterministic = True
|
102 |
+
warnings.warn(
|
103 |
+
"You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
|
104 |
+
)
|
105 |
+
|
106 |
+
is_distributed = num_gpu > 1
|
107 |
+
|
108 |
+
# set environment variables for distributed training
|
109 |
+
configure_nccl()
|
110 |
+
cudnn.benchmark = True
|
111 |
+
|
112 |
+
# rank = args.local_rank
|
113 |
+
rank = get_local_rank()
|
114 |
+
|
115 |
+
if rank == 0:
|
116 |
+
if os.path.exists("./" + args.experiment_name + "ip_add.txt"):
|
117 |
+
os.remove("./" + args.experiment_name + "ip_add.txt")
|
118 |
+
|
119 |
+
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
120 |
+
|
121 |
+
if rank == 0:
|
122 |
+
os.makedirs(file_name, exist_ok=True)
|
123 |
+
|
124 |
+
setup_logger(
|
125 |
+
file_name, distributed_rank=rank, filename="val_log.txt", mode="a"
|
126 |
+
)
|
127 |
+
logger.info("Args: {}".format(args))
|
128 |
+
|
129 |
+
if args.conf is not None:
|
130 |
+
exp.test_conf = args.conf
|
131 |
+
if args.nms is not None:
|
132 |
+
exp.nmsthre = args.nms
|
133 |
+
if args.tsize is not None:
|
134 |
+
exp.test_size = (args.tsize, args.tsize)
|
135 |
+
|
136 |
+
model = exp.get_model()
|
137 |
+
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
138 |
+
logger.info("Model Structure:\n{}".format(str(model)))
|
139 |
+
|
140 |
+
evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test)
|
141 |
+
|
142 |
+
torch.cuda.set_device(rank)
|
143 |
+
model.cuda(rank)
|
144 |
+
model.eval()
|
145 |
+
|
146 |
+
if not args.speed and not args.trt:
|
147 |
+
if args.ckpt is None:
|
148 |
+
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
149 |
+
else:
|
150 |
+
ckpt_file = args.ckpt
|
151 |
+
logger.info("loading checkpoint")
|
152 |
+
loc = "cuda:{}".format(rank)
|
153 |
+
ckpt = torch.load(ckpt_file, map_location=loc)
|
154 |
+
# load the model state dict
|
155 |
+
model.load_state_dict(ckpt["model"])
|
156 |
+
logger.info("loaded checkpoint done.")
|
157 |
+
|
158 |
+
if is_distributed:
|
159 |
+
model = DDP(model, device_ids=[rank])
|
160 |
+
|
161 |
+
if args.fuse:
|
162 |
+
logger.info("\tFusing model...")
|
163 |
+
model = fuse_model(model)
|
164 |
+
|
165 |
+
if args.trt:
|
166 |
+
assert (not args.fuse and not is_distributed and args.batch_size == 1),\
|
167 |
+
"TensorRT model is not support model fusing and distributed inferencing!"
|
168 |
+
trt_file = os.path.join(file_name, "model_trt.pth")
|
169 |
+
assert os.path.exists(trt_file), "TensorRT model is not found!\n Run tools/trt.py first!"
|
170 |
+
model.head.decode_in_inference = False
|
171 |
+
decoder = model.head.decode_outputs
|
172 |
+
else:
|
173 |
+
trt_file = None
|
174 |
+
decoder = None
|
175 |
+
|
176 |
+
# start evaluate
|
177 |
+
*_, summary = evaluator.evaluate(
|
178 |
+
model, is_distributed, args.fp16, trt_file, decoder, exp.test_size
|
179 |
+
)
|
180 |
+
logger.info("\n" + summary)
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
args = make_parser().parse_args()
|
185 |
+
exp = get_exp(args.exp_file, args.name)
|
186 |
+
exp.merge(args.opts)
|
187 |
+
|
188 |
+
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
|
189 |
+
assert num_gpu <= torch.cuda.device_count()
|
190 |
+
|
191 |
+
dist_url = "auto" if args.dist_url is None else args.dist_url
|
192 |
+
launch(
|
193 |
+
main, num_gpu, args.num_machine, backend=args.dist_backend,
|
194 |
+
dist_url=dist_url, args=(exp, num_gpu, args)
|
195 |
+
)
|
tools/export_onnx.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from yolox.exp import get_exp
|
12 |
+
from yolox.models.network_blocks import SiLU
|
13 |
+
from yolox.utils import replace_module
|
14 |
+
|
15 |
+
|
16 |
+
def make_parser():
|
17 |
+
parser = argparse.ArgumentParser("YOLOX onnx deploy")
|
18 |
+
parser.add_argument(
|
19 |
+
"--output-name", type=str, default="yolox.onnx", help="output name of models"
|
20 |
+
)
|
21 |
+
parser.add_argument("--input", default="images", type=str, help="input name of onnx model")
|
22 |
+
parser.add_argument("--output", default="output", type=str, help="output name of onnx model")
|
23 |
+
parser.add_argument("-o", "--opset", default=11, type=int, help="onnx opset version")
|
24 |
+
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"-f",
|
28 |
+
"--exp_file",
|
29 |
+
default=None,
|
30 |
+
type=str,
|
31 |
+
help="expriment description file",
|
32 |
+
)
|
33 |
+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
34 |
+
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
35 |
+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
|
36 |
+
parser.add_argument(
|
37 |
+
"opts",
|
38 |
+
help="Modify config options using the command-line",
|
39 |
+
default=None,
|
40 |
+
nargs=argparse.REMAINDER,
|
41 |
+
)
|
42 |
+
|
43 |
+
return parser
|
44 |
+
|
45 |
+
|
46 |
+
@logger.catch
|
47 |
+
def main():
|
48 |
+
args = make_parser().parse_args()
|
49 |
+
logger.info("args value: {}".format(args))
|
50 |
+
exp = get_exp(args.exp_file, args.name)
|
51 |
+
exp.merge(args.opts)
|
52 |
+
|
53 |
+
if not args.experiment_name:
|
54 |
+
args.experiment_name = exp.exp_name
|
55 |
+
|
56 |
+
model = exp.get_model()
|
57 |
+
if args.ckpt is None:
|
58 |
+
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
59 |
+
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
60 |
+
else:
|
61 |
+
ckpt_file = args.ckpt
|
62 |
+
|
63 |
+
ckpt = torch.load(ckpt_file, map_location="cpu")
|
64 |
+
# load the model state dict
|
65 |
+
|
66 |
+
model.eval()
|
67 |
+
if "model" in ckpt:
|
68 |
+
ckpt = ckpt["model"]
|
69 |
+
model.load_state_dict(ckpt)
|
70 |
+
model = replace_module(model, nn.SiLU, SiLU)
|
71 |
+
model.head.decode_in_inference = False
|
72 |
+
|
73 |
+
logger.info("loaded checkpoint done.")
|
74 |
+
dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
|
75 |
+
torch.onnx._export(
|
76 |
+
model,
|
77 |
+
dummy_input,
|
78 |
+
args.output_name,
|
79 |
+
input_names=[args.input],
|
80 |
+
output_names=[args.output],
|
81 |
+
opset_version=args.opset,
|
82 |
+
)
|
83 |
+
logger.info("generate onnx named {}".format(args.output_name))
|
84 |
+
|
85 |
+
if not args.no_onnxsim:
|
86 |
+
# use onnxsimplify to reduce reduent model.
|
87 |
+
os.system("python3 -m onnxsim {} {}".format(args.output_name, args.output_name))
|
88 |
+
logger.info("generate simplify onnx named {}".format(args.output_name))
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
main()
|
tools/train.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
|
9 |
+
from yolox.core import Trainer, launch
|
10 |
+
from yolox.exp import get_exp
|
11 |
+
from yolox.utils import configure_nccl
|
12 |
+
|
13 |
+
|
14 |
+
def make_parser():
|
15 |
+
parser = argparse.ArgumentParser("YOLOX train parser")
|
16 |
+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
17 |
+
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
18 |
+
|
19 |
+
# distributed
|
20 |
+
parser.add_argument(
|
21 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--dist-url", default=None, type=str, help="url used to set up distributed training"
|
25 |
+
)
|
26 |
+
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
|
27 |
+
parser.add_argument(
|
28 |
+
"-d", "--devices", default=None, type=int, help="device for training"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--local_rank", default=0, type=int, help="local rank for dist training"
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"-f",
|
35 |
+
"--exp_file",
|
36 |
+
default=None,
|
37 |
+
type=str,
|
38 |
+
help="plz input your expriment description file",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--resume", default=False, action="store_true", help="resume training"
|
42 |
+
)
|
43 |
+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
|
44 |
+
parser.add_argument(
|
45 |
+
"-e", "--start_epoch", default=None, type=int, help="resume training start epoch"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--num_machine", default=1, type=int, help="num of node for training"
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--machine_rank", default=0, type=int, help="node rank for multi-node training"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--fp16",
|
55 |
+
dest="fp16",
|
56 |
+
default=True,
|
57 |
+
action="store_true",
|
58 |
+
help="Adopting mix precision training.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"-o",
|
62 |
+
"--occumpy",
|
63 |
+
dest="occumpy",
|
64 |
+
default=False,
|
65 |
+
action="store_true",
|
66 |
+
help="occumpy GPU memory first for training.",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"opts",
|
70 |
+
help="Modify config options using the command-line",
|
71 |
+
default=None,
|
72 |
+
nargs=argparse.REMAINDER,
|
73 |
+
)
|
74 |
+
return parser
|
75 |
+
|
76 |
+
|
77 |
+
@logger.catch
|
78 |
+
def main(exp, args):
|
79 |
+
if not args.experiment_name:
|
80 |
+
args.experiment_name = exp.exp_name
|
81 |
+
|
82 |
+
if exp.seed is not None:
|
83 |
+
random.seed(exp.seed)
|
84 |
+
torch.manual_seed(exp.seed)
|
85 |
+
cudnn.deterministic = True
|
86 |
+
warnings.warn(
|
87 |
+
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
|
88 |
+
"which can slow down your training considerably! You may see unexpected behavior "
|
89 |
+
"when restarting from checkpoints."
|
90 |
+
)
|
91 |
+
|
92 |
+
# set environment variables for distributed training
|
93 |
+
configure_nccl()
|
94 |
+
cudnn.benchmark = True
|
95 |
+
|
96 |
+
trainer = Trainer(exp, args)
|
97 |
+
trainer.train()
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
args = make_parser().parse_args()
|
102 |
+
exp = get_exp(args.exp_file, args.name)
|
103 |
+
exp.merge(args.opts)
|
104 |
+
|
105 |
+
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
|
106 |
+
assert num_gpu <= torch.cuda.device_count()
|
107 |
+
|
108 |
+
dist_url = "auto" if args.dist_url is None else args.dist_url
|
109 |
+
launch(
|
110 |
+
main, num_gpu, args.num_machine, backend=args.dist_backend,
|
111 |
+
dist_url=dist_url, args=(exp, args)
|
112 |
+
)
|
tools/trt.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
import tensorrt as trt
|
9 |
+
import torch
|
10 |
+
from torch2trt import torch2trt
|
11 |
+
|
12 |
+
from yolox.exp import get_exp
|
13 |
+
|
14 |
+
|
15 |
+
def make_parser():
|
16 |
+
parser = argparse.ArgumentParser("YOLOX ncnn deploy")
|
17 |
+
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
18 |
+
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
19 |
+
|
20 |
+
parser.add_argument(
|
21 |
+
"-f",
|
22 |
+
"--exp_file",
|
23 |
+
default=None,
|
24 |
+
type=str,
|
25 |
+
help="pls input your expriment description file",
|
26 |
+
)
|
27 |
+
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
|
28 |
+
return parser
|
29 |
+
|
30 |
+
|
31 |
+
@logger.catch
|
32 |
+
def main():
|
33 |
+
args = make_parser().parse_args()
|
34 |
+
exp = get_exp(args.exp_file, args.name)
|
35 |
+
if not args.experiment_name:
|
36 |
+
args.experiment_name = exp.exp_name
|
37 |
+
|
38 |
+
model = exp.get_model()
|
39 |
+
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
40 |
+
os.makedirs(file_name, exist_ok=True)
|
41 |
+
if args.ckpt is None:
|
42 |
+
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
43 |
+
else:
|
44 |
+
ckpt_file = args.ckpt
|
45 |
+
|
46 |
+
ckpt = torch.load(ckpt_file, map_location="cpu")
|
47 |
+
# load the model state dict
|
48 |
+
|
49 |
+
model.load_state_dict(ckpt["model"])
|
50 |
+
logger.info("loaded checkpoint done.")
|
51 |
+
model.eval()
|
52 |
+
model.cuda()
|
53 |
+
model.head.decode_in_inference = False
|
54 |
+
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
|
55 |
+
model_trt = torch2trt(
|
56 |
+
model,
|
57 |
+
[x],
|
58 |
+
fp16_mode=True,
|
59 |
+
log_level=trt.Logger.INFO,
|
60 |
+
max_workspace_size=(1 << 32),
|
61 |
+
)
|
62 |
+
torch.save(model_trt.state_dict(), os.path.join(file_name, 'model_trt.pth'))
|
63 |
+
logger.info("Converted TensorRT model done.")
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|
yolox/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .utils import configure_module
|
5 |
+
|
6 |
+
configure_module()
|
7 |
+
|
8 |
+
__version__ = "0.1.0"
|
yolox/core/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .launch import launch
|
5 |
+
from .trainer import Trainer
|
yolox/core/launch.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Code are based on
|
4 |
+
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
|
5 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
6 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
7 |
+
|
8 |
+
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.multiprocessing as mp
|
14 |
+
|
15 |
+
import yolox.utils.dist as comm
|
16 |
+
|
17 |
+
__all__ = ["launch"]
|
18 |
+
|
19 |
+
|
20 |
+
def _find_free_port():
|
21 |
+
"""
|
22 |
+
Find an available port of current machine / node.
|
23 |
+
"""
|
24 |
+
import socket
|
25 |
+
|
26 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
27 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
28 |
+
sock.bind(("", 0))
|
29 |
+
port = sock.getsockname()[1]
|
30 |
+
sock.close()
|
31 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
32 |
+
return port
|
33 |
+
|
34 |
+
|
35 |
+
def launch(
|
36 |
+
main_func, num_gpus_per_machine, num_machines=1, machine_rank=0,
|
37 |
+
backend="nccl", dist_url=None, args=()
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
main_func: a function that will be called by `main_func(*args)`
|
42 |
+
num_machines (int): the total number of machines
|
43 |
+
machine_rank (int): the rank of this machine (one per machine)
|
44 |
+
dist_url (str): url to connect to for distributed training, including protocol
|
45 |
+
e.g. "tcp://127.0.0.1:8686".
|
46 |
+
Can be set to auto to automatically select a free port on localhost
|
47 |
+
args (tuple): arguments passed to main_func
|
48 |
+
"""
|
49 |
+
world_size = num_machines * num_gpus_per_machine
|
50 |
+
if world_size > 1:
|
51 |
+
# https://github.com/pytorch/pytorch/pull/14391
|
52 |
+
# TODO prctl in spawned processes
|
53 |
+
|
54 |
+
if dist_url == "auto":
|
55 |
+
assert num_machines == 1, "dist_url=auto cannot work with distributed training."
|
56 |
+
port = _find_free_port()
|
57 |
+
dist_url = f"tcp://127.0.0.1:{port}"
|
58 |
+
|
59 |
+
mp.spawn(
|
60 |
+
_distributed_worker,
|
61 |
+
nprocs=num_gpus_per_machine,
|
62 |
+
args=(
|
63 |
+
main_func, world_size, num_gpus_per_machine,
|
64 |
+
machine_rank, backend, dist_url, args
|
65 |
+
),
|
66 |
+
daemon=False,
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
main_func(*args)
|
70 |
+
|
71 |
+
|
72 |
+
def _distributed_worker(
|
73 |
+
local_rank, main_func, world_size, num_gpus_per_machine,
|
74 |
+
machine_rank, backend, dist_url, args
|
75 |
+
):
|
76 |
+
assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
|
77 |
+
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
78 |
+
logger.info("Rank {} initialization finished.".format(global_rank))
|
79 |
+
try:
|
80 |
+
dist.init_process_group(
|
81 |
+
backend=backend,
|
82 |
+
init_method=dist_url,
|
83 |
+
world_size=world_size,
|
84 |
+
rank=global_rank,
|
85 |
+
)
|
86 |
+
except Exception:
|
87 |
+
logger.error("Process group URL: {}".format(dist_url))
|
88 |
+
raise
|
89 |
+
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
90 |
+
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
91 |
+
comm.synchronize()
|
92 |
+
|
93 |
+
assert num_gpus_per_machine <= torch.cuda.device_count()
|
94 |
+
torch.cuda.set_device(local_rank)
|
95 |
+
|
96 |
+
# Setup the local process group (which contains ranks within the same machine)
|
97 |
+
assert comm._LOCAL_PROCESS_GROUP is None
|
98 |
+
num_machines = world_size // num_gpus_per_machine
|
99 |
+
for i in range(num_machines):
|
100 |
+
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
|
101 |
+
pg = dist.new_group(ranks_on_i)
|
102 |
+
if i == machine_rank:
|
103 |
+
comm._LOCAL_PROCESS_GROUP = pg
|
104 |
+
|
105 |
+
main_func(*args)
|
yolox/core/trainer.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
import apex
|
7 |
+
import torch
|
8 |
+
from apex import amp
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
|
11 |
+
from yolox.data import DataPrefetcher
|
12 |
+
from yolox.utils import (
|
13 |
+
MeterBuffer,
|
14 |
+
ModelEMA,
|
15 |
+
all_reduce_norm,
|
16 |
+
get_local_rank,
|
17 |
+
get_model_info,
|
18 |
+
get_rank,
|
19 |
+
get_world_size,
|
20 |
+
gpu_mem_usage,
|
21 |
+
load_ckpt,
|
22 |
+
occumpy_mem,
|
23 |
+
save_checkpoint,
|
24 |
+
setup_logger,
|
25 |
+
synchronize
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class Trainer:
|
30 |
+
|
31 |
+
def __init__(self, exp, args):
|
32 |
+
# init function only defines some basic attr, other attrs like model, optimizer are built in
|
33 |
+
# before_train methods.
|
34 |
+
self.exp = exp
|
35 |
+
self.args = args
|
36 |
+
|
37 |
+
# training related attr
|
38 |
+
self.max_epoch = exp.max_epoch
|
39 |
+
self.amp_training = args.fp16
|
40 |
+
self.is_distributed = get_world_size() > 1
|
41 |
+
self.rank = get_rank()
|
42 |
+
self.local_rank = get_local_rank()
|
43 |
+
self.device = "cuda:{}".format(self.local_rank)
|
44 |
+
self.use_model_ema = exp.ema
|
45 |
+
|
46 |
+
# data/dataloader related attr
|
47 |
+
self.data_type = torch.float16 if args.fp16 else torch.float32
|
48 |
+
self.input_size = exp.input_size
|
49 |
+
self.best_ap = 0
|
50 |
+
|
51 |
+
# metric record
|
52 |
+
self.meter = MeterBuffer(window_size=exp.print_interval)
|
53 |
+
self.file_name = os.path.join(exp.output_dir, args.experiment_name)
|
54 |
+
|
55 |
+
if self.rank == 0 and os.path.exists("./" + args.experiment_name + "ip_add.txt"):
|
56 |
+
os.remove("./" + args.experiment_name + "ip_add.txt")
|
57 |
+
|
58 |
+
if self.rank == 0:
|
59 |
+
os.makedirs(self.file_name, exist_ok=True)
|
60 |
+
|
61 |
+
setup_logger(self.file_name, distributed_rank=self.rank, filename="train_log.txt", mode="a")
|
62 |
+
|
63 |
+
def train(self):
|
64 |
+
self.before_train()
|
65 |
+
try:
|
66 |
+
self.train_in_epoch()
|
67 |
+
except Exception:
|
68 |
+
raise
|
69 |
+
finally:
|
70 |
+
self.after_train()
|
71 |
+
|
72 |
+
def train_in_epoch(self):
|
73 |
+
for self.epoch in range(self.start_epoch, self.max_epoch):
|
74 |
+
self.before_epoch()
|
75 |
+
self.train_in_iter()
|
76 |
+
self.after_epoch()
|
77 |
+
|
78 |
+
def train_in_iter(self):
|
79 |
+
for self.iter in range(self.max_iter):
|
80 |
+
self.before_iter()
|
81 |
+
self.train_one_iter()
|
82 |
+
self.after_iter()
|
83 |
+
|
84 |
+
def train_one_iter(self):
|
85 |
+
iter_start_time = time.time()
|
86 |
+
|
87 |
+
inps, targets = self.prefetcher.next()
|
88 |
+
inps = inps.to(self.data_type)
|
89 |
+
targets = targets.to(self.data_type)
|
90 |
+
targets.requires_grad = False
|
91 |
+
data_end_time = time.time()
|
92 |
+
|
93 |
+
outputs = self.model(inps, targets)
|
94 |
+
loss = outputs["total_loss"]
|
95 |
+
|
96 |
+
self.optimizer.zero_grad()
|
97 |
+
if self.amp_training:
|
98 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
99 |
+
scaled_loss.backward()
|
100 |
+
else:
|
101 |
+
loss.backward()
|
102 |
+
self.optimizer.step()
|
103 |
+
|
104 |
+
if self.use_model_ema:
|
105 |
+
self.ema_model.update(self.model)
|
106 |
+
|
107 |
+
lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
|
108 |
+
for param_group in self.optimizer.param_groups:
|
109 |
+
param_group["lr"] = lr
|
110 |
+
|
111 |
+
iter_end_time = time.time()
|
112 |
+
self.meter.update(
|
113 |
+
iter_time=iter_end_time - iter_start_time,
|
114 |
+
data_time=data_end_time - iter_start_time,
|
115 |
+
lr=lr,
|
116 |
+
**outputs,
|
117 |
+
)
|
118 |
+
|
119 |
+
def before_train(self):
|
120 |
+
logger.info("args: {}".format(self.args))
|
121 |
+
logger.info("exp value:\n{}".format(self.exp))
|
122 |
+
|
123 |
+
# model related init
|
124 |
+
torch.cuda.set_device(self.local_rank)
|
125 |
+
model = self.exp.get_model()
|
126 |
+
logger.info("Model Summary: {}".format(get_model_info(model, self.exp.test_size)))
|
127 |
+
model.to(self.device)
|
128 |
+
|
129 |
+
# solver related init
|
130 |
+
self.optimizer = self.exp.get_optimizer(self.args.batch_size)
|
131 |
+
|
132 |
+
if self.amp_training:
|
133 |
+
model, optimizer = amp.initialize(model, self.optimizer, opt_level="O1")
|
134 |
+
|
135 |
+
# value of epoch will be set in `resume_train`
|
136 |
+
model = self.resume_train(model)
|
137 |
+
|
138 |
+
# data related init
|
139 |
+
self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
|
140 |
+
self.train_loader = self.exp.get_data_loader(
|
141 |
+
batch_size=self.args.batch_size,
|
142 |
+
is_distributed=self.is_distributed,
|
143 |
+
no_aug=self.no_aug
|
144 |
+
)
|
145 |
+
logger.info("init prefetcher, this might take a while...")
|
146 |
+
self.prefetcher = DataPrefetcher(self.train_loader)
|
147 |
+
# max_iter means iters per epoch
|
148 |
+
self.max_iter = len(self.train_loader)
|
149 |
+
|
150 |
+
self.lr_scheduler = self.exp.get_lr_scheduler(
|
151 |
+
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
|
152 |
+
)
|
153 |
+
if self.args.occumpy:
|
154 |
+
occumpy_mem(self.local_rank)
|
155 |
+
|
156 |
+
if self.is_distributed:
|
157 |
+
model = apex.parallel.DistributedDataParallel(model)
|
158 |
+
# from torch.nn.parallel import DistributedDataParallel as DDP
|
159 |
+
# model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
|
160 |
+
|
161 |
+
if self.use_model_ema:
|
162 |
+
self.ema_model = ModelEMA(model, 0.9998)
|
163 |
+
self.ema_model.updates = self.max_iter * self.start_epoch
|
164 |
+
|
165 |
+
self.model = model
|
166 |
+
self.model.train()
|
167 |
+
|
168 |
+
self.evaluator = self.exp.get_evaluator(
|
169 |
+
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
170 |
+
)
|
171 |
+
# Tensorboard logger
|
172 |
+
if self.rank == 0:
|
173 |
+
self.tblogger = SummaryWriter(self.file_name)
|
174 |
+
|
175 |
+
logger.info("Training start...")
|
176 |
+
logger.info("\n{}".format(model))
|
177 |
+
|
178 |
+
def after_train(self):
|
179 |
+
logger.info(
|
180 |
+
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
|
181 |
+
)
|
182 |
+
|
183 |
+
def before_epoch(self):
|
184 |
+
logger.info("---> start train epoch{}".format(self.epoch + 1))
|
185 |
+
|
186 |
+
if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
|
187 |
+
logger.info("--->No mosaic aug now!")
|
188 |
+
self.train_loader.close_mosaic()
|
189 |
+
logger.info("--->Add additional L1 loss now!")
|
190 |
+
if self.is_distributed:
|
191 |
+
self.model.module.head.use_l1 = True
|
192 |
+
else:
|
193 |
+
self.model.head.use_l1 = True
|
194 |
+
self.exp.eval_interval = 1
|
195 |
+
if not self.no_aug:
|
196 |
+
self.save_ckpt(ckpt_name="last_mosaic_epoch")
|
197 |
+
|
198 |
+
def after_epoch(self):
|
199 |
+
if self.use_model_ema:
|
200 |
+
self.ema_model.update_attr(self.model)
|
201 |
+
|
202 |
+
self.save_ckpt(ckpt_name="latest")
|
203 |
+
|
204 |
+
if (self.epoch + 1) % self.exp.eval_interval == 0:
|
205 |
+
all_reduce_norm(self.model)
|
206 |
+
self.evaluate_and_save_model()
|
207 |
+
|
208 |
+
def before_iter(self):
|
209 |
+
pass
|
210 |
+
|
211 |
+
def after_iter(self):
|
212 |
+
"""
|
213 |
+
`after_iter` contains two parts of logic:
|
214 |
+
* log information
|
215 |
+
* reset setting of resize
|
216 |
+
"""
|
217 |
+
# log needed information
|
218 |
+
if (self.iter + 1) % self.exp.print_interval == 0:
|
219 |
+
# TODO check ETA logic
|
220 |
+
left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
|
221 |
+
eta_seconds = self.meter["iter_time"].global_avg * left_iters
|
222 |
+
eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
|
223 |
+
|
224 |
+
progress_str = "epoch: {}/{}, iter: {}/{}".format(
|
225 |
+
self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
|
226 |
+
)
|
227 |
+
loss_meter = self.meter.get_filtered_meter("loss")
|
228 |
+
loss_str = ", ".join(["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()])
|
229 |
+
|
230 |
+
time_meter = self.meter.get_filtered_meter("time")
|
231 |
+
time_str = ", ".join(["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()])
|
232 |
+
|
233 |
+
logger.info(
|
234 |
+
"{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
|
235 |
+
progress_str,
|
236 |
+
gpu_mem_usage(),
|
237 |
+
time_str,
|
238 |
+
loss_str,
|
239 |
+
self.meter["lr"].latest,
|
240 |
+
)
|
241 |
+
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
|
242 |
+
)
|
243 |
+
self.meter.clear_meters()
|
244 |
+
|
245 |
+
# random resizing
|
246 |
+
if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
|
247 |
+
self.input_size = self.exp.random_resize(
|
248 |
+
self.train_loader, self.epoch, self.rank, self.is_distributed
|
249 |
+
)
|
250 |
+
|
251 |
+
@property
|
252 |
+
def progress_in_iter(self):
|
253 |
+
return self.epoch * self.max_iter + self.iter
|
254 |
+
|
255 |
+
def resume_train(self, model):
|
256 |
+
if self.args.resume:
|
257 |
+
logger.info("resume training")
|
258 |
+
if self.args.ckpt is None:
|
259 |
+
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar")
|
260 |
+
else:
|
261 |
+
ckpt_file = self.args.ckpt
|
262 |
+
|
263 |
+
ckpt = torch.load(ckpt_file, map_location=self.device)
|
264 |
+
# resume the model/optimizer state dict
|
265 |
+
model.load_state_dict(ckpt["model"])
|
266 |
+
self.optimizer.load_state_dict(ckpt["optimizer"])
|
267 |
+
# resume the training states variables
|
268 |
+
if self.amp_training and "amp" in ckpt:
|
269 |
+
amp.load_state_dict(ckpt["amp"])
|
270 |
+
start_epoch = (
|
271 |
+
self.args.start_epoch - 1
|
272 |
+
if self.args.start_epoch is not None
|
273 |
+
else ckpt["start_epoch"]
|
274 |
+
)
|
275 |
+
self.start_epoch = start_epoch
|
276 |
+
logger.info("loaded checkpoint '{}' (epoch {})".format(self.args.resume, self.start_epoch)) # noqa
|
277 |
+
else:
|
278 |
+
if self.args.ckpt is not None:
|
279 |
+
logger.info("loading checkpoint for fine tuning")
|
280 |
+
ckpt_file = self.args.ckpt
|
281 |
+
ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
|
282 |
+
model = load_ckpt(self.model, ckpt)
|
283 |
+
self.start_epoch = 0
|
284 |
+
|
285 |
+
return model
|
286 |
+
|
287 |
+
def evaluate_and_save_model(self):
|
288 |
+
evalmodel = self.ema_model.ema if self.use_model_ema else self.model
|
289 |
+
ap50_95, ap50, summary = self.exp.eval(evalmodel, self.evaluator, self.is_distributed)
|
290 |
+
self.model.train()
|
291 |
+
if self.rank == 0:
|
292 |
+
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
293 |
+
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
|
294 |
+
logger.info("\n" + summary)
|
295 |
+
synchronize()
|
296 |
+
|
297 |
+
self.save_ckpt("last_epoch", ap50_95 > self.best_ap)
|
298 |
+
self.best_ap = max(self.best_ap, ap50_95)
|
299 |
+
|
300 |
+
def save_ckpt(self, ckpt_name, update_best_ckpt=False):
|
301 |
+
if self.rank == 0:
|
302 |
+
save_model = self.ema_model.ema if self.use_model_ema else self.model
|
303 |
+
logger.info("Save weights to {}".format(self.file_name))
|
304 |
+
ckpt_state = {
|
305 |
+
"start_epoch": self.epoch + 1,
|
306 |
+
"model": save_model.state_dict(),
|
307 |
+
"optimizer": self.optimizer.state_dict(),
|
308 |
+
}
|
309 |
+
if self.amp_training:
|
310 |
+
# save amp state according to
|
311 |
+
# https://nvidia.github.io/apex/amp.html#checkpointing
|
312 |
+
ckpt_state["amp"] = amp.state_dict()
|
313 |
+
save_checkpoint(
|
314 |
+
ckpt_state,
|
315 |
+
update_best_ckpt,
|
316 |
+
self.file_name,
|
317 |
+
ckpt_name,
|
318 |
+
)
|
yolox/data/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .data_augment import TrainTransform, ValTransform
|
5 |
+
from .data_prefetcher import DataPrefetcher
|
6 |
+
from .dataloading import DataLoader, get_yolox_datadir
|
7 |
+
from .datasets import *
|
8 |
+
from .samplers import InfiniteSampler, YoloBatchSampler
|
yolox/data/data_augment.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data augmentation functionality. Passed as callable transformations to
|
3 |
+
Dataset classes.
|
4 |
+
|
5 |
+
The data augmentation procedures were interpreted from @weiliu89's SSD paper
|
6 |
+
http://arxiv.org/abs/1512.02325
|
7 |
+
"""
|
8 |
+
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
|
18 |
+
def augment_hsv(img, hgain=0.015, sgain=0.7, vgain=0.4):
|
19 |
+
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
20 |
+
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
21 |
+
dtype = img.dtype # uint8
|
22 |
+
|
23 |
+
x = np.arange(0, 256, dtype=np.int16)
|
24 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
25 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
26 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
27 |
+
|
28 |
+
img_hsv = cv2.merge(
|
29 |
+
(cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))
|
30 |
+
).astype(dtype)
|
31 |
+
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
32 |
+
|
33 |
+
|
34 |
+
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.2):
|
35 |
+
# box1(4,n), box2(4,n)
|
36 |
+
# Compute candidate boxes which include follwing 5 things:
|
37 |
+
# box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
38 |
+
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
39 |
+
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
40 |
+
ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
|
41 |
+
return (
|
42 |
+
(w2 > wh_thr)
|
43 |
+
& (h2 > wh_thr)
|
44 |
+
& (w2 * h2 / (w1 * h1 + 1e-16) > area_thr)
|
45 |
+
& (ar < ar_thr)
|
46 |
+
) # candidates
|
47 |
+
|
48 |
+
|
49 |
+
def random_perspective(
|
50 |
+
img, targets=(), degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0),
|
51 |
+
):
|
52 |
+
# targets = [cls, xyxy]
|
53 |
+
height = img.shape[0] + border[0] * 2 # shape(h,w,c)
|
54 |
+
width = img.shape[1] + border[1] * 2
|
55 |
+
|
56 |
+
# Center
|
57 |
+
C = np.eye(3)
|
58 |
+
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
59 |
+
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
60 |
+
|
61 |
+
# Rotation and Scale
|
62 |
+
R = np.eye(3)
|
63 |
+
a = random.uniform(-degrees, degrees)
|
64 |
+
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
65 |
+
s = random.uniform(scale[0], scale[1])
|
66 |
+
# s = 2 ** random.uniform(-scale, scale)
|
67 |
+
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
68 |
+
|
69 |
+
# Shear
|
70 |
+
S = np.eye(3)
|
71 |
+
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
72 |
+
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
73 |
+
|
74 |
+
# Translation
|
75 |
+
T = np.eye(3)
|
76 |
+
T[0, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * width) # x translation (pixels)
|
77 |
+
T[1, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * height) # y translation (pixels)
|
78 |
+
|
79 |
+
# Combined rotation matrix
|
80 |
+
M = T @ S @ R @ C # order of operations (right to left) is IMPORTANT
|
81 |
+
|
82 |
+
###########################
|
83 |
+
# For Aug out of Mosaic
|
84 |
+
# s = 1.
|
85 |
+
# M = np.eye(3)
|
86 |
+
###########################
|
87 |
+
|
88 |
+
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
89 |
+
if perspective:
|
90 |
+
img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
|
91 |
+
else: # affine
|
92 |
+
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
93 |
+
|
94 |
+
# Transform label coordinates
|
95 |
+
n = len(targets)
|
96 |
+
if n:
|
97 |
+
# warp points
|
98 |
+
xy = np.ones((n * 4, 3))
|
99 |
+
xy[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
100 |
+
xy = xy @ M.T # transform
|
101 |
+
if perspective:
|
102 |
+
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
|
103 |
+
else: # affine
|
104 |
+
xy = xy[:, :2].reshape(n, 8)
|
105 |
+
|
106 |
+
# create new boxes
|
107 |
+
x = xy[:, [0, 2, 4, 6]]
|
108 |
+
y = xy[:, [1, 3, 5, 7]]
|
109 |
+
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
110 |
+
|
111 |
+
# clip boxes
|
112 |
+
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
113 |
+
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
114 |
+
|
115 |
+
# filter candidates
|
116 |
+
i = box_candidates(box1=targets[:, :4].T * s, box2=xy.T)
|
117 |
+
targets = targets[i]
|
118 |
+
targets[:, :4] = xy[i]
|
119 |
+
|
120 |
+
return img, targets
|
121 |
+
|
122 |
+
|
123 |
+
def _distort(image):
|
124 |
+
def _convert(image, alpha=1, beta=0):
|
125 |
+
tmp = image.astype(float) * alpha + beta
|
126 |
+
tmp[tmp < 0] = 0
|
127 |
+
tmp[tmp > 255] = 255
|
128 |
+
image[:] = tmp
|
129 |
+
|
130 |
+
image = image.copy()
|
131 |
+
|
132 |
+
if random.randrange(2):
|
133 |
+
_convert(image, beta=random.uniform(-32, 32))
|
134 |
+
|
135 |
+
if random.randrange(2):
|
136 |
+
_convert(image, alpha=random.uniform(0.5, 1.5))
|
137 |
+
|
138 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
139 |
+
|
140 |
+
if random.randrange(2):
|
141 |
+
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
|
142 |
+
tmp %= 180
|
143 |
+
image[:, :, 0] = tmp
|
144 |
+
|
145 |
+
if random.randrange(2):
|
146 |
+
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
|
147 |
+
|
148 |
+
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
|
149 |
+
|
150 |
+
return image
|
151 |
+
|
152 |
+
|
153 |
+
def _mirror(image, boxes):
|
154 |
+
_, width, _ = image.shape
|
155 |
+
if random.randrange(2):
|
156 |
+
image = image[:, ::-1]
|
157 |
+
boxes = boxes.copy()
|
158 |
+
boxes[:, 0::2] = width - boxes[:, 2::-2]
|
159 |
+
return image, boxes
|
160 |
+
|
161 |
+
|
162 |
+
# TODO: reorg: use mosaicDet instead
|
163 |
+
def _random_affine(
|
164 |
+
img,
|
165 |
+
targets=None,
|
166 |
+
degrees=(-10, 10),
|
167 |
+
translate=(0.1, 0.1),
|
168 |
+
scale=(0.9, 1.1),
|
169 |
+
shear=(-2, 2),
|
170 |
+
borderValue=(114, 114, 114),
|
171 |
+
):
|
172 |
+
# degrees = (0, 0)
|
173 |
+
# shear = (0, 0)
|
174 |
+
border = 0 # width of added border (optional)
|
175 |
+
# height = max(img.shape[0], img.shape[1]) + border * 2
|
176 |
+
height, width, _ = img.shape
|
177 |
+
|
178 |
+
# Rotation and Scale
|
179 |
+
R = np.eye(3)
|
180 |
+
a = random.random() * (degrees[1] - degrees[0]) + degrees[0]
|
181 |
+
# a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations
|
182 |
+
s = random.random() * (scale[1] - scale[0]) + scale[0]
|
183 |
+
R[:2] = cv2.getRotationMatrix2D(
|
184 |
+
angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s
|
185 |
+
)
|
186 |
+
|
187 |
+
# Translation
|
188 |
+
T = np.eye(3)
|
189 |
+
# x translation (pixels)
|
190 |
+
T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border
|
191 |
+
# y translation (pixels)
|
192 |
+
T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border
|
193 |
+
|
194 |
+
# Shear
|
195 |
+
S = np.eye(3)
|
196 |
+
# x shear (deg)
|
197 |
+
S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180)
|
198 |
+
# y shear (deg)
|
199 |
+
S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180)
|
200 |
+
|
201 |
+
# Combined rotation matrix. NOTE: ORDER IS IMPORTANT HERE!!
|
202 |
+
M = S @ T @ R
|
203 |
+
# BGR order borderValue
|
204 |
+
imw = cv2.warpPerspective(
|
205 |
+
img, M, dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=borderValue
|
206 |
+
)
|
207 |
+
|
208 |
+
# Return warped points also
|
209 |
+
if targets is not None:
|
210 |
+
if len(targets) > 0:
|
211 |
+
n = targets.shape[0]
|
212 |
+
points = targets[:, 0:4].copy()
|
213 |
+
|
214 |
+
# warp points
|
215 |
+
xy = np.ones((n * 4, 3))
|
216 |
+
xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
|
217 |
+
n * 4, 2
|
218 |
+
) # x1y1, x2y2, x1y2, x2y1
|
219 |
+
xy = (xy @ M.T)[:, :2].reshape(n, 8)
|
220 |
+
|
221 |
+
# create new boxes
|
222 |
+
x = xy[:, [0, 2, 4, 6]]
|
223 |
+
y = xy[:, [1, 3, 5, 7]]
|
224 |
+
xy = (
|
225 |
+
np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
226 |
+
)
|
227 |
+
|
228 |
+
# apply angle-based reduction
|
229 |
+
radians = a * math.pi / 180
|
230 |
+
reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
|
231 |
+
x = (xy[:, 2] + xy[:, 0]) / 2
|
232 |
+
y = (xy[:, 3] + xy[:, 1]) / 2
|
233 |
+
w = (xy[:, 2] - xy[:, 0]) * reduction
|
234 |
+
h = (xy[:, 3] - xy[:, 1]) * reduction
|
235 |
+
xy = (
|
236 |
+
np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2))
|
237 |
+
.reshape(4, n)
|
238 |
+
.T
|
239 |
+
)
|
240 |
+
|
241 |
+
# reject warped points outside of image
|
242 |
+
x1 = np.clip(xy[:, 0], 0, width)
|
243 |
+
y1 = np.clip(xy[:, 1], 0, height)
|
244 |
+
x2 = np.clip(xy[:, 2], 0, width)
|
245 |
+
y2 = np.clip(xy[:, 3], 0, height)
|
246 |
+
boxes = np.concatenate((x1, y1, x2, y2)).reshape(4, n).T
|
247 |
+
|
248 |
+
return imw, boxes, M
|
249 |
+
else:
|
250 |
+
return imw
|
251 |
+
|
252 |
+
|
253 |
+
def preproc(image, input_size, mean, std, swap=(2, 0, 1)):
|
254 |
+
if len(image.shape) == 3:
|
255 |
+
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
|
256 |
+
else:
|
257 |
+
padded_img = np.ones(input_size) * 114.0
|
258 |
+
img = np.array(image)
|
259 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
260 |
+
resized_img = cv2.resize(
|
261 |
+
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LINEAR
|
262 |
+
).astype(np.float32)
|
263 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
264 |
+
image = padded_img
|
265 |
+
|
266 |
+
image = image.astype(np.float32)
|
267 |
+
image = image[:, :, ::-1]
|
268 |
+
image /= 255.0
|
269 |
+
if mean is not None:
|
270 |
+
image -= mean
|
271 |
+
if std is not None:
|
272 |
+
image /= std
|
273 |
+
image = image.transpose(swap)
|
274 |
+
image = np.ascontiguousarray(image, dtype=np.float32)
|
275 |
+
return image, r
|
276 |
+
|
277 |
+
|
278 |
+
class TrainTransform:
|
279 |
+
def __init__(self, p=0.5, rgb_means=None, std=None, max_labels=50):
|
280 |
+
self.means = rgb_means
|
281 |
+
self.std = std
|
282 |
+
self.p = p
|
283 |
+
self.max_labels = max_labels
|
284 |
+
|
285 |
+
def __call__(self, image, targets, input_dim):
|
286 |
+
boxes = targets[:, :4].copy()
|
287 |
+
labels = targets[:, 4].copy()
|
288 |
+
if targets.shape[1] > 5:
|
289 |
+
mixup = True
|
290 |
+
ratios = targets[:, -1].copy()
|
291 |
+
ratios_o = targets[:, -1].copy()
|
292 |
+
else:
|
293 |
+
mixup = False
|
294 |
+
ratios = None
|
295 |
+
ratios_o = None
|
296 |
+
lshape = 6 if mixup else 5
|
297 |
+
if len(boxes) == 0:
|
298 |
+
targets = np.zeros((self.max_labels, lshape), dtype=np.float32)
|
299 |
+
image, r_o = preproc(image, input_dim, self.means, self.std)
|
300 |
+
image = np.ascontiguousarray(image, dtype=np.float32)
|
301 |
+
return image, targets
|
302 |
+
|
303 |
+
image_o = image.copy()
|
304 |
+
targets_o = targets.copy()
|
305 |
+
height_o, width_o, _ = image_o.shape
|
306 |
+
boxes_o = targets_o[:, :4]
|
307 |
+
labels_o = targets_o[:, 4]
|
308 |
+
# bbox_o: [xyxy] to [c_x,c_y,w,h]
|
309 |
+
b_x_o = (boxes_o[:, 2] + boxes_o[:, 0]) * 0.5
|
310 |
+
b_y_o = (boxes_o[:, 3] + boxes_o[:, 1]) * 0.5
|
311 |
+
b_w_o = (boxes_o[:, 2] - boxes_o[:, 0]) * 1.0
|
312 |
+
b_h_o = (boxes_o[:, 3] - boxes_o[:, 1]) * 1.0
|
313 |
+
boxes_o[:, 0] = b_x_o
|
314 |
+
boxes_o[:, 1] = b_y_o
|
315 |
+
boxes_o[:, 2] = b_w_o
|
316 |
+
boxes_o[:, 3] = b_h_o
|
317 |
+
|
318 |
+
image_t = _distort(image)
|
319 |
+
image_t, boxes = _mirror(image_t, boxes)
|
320 |
+
height, width, _ = image_t.shape
|
321 |
+
image_t, r_ = preproc(image_t, input_dim, self.means, self.std)
|
322 |
+
boxes = boxes.copy()
|
323 |
+
# boxes [xyxy] 2 [cx,cy,w,h]
|
324 |
+
b_x = (boxes[:, 2] + boxes[:, 0]) * 0.5
|
325 |
+
b_y = (boxes[:, 3] + boxes[:, 1]) * 0.5
|
326 |
+
b_w = (boxes[:, 2] - boxes[:, 0]) * 1.0
|
327 |
+
b_h = (boxes[:, 3] - boxes[:, 1]) * 1.0
|
328 |
+
boxes[:, 0] = b_x
|
329 |
+
boxes[:, 1] = b_y
|
330 |
+
boxes[:, 2] = b_w
|
331 |
+
boxes[:, 3] = b_h
|
332 |
+
|
333 |
+
boxes *= r_
|
334 |
+
|
335 |
+
mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 8
|
336 |
+
boxes_t = boxes[mask_b]
|
337 |
+
labels_t = labels[mask_b].copy()
|
338 |
+
if mixup:
|
339 |
+
ratios_t = ratios[mask_b].copy()
|
340 |
+
|
341 |
+
if len(boxes_t) == 0:
|
342 |
+
image_t, r_o = preproc(image_o, input_dim, self.means, self.std)
|
343 |
+
boxes_o *= r_o
|
344 |
+
boxes_t = boxes_o
|
345 |
+
labels_t = labels_o
|
346 |
+
ratios_t = ratios_o
|
347 |
+
|
348 |
+
labels_t = np.expand_dims(labels_t, 1)
|
349 |
+
if mixup:
|
350 |
+
ratios_t = np.expand_dims(ratios_t, 1)
|
351 |
+
targets_t = np.hstack((labels_t, boxes_t, ratios_t))
|
352 |
+
else:
|
353 |
+
targets_t = np.hstack((labels_t, boxes_t))
|
354 |
+
padded_labels = np.zeros((self.max_labels, lshape))
|
355 |
+
padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[
|
356 |
+
: self.max_labels
|
357 |
+
]
|
358 |
+
padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
|
359 |
+
image_t = np.ascontiguousarray(image_t, dtype=np.float32)
|
360 |
+
return image_t, padded_labels
|
361 |
+
|
362 |
+
|
363 |
+
class ValTransform:
|
364 |
+
"""
|
365 |
+
Defines the transformations that should be applied to test PIL image
|
366 |
+
for input into the network
|
367 |
+
|
368 |
+
dimension -> tensorize -> color adj
|
369 |
+
|
370 |
+
Arguments:
|
371 |
+
resize (int): input dimension to SSD
|
372 |
+
rgb_means ((int,int,int)): average RGB of the dataset
|
373 |
+
(104,117,123)
|
374 |
+
swap ((int,int,int)): final order of channels
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
transform (transform) : callable transform to be applied to test/val
|
378 |
+
data
|
379 |
+
"""
|
380 |
+
|
381 |
+
def __init__(self, rgb_means=None, std=None, swap=(2, 0, 1)):
|
382 |
+
self.means = rgb_means
|
383 |
+
self.swap = swap
|
384 |
+
self.std = std
|
385 |
+
|
386 |
+
# assume input is cv2 img for now
|
387 |
+
def __call__(self, img, res, input_size):
|
388 |
+
img, _ = preproc(img, input_size, self.means, self.std, self.swap)
|
389 |
+
return torch.from_numpy(img), torch.zeros(1, 5)
|
yolox/data/data_prefetcher.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import random
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from yolox.utils import synchronize
|
10 |
+
|
11 |
+
|
12 |
+
class DataPrefetcher:
|
13 |
+
"""
|
14 |
+
DataPrefetcher is inspired by code of following file:
|
15 |
+
https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
|
16 |
+
It could speedup your pytorch dataloader. For more information, please check
|
17 |
+
https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, loader):
|
21 |
+
self.loader = iter(loader)
|
22 |
+
self.stream = torch.cuda.Stream()
|
23 |
+
self.input_cuda = self._input_cuda_for_image
|
24 |
+
self.record_stream = DataPrefetcher._record_stream_for_image
|
25 |
+
self.preload()
|
26 |
+
|
27 |
+
def preload(self):
|
28 |
+
try:
|
29 |
+
self.next_input, self.next_target, _, _ = next(self.loader)
|
30 |
+
except StopIteration:
|
31 |
+
self.next_input = None
|
32 |
+
self.next_target = None
|
33 |
+
return
|
34 |
+
|
35 |
+
with torch.cuda.stream(self.stream):
|
36 |
+
self.input_cuda()
|
37 |
+
self.next_target = self.next_target.cuda(non_blocking=True)
|
38 |
+
|
39 |
+
def next(self):
|
40 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
41 |
+
input = self.next_input
|
42 |
+
target = self.next_target
|
43 |
+
if input is not None:
|
44 |
+
self.record_stream(input)
|
45 |
+
if target is not None:
|
46 |
+
target.record_stream(torch.cuda.current_stream())
|
47 |
+
self.preload()
|
48 |
+
return input, target
|
49 |
+
|
50 |
+
def _input_cuda_for_image(self):
|
51 |
+
self.next_input = self.next_input.cuda(non_blocking=True)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _record_stream_for_image(input):
|
55 |
+
input.record_stream(torch.cuda.current_stream())
|
56 |
+
|
57 |
+
|
58 |
+
def random_resize(data_loader, exp, epoch, rank, is_distributed):
|
59 |
+
tensor = torch.LongTensor(1).cuda()
|
60 |
+
if is_distributed:
|
61 |
+
synchronize()
|
62 |
+
|
63 |
+
if rank == 0:
|
64 |
+
if epoch > exp.max_epoch - 10:
|
65 |
+
size = exp.input_size
|
66 |
+
else:
|
67 |
+
size = random.randint(*exp.random_size)
|
68 |
+
size = int(32 * size)
|
69 |
+
tensor.fill_(size)
|
70 |
+
|
71 |
+
if is_distributed:
|
72 |
+
synchronize()
|
73 |
+
dist.broadcast(tensor, 0)
|
74 |
+
|
75 |
+
input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None)
|
76 |
+
return input_size
|
yolox/data/dataloading.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data.dataloader import DataLoader as torchDataLoader
|
6 |
+
from torch.utils.data.dataloader import default_collate
|
7 |
+
|
8 |
+
from .samplers import YoloBatchSampler
|
9 |
+
|
10 |
+
|
11 |
+
def get_yolox_datadir():
|
12 |
+
"""
|
13 |
+
get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set,
|
14 |
+
this function will return value of the environment variable. Otherwise, use data
|
15 |
+
"""
|
16 |
+
yolox_datadir = os.getenv("YOLOX_DATADIR", None)
|
17 |
+
if yolox_datadir is None:
|
18 |
+
import yolox
|
19 |
+
yolox_datadir = os.path.join(os.path.dirname(yolox.__file__), "data")
|
20 |
+
return yolox_datadir
|
21 |
+
|
22 |
+
|
23 |
+
class DataLoader(torchDataLoader):
|
24 |
+
"""
|
25 |
+
Lightnet dataloader that enables on the fly resizing of the images.
|
26 |
+
See :class:`torch.utils.data.DataLoader` for more information on the arguments.
|
27 |
+
Check more on the following website:
|
28 |
+
https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
|
29 |
+
|
30 |
+
Note:
|
31 |
+
This dataloader only works with :class:`lightnet.data.Dataset` based datasets.
|
32 |
+
|
33 |
+
Example:
|
34 |
+
>>> class CustomSet(ln.data.Dataset):
|
35 |
+
... def __len__(self):
|
36 |
+
... return 4
|
37 |
+
... @ln.data.Dataset.resize_getitem
|
38 |
+
... def __getitem__(self, index):
|
39 |
+
... # Should return (image, anno) but here we return (input_dim,)
|
40 |
+
... return (self.input_dim,)
|
41 |
+
>>> dl = ln.data.DataLoader(
|
42 |
+
... CustomSet((200,200)),
|
43 |
+
... batch_size = 2,
|
44 |
+
... collate_fn = ln.data.list_collate # We want the data to be grouped as a list
|
45 |
+
... )
|
46 |
+
>>> dl.dataset.input_dim # Default input_dim
|
47 |
+
(200, 200)
|
48 |
+
>>> for d in dl:
|
49 |
+
... d
|
50 |
+
[[(200, 200), (200, 200)]]
|
51 |
+
[[(200, 200), (200, 200)]]
|
52 |
+
>>> dl.change_input_dim(320, random_range=None)
|
53 |
+
(320, 320)
|
54 |
+
>>> for d in dl:
|
55 |
+
... d
|
56 |
+
[[(320, 320), (320, 320)]]
|
57 |
+
[[(320, 320), (320, 320)]]
|
58 |
+
>>> dl.change_input_dim((480, 320), random_range=None)
|
59 |
+
(480, 320)
|
60 |
+
>>> for d in dl:
|
61 |
+
... d
|
62 |
+
[[(480, 320), (480, 320)]]
|
63 |
+
[[(480, 320), (480, 320)]]
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, *args, **kwargs):
|
67 |
+
super().__init__(*args, **kwargs)
|
68 |
+
self.__initialized = False
|
69 |
+
shuffle = False
|
70 |
+
batch_sampler = None
|
71 |
+
if len(args) > 5:
|
72 |
+
shuffle = args[2]
|
73 |
+
sampler = args[3]
|
74 |
+
batch_sampler = args[4]
|
75 |
+
elif len(args) > 4:
|
76 |
+
shuffle = args[2]
|
77 |
+
sampler = args[3]
|
78 |
+
if "batch_sampler" in kwargs:
|
79 |
+
batch_sampler = kwargs["batch_sampler"]
|
80 |
+
elif len(args) > 3:
|
81 |
+
shuffle = args[2]
|
82 |
+
if "sampler" in kwargs:
|
83 |
+
sampler = kwargs["sampler"]
|
84 |
+
if "batch_sampler" in kwargs:
|
85 |
+
batch_sampler = kwargs["batch_sampler"]
|
86 |
+
else:
|
87 |
+
if "shuffle" in kwargs:
|
88 |
+
shuffle = kwargs["shuffle"]
|
89 |
+
if "sampler" in kwargs:
|
90 |
+
sampler = kwargs["sampler"]
|
91 |
+
if "batch_sampler" in kwargs:
|
92 |
+
batch_sampler = kwargs["batch_sampler"]
|
93 |
+
|
94 |
+
# Use custom BatchSampler
|
95 |
+
if batch_sampler is None:
|
96 |
+
if sampler is None:
|
97 |
+
if shuffle:
|
98 |
+
sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
|
99 |
+
# sampler = torch.utils.data.DistributedSampler(self.dataset)
|
100 |
+
else:
|
101 |
+
sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
|
102 |
+
batch_sampler = YoloBatchSampler(
|
103 |
+
sampler,
|
104 |
+
self.batch_size,
|
105 |
+
self.drop_last,
|
106 |
+
input_dimension=self.dataset.input_dim,
|
107 |
+
)
|
108 |
+
# batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
|
109 |
+
|
110 |
+
self.batch_sampler = batch_sampler
|
111 |
+
|
112 |
+
self.__initialized = True
|
113 |
+
|
114 |
+
def close_mosaic(self):
|
115 |
+
self.batch_sampler.mosaic = False
|
116 |
+
|
117 |
+
def change_input_dim(self, multiple=32, random_range=(10, 19)):
|
118 |
+
""" This function will compute a new size and update it on the next mini_batch.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
multiple (int or tuple, optional): values to multiply the randomly generated range by.
|
122 |
+
Default **32**
|
123 |
+
random_range (tuple, optional): This (min, max) tuple sets the range
|
124 |
+
for the randomisation; Default **(10, 19)**
|
125 |
+
|
126 |
+
Return:
|
127 |
+
tuple: width, height tuple with new dimension
|
128 |
+
|
129 |
+
Note:
|
130 |
+
The new size is generated as follows: |br|
|
131 |
+
First we compute a random integer inside ``[random_range]``.
|
132 |
+
We then multiply that number with the ``multiple`` argument,
|
133 |
+
which gives our final new input size. |br|
|
134 |
+
If ``multiple`` is an integer we generate a square size. If you give a tuple
|
135 |
+
of **(width, height)**, the size is computed
|
136 |
+
as :math:`rng * multiple[0], rng * multiple[1]`.
|
137 |
+
|
138 |
+
Note:
|
139 |
+
You can set the ``random_range`` argument to **None** to set
|
140 |
+
an exact size of multiply. |br|
|
141 |
+
See the example above for how this works.
|
142 |
+
"""
|
143 |
+
if random_range is None:
|
144 |
+
size = 1
|
145 |
+
else:
|
146 |
+
size = random.randint(*random_range)
|
147 |
+
|
148 |
+
if isinstance(multiple, int):
|
149 |
+
size = (size * multiple, size * multiple)
|
150 |
+
else:
|
151 |
+
size = (size * multiple[0], size * multiple[1])
|
152 |
+
|
153 |
+
self.batch_sampler.new_input_dim = size
|
154 |
+
|
155 |
+
return size
|
156 |
+
|
157 |
+
|
158 |
+
def list_collate(batch):
|
159 |
+
"""
|
160 |
+
Function that collates lists or tuples together into one list (of lists/tuples).
|
161 |
+
Use this as the collate function in a Dataloader, if you want to have a list of
|
162 |
+
items as an output, as opposed to tensors (eg. Brambox.boxes).
|
163 |
+
"""
|
164 |
+
items = list(zip(*batch))
|
165 |
+
|
166 |
+
for i in range(len(items)):
|
167 |
+
if isinstance(items[i][0], (list, tuple)):
|
168 |
+
items[i] = list(items[i])
|
169 |
+
else:
|
170 |
+
items[i] = default_collate(items[i])
|
171 |
+
|
172 |
+
return items
|
yolox/data/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .coco import COCODataset
|
5 |
+
from .coco_classes import COCO_CLASSES
|
6 |
+
from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
|
7 |
+
from .mosaicdetection import MosaicDetection
|
yolox/data/datasets/coco.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from pycocotools.coco import COCO
|
6 |
+
|
7 |
+
from ..dataloading import get_yolox_datadir
|
8 |
+
from .datasets_wrapper import Dataset
|
9 |
+
|
10 |
+
|
11 |
+
class COCODataset(Dataset):
|
12 |
+
"""
|
13 |
+
COCO dataset class.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
data_dir=None,
|
19 |
+
json_file="instances_train2017.json",
|
20 |
+
name="train2017",
|
21 |
+
img_size=(416, 416),
|
22 |
+
preproc=None,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
COCO dataset initialization. Annotation data are read into memory by COCO API.
|
26 |
+
Args:
|
27 |
+
data_dir (str): dataset root directory
|
28 |
+
json_file (str): COCO json file name
|
29 |
+
name (str): COCO data name (e.g. 'train2017' or 'val2017')
|
30 |
+
img_size (int): target image size after pre-processing
|
31 |
+
preproc: data augmentation strategy
|
32 |
+
"""
|
33 |
+
super().__init__(img_size)
|
34 |
+
if data_dir is None:
|
35 |
+
data_dir = os.path.join(get_yolox_datadir(), "COCO")
|
36 |
+
self.data_dir = data_dir
|
37 |
+
self.json_file = json_file
|
38 |
+
|
39 |
+
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
|
40 |
+
self.ids = self.coco.getImgIds()
|
41 |
+
self.class_ids = sorted(self.coco.getCatIds())
|
42 |
+
cats = self.coco.loadCats(self.coco.getCatIds())
|
43 |
+
self._classes = tuple([c["name"] for c in cats])
|
44 |
+
self.name = name
|
45 |
+
self.max_labels = 50
|
46 |
+
self.img_size = img_size
|
47 |
+
self.preproc = preproc
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.ids)
|
51 |
+
|
52 |
+
def pull_item(self, index):
|
53 |
+
id_ = self.ids[index]
|
54 |
+
|
55 |
+
im_ann = self.coco.loadImgs(id_)[0]
|
56 |
+
width = im_ann["width"]
|
57 |
+
height = im_ann["height"]
|
58 |
+
anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
|
59 |
+
annotations = self.coco.loadAnns(anno_ids)
|
60 |
+
|
61 |
+
# load image and preprocess
|
62 |
+
img_file = os.path.join(
|
63 |
+
self.data_dir, self.name, "{:012}".format(id_) + ".jpg"
|
64 |
+
)
|
65 |
+
|
66 |
+
img = cv2.imread(img_file)
|
67 |
+
assert img is not None
|
68 |
+
|
69 |
+
# load labels
|
70 |
+
valid_objs = []
|
71 |
+
for obj in annotations:
|
72 |
+
x1 = np.max((0, obj["bbox"][0]))
|
73 |
+
y1 = np.max((0, obj["bbox"][1]))
|
74 |
+
x2 = np.min((width - 1, x1 + np.max((0, obj["bbox"][2] - 1))))
|
75 |
+
y2 = np.min((height - 1, y1 + np.max((0, obj["bbox"][3] - 1))))
|
76 |
+
if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
|
77 |
+
obj["clean_bbox"] = [x1, y1, x2, y2]
|
78 |
+
valid_objs.append(obj)
|
79 |
+
objs = valid_objs
|
80 |
+
num_objs = len(objs)
|
81 |
+
|
82 |
+
res = np.zeros((num_objs, 5))
|
83 |
+
|
84 |
+
for ix, obj in enumerate(objs):
|
85 |
+
cls = self.class_ids.index(obj["category_id"])
|
86 |
+
res[ix, 0:4] = obj["clean_bbox"]
|
87 |
+
res[ix, 4] = cls
|
88 |
+
|
89 |
+
img_info = (height, width)
|
90 |
+
|
91 |
+
return img, res, img_info, id_
|
92 |
+
|
93 |
+
@Dataset.resize_getitem
|
94 |
+
def __getitem__(self, index):
|
95 |
+
"""
|
96 |
+
One image / label pair for the given index is picked up and pre-processed.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
index (int): data index
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
img (numpy.ndarray): pre-processed image
|
103 |
+
padded_labels (torch.Tensor): pre-processed label data.
|
104 |
+
The shape is :math:`[self.max_labels, 5]`.
|
105 |
+
each label consists of [class, xc, yc, w, h]:
|
106 |
+
class (float): class index.
|
107 |
+
xc, yc (float) : center of bbox whose values range from 0 to 1.
|
108 |
+
w, h (float) : size of bbox whose values range from 0 to 1.
|
109 |
+
info_img : tuple of h, w, nh, nw, dx, dy.
|
110 |
+
h, w (int): original shape of the image
|
111 |
+
nh, nw (int): shape of the resized image without padding
|
112 |
+
dx, dy (int): pad size
|
113 |
+
img_id (int): same as the input index. Used for evaluation.
|
114 |
+
"""
|
115 |
+
img, res, img_info, img_id = self.pull_item(index)
|
116 |
+
|
117 |
+
if self.preproc is not None:
|
118 |
+
img, target = self.preproc(img, res, self.input_dim)
|
119 |
+
return img, target, img_info, img_id
|
yolox/data/datasets/coco_classes.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
COCO_CLASSES = (
|
6 |
+
"person",
|
7 |
+
"bicycle",
|
8 |
+
"car",
|
9 |
+
"motorcycle",
|
10 |
+
"airplane",
|
11 |
+
"bus",
|
12 |
+
"train",
|
13 |
+
"truck",
|
14 |
+
"boat",
|
15 |
+
"traffic light",
|
16 |
+
"fire hydrant",
|
17 |
+
"stop sign",
|
18 |
+
"parking meter",
|
19 |
+
"bench",
|
20 |
+
"bird",
|
21 |
+
"cat",
|
22 |
+
"dog",
|
23 |
+
"horse",
|
24 |
+
"sheep",
|
25 |
+
"cow",
|
26 |
+
"elephant",
|
27 |
+
"bear",
|
28 |
+
"zebra",
|
29 |
+
"giraffe",
|
30 |
+
"backpack",
|
31 |
+
"umbrella",
|
32 |
+
"handbag",
|
33 |
+
"tie",
|
34 |
+
"suitcase",
|
35 |
+
"frisbee",
|
36 |
+
"skis",
|
37 |
+
"snowboard",
|
38 |
+
"sports ball",
|
39 |
+
"kite",
|
40 |
+
"baseball bat",
|
41 |
+
"baseball glove",
|
42 |
+
"skateboard",
|
43 |
+
"surfboard",
|
44 |
+
"tennis racket",
|
45 |
+
"bottle",
|
46 |
+
"wine glass",
|
47 |
+
"cup",
|
48 |
+
"fork",
|
49 |
+
"knife",
|
50 |
+
"spoon",
|
51 |
+
"bowl",
|
52 |
+
"banana",
|
53 |
+
"apple",
|
54 |
+
"sandwich",
|
55 |
+
"orange",
|
56 |
+
"broccoli",
|
57 |
+
"carrot",
|
58 |
+
"hot dog",
|
59 |
+
"pizza",
|
60 |
+
"donut",
|
61 |
+
"cake",
|
62 |
+
"chair",
|
63 |
+
"couch",
|
64 |
+
"potted plant",
|
65 |
+
"bed",
|
66 |
+
"dining table",
|
67 |
+
"toilet",
|
68 |
+
"tv",
|
69 |
+
"laptop",
|
70 |
+
"mouse",
|
71 |
+
"remote",
|
72 |
+
"keyboard",
|
73 |
+
"cell phone",
|
74 |
+
"microwave",
|
75 |
+
"oven",
|
76 |
+
"toaster",
|
77 |
+
"sink",
|
78 |
+
"refrigerator",
|
79 |
+
"book",
|
80 |
+
"clock",
|
81 |
+
"vase",
|
82 |
+
"scissors",
|
83 |
+
"teddy bear",
|
84 |
+
"hair drier",
|
85 |
+
"toothbrush",
|
86 |
+
)
|
yolox/data/datasets/datasets_wrapper.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import bisect
|
4 |
+
from functools import wraps
|
5 |
+
|
6 |
+
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
|
7 |
+
from torch.utils.data.dataset import Dataset as torchDataset
|
8 |
+
|
9 |
+
|
10 |
+
class ConcatDataset(torchConcatDataset):
|
11 |
+
def __init__(self, datasets):
|
12 |
+
super(ConcatDataset, self).__init__(datasets)
|
13 |
+
if hasattr(self.datasets[0], "input_dim"):
|
14 |
+
self._input_dim = self.datasets[0].input_dim
|
15 |
+
self.input_dim = self.datasets[0].input_dim
|
16 |
+
|
17 |
+
def pull_item(self, idx):
|
18 |
+
if idx < 0:
|
19 |
+
if -idx > len(self):
|
20 |
+
raise ValueError(
|
21 |
+
"absolute value of index should not exceed dataset length"
|
22 |
+
)
|
23 |
+
idx = len(self) + idx
|
24 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
25 |
+
if dataset_idx == 0:
|
26 |
+
sample_idx = idx
|
27 |
+
else:
|
28 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
29 |
+
return self.datasets[dataset_idx].pull_item(sample_idx)
|
30 |
+
|
31 |
+
|
32 |
+
class MixConcatDataset(torchConcatDataset):
|
33 |
+
def __init__(self, datasets):
|
34 |
+
super(MixConcatDataset, self).__init__(datasets)
|
35 |
+
if hasattr(self.datasets[0], "input_dim"):
|
36 |
+
self._input_dim = self.datasets[0].input_dim
|
37 |
+
self.input_dim = self.datasets[0].input_dim
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
if not isinstance(index, int):
|
42 |
+
idx = index[1]
|
43 |
+
if idx < 0:
|
44 |
+
if -idx > len(self):
|
45 |
+
raise ValueError(
|
46 |
+
"absolute value of index should not exceed dataset length"
|
47 |
+
)
|
48 |
+
idx = len(self) + idx
|
49 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
50 |
+
if dataset_idx == 0:
|
51 |
+
sample_idx = idx
|
52 |
+
else:
|
53 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
54 |
+
if not isinstance(index, int):
|
55 |
+
index = (index[0], sample_idx, index[2])
|
56 |
+
|
57 |
+
return self.datasets[dataset_idx][index]
|
58 |
+
|
59 |
+
|
60 |
+
class Dataset(torchDataset):
|
61 |
+
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
|
62 |
+
that enables on the fly resizing of the ``input_dim``.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
input_dimension (tuple): (width,height) tuple with default dimensions of the network
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, input_dimension, mosaic=True):
|
69 |
+
super().__init__()
|
70 |
+
self.__input_dim = input_dimension[:2]
|
71 |
+
self._mosaic = mosaic
|
72 |
+
|
73 |
+
@property
|
74 |
+
def input_dim(self):
|
75 |
+
"""
|
76 |
+
Dimension that can be used by transforms to set the correct image size, etc.
|
77 |
+
This allows transforms to have a single source of truth
|
78 |
+
for the input dimension of the network.
|
79 |
+
|
80 |
+
Return:
|
81 |
+
list: Tuple containing the current width,height
|
82 |
+
"""
|
83 |
+
if hasattr(self, "_input_dim"):
|
84 |
+
return self._input_dim
|
85 |
+
return self.__input_dim
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def resize_getitem(getitem_fn):
|
89 |
+
"""
|
90 |
+
Decorator method that needs to be used around the ``__getitem__`` method. |br|
|
91 |
+
This decorator enables the on the fly resizing of
|
92 |
+
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
|
93 |
+
|
94 |
+
Example:
|
95 |
+
>>> class CustomSet(ln.data.Dataset):
|
96 |
+
... def __len__(self):
|
97 |
+
... return 10
|
98 |
+
... @ln.data.Dataset.resize_getitem
|
99 |
+
... def __getitem__(self, index):
|
100 |
+
... # Should return (image, anno) but here we return input_dim
|
101 |
+
... return self.input_dim
|
102 |
+
>>> data = CustomSet((200,200))
|
103 |
+
>>> data[0]
|
104 |
+
(200, 200)
|
105 |
+
>>> data[(480,320), 0]
|
106 |
+
(480, 320)
|
107 |
+
"""
|
108 |
+
|
109 |
+
@wraps(getitem_fn)
|
110 |
+
def wrapper(self, index):
|
111 |
+
if not isinstance(index, int):
|
112 |
+
has_dim = True
|
113 |
+
self._input_dim = index[0]
|
114 |
+
self._mosaic = index[2]
|
115 |
+
index = index[1]
|
116 |
+
else:
|
117 |
+
has_dim = False
|
118 |
+
|
119 |
+
ret_val = getitem_fn(self, index)
|
120 |
+
|
121 |
+
if has_dim:
|
122 |
+
del self._input_dim
|
123 |
+
|
124 |
+
return ret_val
|
125 |
+
|
126 |
+
return wrapper
|
yolox/data/datasets/mosaicdetection.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from yolox.utils import adjust_box_anns
|
7 |
+
|
8 |
+
from ..data_augment import box_candidates, random_perspective
|
9 |
+
from .datasets_wrapper import Dataset
|
10 |
+
|
11 |
+
|
12 |
+
class MosaicDetection(Dataset):
|
13 |
+
"""Detection dataset wrapper that performs mixup for normal dataset.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
dataset : Pytorch Dataset
|
18 |
+
Gluon dataset object.
|
19 |
+
*args : list
|
20 |
+
Additional arguments for mixup random sampler.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self, dataset, img_size, mosaic=True, preproc=None,
|
25 |
+
degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
|
26 |
+
shear=2.0, perspective=0.0, enable_mixup=True, *args
|
27 |
+
):
|
28 |
+
super().__init__(img_size, mosaic=mosaic)
|
29 |
+
self._dataset = dataset
|
30 |
+
self.preproc = preproc
|
31 |
+
self.degrees = degrees
|
32 |
+
self.translate = translate
|
33 |
+
self.scale = scale
|
34 |
+
self.shear = shear
|
35 |
+
self.perspective = perspective
|
36 |
+
self.mixup_scale = mscale
|
37 |
+
self._mosaic = mosaic
|
38 |
+
self.enable_mixup = enable_mixup
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self._dataset)
|
42 |
+
|
43 |
+
@Dataset.resize_getitem
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
if self._mosaic:
|
46 |
+
labels4 = []
|
47 |
+
s = self._dataset.input_dim[0]
|
48 |
+
# yc, xc = s, s # mosaic center x, y
|
49 |
+
yc = int(random.uniform(0.5 * s, 1.5 * s))
|
50 |
+
xc = int(random.uniform(0.5 * s, 1.5 * s))
|
51 |
+
|
52 |
+
# 3 additional image indices
|
53 |
+
indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
|
54 |
+
|
55 |
+
for i, index in enumerate(indices):
|
56 |
+
img, _labels, _, _ = self._dataset.pull_item(index)
|
57 |
+
h0, w0 = img.shape[:2] # orig hw
|
58 |
+
r = 1.0 * s / max(h0, w0) # resize image to img_size
|
59 |
+
interp = cv2.INTER_LINEAR
|
60 |
+
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
61 |
+
(h, w) = img.shape[:2]
|
62 |
+
|
63 |
+
if i == 0: # top left
|
64 |
+
# base image with 4 tiles
|
65 |
+
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)
|
66 |
+
# xmin, ymin, xmax, ymax (large image)
|
67 |
+
x1a, y1a, x2a, y2a = (max(xc - w, 0), max(yc - h, 0), xc, yc,)
|
68 |
+
# xmin, ymin, xmax, ymax (small image)
|
69 |
+
x1b, y1b, x2b, y2b = (w - (x2a - x1a), h - (y2a - y1a), w, h,)
|
70 |
+
elif i == 1: # top right
|
71 |
+
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
72 |
+
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
73 |
+
elif i == 2: # bottom left
|
74 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
75 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
76 |
+
elif i == 3: # bottom right
|
77 |
+
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
78 |
+
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
79 |
+
|
80 |
+
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
81 |
+
padw = x1a - x1b
|
82 |
+
padh = y1a - y1b
|
83 |
+
|
84 |
+
labels = _labels.copy() # [[xmin, ymin, xmax, ymax, label_ind], ... ]
|
85 |
+
if _labels.size > 0: # Normalized xywh to pixel xyxy format
|
86 |
+
labels[:, 0] = r * _labels[:, 0] + padw
|
87 |
+
labels[:, 1] = r * _labels[:, 1] + padh
|
88 |
+
labels[:, 2] = r * _labels[:, 2] + padw
|
89 |
+
labels[:, 3] = r * _labels[:, 3] + padh
|
90 |
+
labels4.append(labels)
|
91 |
+
|
92 |
+
if len(labels4):
|
93 |
+
labels4 = np.concatenate(labels4, 0)
|
94 |
+
np.clip(labels4[:, :4], 0, 2 * s, out=labels4[:, :4]) # use with random_affine
|
95 |
+
img4, labels4 = random_perspective(
|
96 |
+
img4,
|
97 |
+
labels4,
|
98 |
+
degrees=self.degrees,
|
99 |
+
translate=self.translate,
|
100 |
+
scale=self.scale,
|
101 |
+
shear=self.shear,
|
102 |
+
perspective=self.perspective,
|
103 |
+
border=[-s // 2, -s // 2],
|
104 |
+
) # border to remove
|
105 |
+
|
106 |
+
# -----------------------------------------------------------------
|
107 |
+
# CopyPaste: https://arxiv.org/abs/2012.07177
|
108 |
+
# -----------------------------------------------------------------
|
109 |
+
if self.enable_mixup and not len(labels4) == 0:
|
110 |
+
img4, labels4 = self.mixup(img4, labels4, self.input_dim)
|
111 |
+
mix_img, padded_labels = self.preproc(img4, labels4, self.input_dim)
|
112 |
+
img_info = (mix_img.shape[1], mix_img.shape[0])
|
113 |
+
|
114 |
+
return mix_img, padded_labels, img_info, int(idx)
|
115 |
+
|
116 |
+
else:
|
117 |
+
self._dataset._input_dim = self.input_dim
|
118 |
+
img, label, img_info, idx = self._dataset.pull_item(idx)
|
119 |
+
img, label = self.preproc(img, label, self.input_dim)
|
120 |
+
return img, label, img_info, int(idx)
|
121 |
+
|
122 |
+
def mixup(self, origin_img, origin_labels, input_dim):
|
123 |
+
# jit_factor = random.uniform(0.8, 1.2)
|
124 |
+
jit_factor = random.uniform(*self.mixup_scale)
|
125 |
+
FLIP = random.uniform(0, 1) > 0.5
|
126 |
+
cp_labels = []
|
127 |
+
while len(cp_labels) == 0:
|
128 |
+
cp_index = random.randint(0, self.__len__() - 1)
|
129 |
+
id_ = self._dataset.ids[cp_index]
|
130 |
+
anno_ids = self._dataset.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
|
131 |
+
cp_labels = self._dataset.coco.loadAnns(anno_ids)
|
132 |
+
img, cp_labels, _, _ = self._dataset.pull_item(cp_index)
|
133 |
+
|
134 |
+
if len(img.shape) == 3:
|
135 |
+
cp_img = np.ones((input_dim[0], input_dim[1], 3)) * 114.0
|
136 |
+
else:
|
137 |
+
cp_img = np.ones(input_dim) * 114.0
|
138 |
+
cp_scale_ratio = input_dim[0] / max(img.shape[0], img.shape[1])
|
139 |
+
resized_img = cv2.resize(
|
140 |
+
img,
|
141 |
+
(int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
|
142 |
+
interpolation=cv2.INTER_LINEAR,
|
143 |
+
).astype(np.float32)
|
144 |
+
cp_img[
|
145 |
+
: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)
|
146 |
+
] = resized_img
|
147 |
+
cp_img = cv2.resize(
|
148 |
+
cp_img,
|
149 |
+
(int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
|
150 |
+
)
|
151 |
+
cp_scale_ratio *= jit_factor
|
152 |
+
if FLIP:
|
153 |
+
cp_img = cp_img[:, ::-1, :]
|
154 |
+
|
155 |
+
origin_h, origin_w = cp_img.shape[:2]
|
156 |
+
target_h, target_w = origin_img.shape[:2]
|
157 |
+
padded_img = np.zeros(
|
158 |
+
(max(origin_h, target_h), max(origin_w, target_w), 3)
|
159 |
+
).astype(np.uint8)
|
160 |
+
padded_img[:origin_h, :origin_w] = cp_img
|
161 |
+
|
162 |
+
x_offset, y_offset = 0, 0
|
163 |
+
if padded_img.shape[0] > target_h:
|
164 |
+
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
|
165 |
+
if padded_img.shape[1] > target_w:
|
166 |
+
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
|
167 |
+
padded_cropped_img = padded_img[
|
168 |
+
y_offset: y_offset + target_h, x_offset: x_offset + target_w
|
169 |
+
]
|
170 |
+
|
171 |
+
cp_bboxes_origin_np = adjust_box_anns(
|
172 |
+
cp_labels[:, :4], cp_scale_ratio, 0, 0, origin_w, origin_h
|
173 |
+
)
|
174 |
+
if FLIP:
|
175 |
+
cp_bboxes_origin_np[:, 0::2] = (
|
176 |
+
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
|
177 |
+
)
|
178 |
+
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
|
179 |
+
cp_bboxes_transformed_np[:, 0::2] = np.clip(
|
180 |
+
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
|
181 |
+
)
|
182 |
+
cp_bboxes_transformed_np[:, 1::2] = np.clip(
|
183 |
+
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
|
184 |
+
)
|
185 |
+
keep_list = box_candidates(cp_bboxes_origin_np.T, cp_bboxes_transformed_np.T, 5)
|
186 |
+
|
187 |
+
if keep_list.sum() >= 1.0:
|
188 |
+
cls_labels = cp_labels[keep_list, 4:5]
|
189 |
+
box_labels = cp_bboxes_transformed_np[keep_list]
|
190 |
+
labels = np.hstack((box_labels, cls_labels))
|
191 |
+
origin_labels = np.vstack((origin_labels, labels))
|
192 |
+
origin_img = origin_img.astype(np.float32)
|
193 |
+
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
|
194 |
+
|
195 |
+
return origin_img.astype(np.uint8), origin_labels
|
yolox/data/datasets/voc.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""VOC Dataset Classes
|
2 |
+
|
3 |
+
Original author: Francisco Massa
|
4 |
+
https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
|
5 |
+
|
6 |
+
Updated by: Ellis Brown, Max deGroot
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import os.path
|
11 |
+
import pickle
|
12 |
+
import xml.etree.ElementTree as ET
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from yolox.evalutors.voc_eval import voc_eval
|
18 |
+
|
19 |
+
from .datasets_wrapper import Dataset
|
20 |
+
from .voc_classes import VOC_CLASSES
|
21 |
+
|
22 |
+
# for making bounding boxes pretty
|
23 |
+
COLORS = (
|
24 |
+
(255, 0, 0, 128),
|
25 |
+
(0, 255, 0, 128),
|
26 |
+
(0, 0, 255, 128),
|
27 |
+
(0, 255, 255, 128),
|
28 |
+
(255, 0, 255, 128),
|
29 |
+
(255, 255, 0, 128),
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class AnnotationTransform(object):
|
34 |
+
|
35 |
+
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
|
36 |
+
Initilized with a dictionary lookup of classnames to indexes
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
|
40 |
+
(default: alphabetic indexing of VOC's 20 classes)
|
41 |
+
keep_difficult (bool, optional): keep difficult instances or not
|
42 |
+
(default: False)
|
43 |
+
height (int): height
|
44 |
+
width (int): width
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, class_to_ind=None, keep_difficult=True):
|
48 |
+
self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
|
49 |
+
self.keep_difficult = keep_difficult
|
50 |
+
|
51 |
+
def __call__(self, target):
|
52 |
+
"""
|
53 |
+
Arguments:
|
54 |
+
target (annotation) : the target annotation to be made usable
|
55 |
+
will be an ET.Element
|
56 |
+
Returns:
|
57 |
+
a list containing lists of bounding boxes [bbox coords, class name]
|
58 |
+
"""
|
59 |
+
res = np.empty((0, 5))
|
60 |
+
for obj in target.iter("object"):
|
61 |
+
difficult = int(obj.find("difficult").text) == 1
|
62 |
+
if not self.keep_difficult and difficult:
|
63 |
+
continue
|
64 |
+
name = obj.find("name").text.lower().strip()
|
65 |
+
bbox = obj.find("bndbox")
|
66 |
+
|
67 |
+
pts = ["xmin", "ymin", "xmax", "ymax"]
|
68 |
+
bndbox = []
|
69 |
+
for i, pt in enumerate(pts):
|
70 |
+
cur_pt = int(bbox.find(pt).text) - 1
|
71 |
+
# scale height or width
|
72 |
+
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
|
73 |
+
bndbox.append(cur_pt)
|
74 |
+
label_idx = self.class_to_ind[name]
|
75 |
+
bndbox.append(label_idx)
|
76 |
+
res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind]
|
77 |
+
# img_id = target.find('filename').text[:-4]
|
78 |
+
|
79 |
+
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
|
80 |
+
|
81 |
+
|
82 |
+
class VOCDetection(Dataset):
|
83 |
+
|
84 |
+
"""
|
85 |
+
VOC Detection Dataset Object
|
86 |
+
|
87 |
+
input is image, target is annotation
|
88 |
+
|
89 |
+
Args:
|
90 |
+
root (string): filepath to VOCdevkit folder.
|
91 |
+
image_set (string): imageset to use (eg. 'train', 'val', 'test')
|
92 |
+
transform (callable, optional): transformation to perform on the
|
93 |
+
input image
|
94 |
+
target_transform (callable, optional): transformation to perform on the
|
95 |
+
target `annotation`
|
96 |
+
(eg: take in caption string, return tensor of word indices)
|
97 |
+
dataset_name (string, optional): which dataset to load
|
98 |
+
(default: 'VOC2007')
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
root,
|
104 |
+
image_sets,
|
105 |
+
preproc=None,
|
106 |
+
target_transform=AnnotationTransform(),
|
107 |
+
input_dim=(416, 416),
|
108 |
+
dataset_name="VOC0712",
|
109 |
+
):
|
110 |
+
super().__init__(input_dim)
|
111 |
+
self.root = root
|
112 |
+
self.image_set = image_sets
|
113 |
+
self.preproc = preproc
|
114 |
+
self.target_transform = target_transform
|
115 |
+
self.name = dataset_name
|
116 |
+
self._annopath = os.path.join("%s", "Annotations", "%s.xml")
|
117 |
+
self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
|
118 |
+
self._classes = VOC_CLASSES
|
119 |
+
self.ids = list()
|
120 |
+
for (year, name) in image_sets:
|
121 |
+
self._year = year
|
122 |
+
rootpath = os.path.join(self.root, "VOC" + year)
|
123 |
+
for line in open(
|
124 |
+
os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
|
125 |
+
):
|
126 |
+
self.ids.append((rootpath, line.strip()))
|
127 |
+
|
128 |
+
@Dataset.resize_getitem
|
129 |
+
def __getitem__(self, index):
|
130 |
+
img_id = self.ids[index]
|
131 |
+
target = ET.parse(self._annopath % img_id).getroot()
|
132 |
+
img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
|
133 |
+
# img = Image.open(self._imgpath % img_id).convert('RGB')
|
134 |
+
|
135 |
+
height, width, _ = img.shape
|
136 |
+
|
137 |
+
if self.target_transform is not None:
|
138 |
+
target = self.target_transform(target)
|
139 |
+
|
140 |
+
if self.preproc is not None:
|
141 |
+
img, target = self.preproc(img, target, self.input_dim)
|
142 |
+
# print(img.size())
|
143 |
+
|
144 |
+
img_info = (width, height)
|
145 |
+
|
146 |
+
return img, target, img_info, img_id
|
147 |
+
|
148 |
+
def __len__(self):
|
149 |
+
return len(self.ids)
|
150 |
+
|
151 |
+
def pull_image(self, index):
|
152 |
+
"""Returns the original image object at index in PIL form
|
153 |
+
|
154 |
+
Note: not using self.__getitem__(), as any transformations passed in
|
155 |
+
could mess up this functionality.
|
156 |
+
|
157 |
+
Argument:
|
158 |
+
index (int): index of img to show
|
159 |
+
Return:
|
160 |
+
PIL img
|
161 |
+
"""
|
162 |
+
img_id = self.ids[index]
|
163 |
+
return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
|
164 |
+
|
165 |
+
def pull_anno(self, index):
|
166 |
+
"""Returns the original annotation of image at index
|
167 |
+
|
168 |
+
Note: not using self.__getitem__(), as any transformations passed in
|
169 |
+
could mess up this functionality.
|
170 |
+
|
171 |
+
Argument:
|
172 |
+
index (int): index of img to get annotation of
|
173 |
+
Return:
|
174 |
+
list: [img_id, [(label, bbox coords),...]]
|
175 |
+
eg: ('001718', [('dog', (96, 13, 438, 332))])
|
176 |
+
"""
|
177 |
+
img_id = self.ids[index]
|
178 |
+
anno = ET.parse(self._annopath % img_id).getroot()
|
179 |
+
gt = self.target_transform(anno, 1, 1)
|
180 |
+
return img_id[1], gt
|
181 |
+
|
182 |
+
def pull_item(self, index):
|
183 |
+
"""Returns the original image and target at an index for mixup
|
184 |
+
|
185 |
+
Note: not using self.__getitem__(), as any transformations passed in
|
186 |
+
could mess up this functionality.
|
187 |
+
|
188 |
+
Argument:
|
189 |
+
index (int): index of img to show
|
190 |
+
Return:
|
191 |
+
img, target
|
192 |
+
"""
|
193 |
+
img_id = self.ids[index]
|
194 |
+
target = ET.parse(self._annopath % img_id).getroot()
|
195 |
+
img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
|
196 |
+
|
197 |
+
height, width, _ = img.shape
|
198 |
+
|
199 |
+
img_info = (width, height)
|
200 |
+
if self.target_transform is not None:
|
201 |
+
target = self.target_transform(target)
|
202 |
+
|
203 |
+
return img, target, img_info, img_id
|
204 |
+
|
205 |
+
def evaluate_detections(self, all_boxes, output_dir=None):
|
206 |
+
"""
|
207 |
+
all_boxes is a list of length number-of-classes.
|
208 |
+
Each list element is a list of length number-of-images.
|
209 |
+
Each of those list elements is either an empty list []
|
210 |
+
or a numpy array of detection.
|
211 |
+
|
212 |
+
all_boxes[class][image] = [] or np.array of shape #dets x 5
|
213 |
+
"""
|
214 |
+
self._write_voc_results_file(all_boxes)
|
215 |
+
IouTh = np.linspace(0.5, 0.95, np.round((0.95 - 0.5) / 0.05) + 1, endpoint=True)
|
216 |
+
mAPs = []
|
217 |
+
for iou in IouTh:
|
218 |
+
mAP = self._do_python_eval(output_dir, iou)
|
219 |
+
mAPs.append(mAP)
|
220 |
+
|
221 |
+
print("--------------------------------------------------------------")
|
222 |
+
print("map_5095:", np.mean(mAPs))
|
223 |
+
print("map_50:", mAPs[0])
|
224 |
+
print("--------------------------------------------------------------")
|
225 |
+
return np.mean(mAPs), mAPs[0]
|
226 |
+
|
227 |
+
def _get_voc_results_file_template(self):
|
228 |
+
filename = "comp4_det_test" + "_{:s}.txt"
|
229 |
+
filedir = os.path.join(self.root, "results", "VOC" + self._year, "Main")
|
230 |
+
if not os.path.exists(filedir):
|
231 |
+
os.makedirs(filedir)
|
232 |
+
path = os.path.join(filedir, filename)
|
233 |
+
return path
|
234 |
+
|
235 |
+
def _write_voc_results_file(self, all_boxes):
|
236 |
+
for cls_ind, cls in enumerate(VOC_CLASSES):
|
237 |
+
cls_ind = cls_ind
|
238 |
+
if cls == "__background__":
|
239 |
+
continue
|
240 |
+
print("Writing {} VOC results file".format(cls))
|
241 |
+
filename = self._get_voc_results_file_template().format(cls)
|
242 |
+
with open(filename, "wt") as f:
|
243 |
+
for im_ind, index in enumerate(self.ids):
|
244 |
+
index = index[1]
|
245 |
+
dets = all_boxes[cls_ind][im_ind]
|
246 |
+
if dets == []:
|
247 |
+
continue
|
248 |
+
for k in range(dets.shape[0]):
|
249 |
+
f.write(
|
250 |
+
"{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
|
251 |
+
index,
|
252 |
+
dets[k, -1],
|
253 |
+
dets[k, 0] + 1,
|
254 |
+
dets[k, 1] + 1,
|
255 |
+
dets[k, 2] + 1,
|
256 |
+
dets[k, 3] + 1,
|
257 |
+
)
|
258 |
+
)
|
259 |
+
|
260 |
+
def _do_python_eval(self, output_dir="output", iou=0.5):
|
261 |
+
rootpath = os.path.join(self.root, "VOC" + self._year)
|
262 |
+
name = self.image_set[0][1]
|
263 |
+
annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")
|
264 |
+
imagesetfile = os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
|
265 |
+
cachedir = os.path.join(
|
266 |
+
self.root, "annotations_cache", "VOC" + self._year, name
|
267 |
+
)
|
268 |
+
if not os.path.exists(cachedir):
|
269 |
+
os.makedirs(cachedir)
|
270 |
+
aps = []
|
271 |
+
# The PASCAL VOC metric changed in 2010
|
272 |
+
use_07_metric = True if int(self._year) < 2010 else False
|
273 |
+
print("VOC07 metric? " + ("Yes" if use_07_metric else "No"))
|
274 |
+
if output_dir is not None and not os.path.isdir(output_dir):
|
275 |
+
os.mkdir(output_dir)
|
276 |
+
for i, cls in enumerate(VOC_CLASSES):
|
277 |
+
|
278 |
+
if cls == "__background__":
|
279 |
+
continue
|
280 |
+
|
281 |
+
filename = self._get_voc_results_file_template().format(cls)
|
282 |
+
rec, prec, ap = voc_eval(
|
283 |
+
filename,
|
284 |
+
annopath,
|
285 |
+
imagesetfile,
|
286 |
+
cls,
|
287 |
+
cachedir,
|
288 |
+
ovthresh=iou,
|
289 |
+
use_07_metric=use_07_metric,
|
290 |
+
)
|
291 |
+
aps += [ap]
|
292 |
+
if iou == 0.5:
|
293 |
+
print("AP for {} = {:.4f}".format(cls, ap))
|
294 |
+
if output_dir is not None:
|
295 |
+
with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
|
296 |
+
pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
|
297 |
+
if iou == 0.5:
|
298 |
+
print("Mean AP = {:.4f}".format(np.mean(aps)))
|
299 |
+
print("~~~~~~~~")
|
300 |
+
print("Results:")
|
301 |
+
for ap in aps:
|
302 |
+
print("{:.3f}".format(ap))
|
303 |
+
print("{:.3f}".format(np.mean(aps)))
|
304 |
+
print("~~~~~~~~")
|
305 |
+
print("")
|
306 |
+
print("--------------------------------------------------------------")
|
307 |
+
print("Results computed with the **unofficial** Python eval code.")
|
308 |
+
print("Results should be very close to the official MATLAB eval code.")
|
309 |
+
print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
|
310 |
+
print("-- Thanks, The Management")
|
311 |
+
print("--------------------------------------------------------------")
|
312 |
+
|
313 |
+
return np.mean(aps)
|
yolox/data/datasets/voc_classes.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
# VOC_CLASSES = ( '__background__', # always index 0
|
6 |
+
VOC_CLASSES = (
|
7 |
+
"aeroplane",
|
8 |
+
"bicycle",
|
9 |
+
"bird",
|
10 |
+
"boat",
|
11 |
+
"bottle",
|
12 |
+
"bus",
|
13 |
+
"car",
|
14 |
+
"cat",
|
15 |
+
"chair",
|
16 |
+
"cow",
|
17 |
+
"diningtable",
|
18 |
+
"dog",
|
19 |
+
"horse",
|
20 |
+
"motorbike",
|
21 |
+
"person",
|
22 |
+
"pottedplant",
|
23 |
+
"sheep",
|
24 |
+
"sofa",
|
25 |
+
"train",
|
26 |
+
"tvmonitor",
|
27 |
+
)
|
yolox/data/samplers.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import itertools
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch.utils.data.sampler import BatchSampler as torchBatchSampler
|
10 |
+
from torch.utils.data.sampler import Sampler
|
11 |
+
|
12 |
+
|
13 |
+
class YoloBatchSampler(torchBatchSampler):
|
14 |
+
"""
|
15 |
+
This batch sampler will generate mini-batches of (dim, index) tuples from another sampler.
|
16 |
+
It works just like the :class:`torch.utils.data.sampler.BatchSampler`,
|
17 |
+
but it will prepend a dimension, whilst ensuring it stays the same across one mini-batch.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, *args, input_dimension=None, mosaic=True, **kwargs):
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.input_dim = input_dimension
|
23 |
+
self.new_input_dim = None
|
24 |
+
self.mosaic = mosaic
|
25 |
+
|
26 |
+
def __iter__(self):
|
27 |
+
self.__set_input_dim()
|
28 |
+
for batch in super().__iter__():
|
29 |
+
yield [(self.input_dim, idx, self.mosaic) for idx in batch]
|
30 |
+
self.__set_input_dim()
|
31 |
+
|
32 |
+
def __set_input_dim(self):
|
33 |
+
""" This function randomly changes the the input dimension of the dataset. """
|
34 |
+
if self.new_input_dim is not None:
|
35 |
+
self.input_dim = (self.new_input_dim[0], self.new_input_dim[1])
|
36 |
+
self.new_input_dim = None
|
37 |
+
|
38 |
+
|
39 |
+
class InfiniteSampler(Sampler):
|
40 |
+
"""
|
41 |
+
In training, we only care about the "infinite stream" of training data.
|
42 |
+
So this sampler produces an infinite stream of indices and
|
43 |
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
44 |
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
45 |
+
where `indices` is an infinite stream of indices consisting of
|
46 |
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
47 |
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
size: int,
|
53 |
+
shuffle: bool = True,
|
54 |
+
seed: Optional[int] = 0,
|
55 |
+
rank=0,
|
56 |
+
world_size=1,
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
size (int): the total number of data of the underlying dataset to sample from
|
61 |
+
shuffle (bool): whether to shuffle the indices or not
|
62 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
63 |
+
across all workers. If None, will use a random seed shared
|
64 |
+
among workers (require synchronization among all workers).
|
65 |
+
"""
|
66 |
+
self._size = size
|
67 |
+
assert size > 0
|
68 |
+
self._shuffle = shuffle
|
69 |
+
self._seed = int(seed)
|
70 |
+
|
71 |
+
if dist.is_available() and dist.is_initialized():
|
72 |
+
self._rank = dist.get_rank()
|
73 |
+
self._world_size = dist.get_world_size()
|
74 |
+
else:
|
75 |
+
self._rank = rank
|
76 |
+
self._world_size = world_size
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
start = self._rank
|
80 |
+
yield from itertools.islice(
|
81 |
+
self._infinite_indices(), start, None, self._world_size
|
82 |
+
)
|
83 |
+
|
84 |
+
def _infinite_indices(self):
|
85 |
+
g = torch.Generator()
|
86 |
+
g.manual_seed(self._seed)
|
87 |
+
while True:
|
88 |
+
if self._shuffle:
|
89 |
+
yield from torch.randperm(self._size, generator=g)
|
90 |
+
else:
|
91 |
+
yield from torch.arange(self._size)
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return self._size // self._world_size
|
yolox/evalutors/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .coco_evaluator import COCOEvaluator
|
yolox/evalutors/coco_evaluator.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import io
|
3 |
+
import itertools
|
4 |
+
import json
|
5 |
+
import tempfile
|
6 |
+
import time
|
7 |
+
from loguru import logger
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from yolox.utils import (
|
13 |
+
gather,
|
14 |
+
is_main_process,
|
15 |
+
postprocess,
|
16 |
+
synchronize,
|
17 |
+
time_synchronized,
|
18 |
+
xyxy2xywh
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class COCOEvaluator:
|
23 |
+
"""
|
24 |
+
COCO AP Evaluation class. All the data in the val2017 dataset are processed
|
25 |
+
and evaluated by COCO API.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self, dataloader, img_size, confthre, nmsthre, num_classes, testdev=False
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
dataloader (Dataloader): evaluate dataloader.
|
34 |
+
img_size (int): image size after preprocess. images are resized
|
35 |
+
to squares whose shape is (img_size, img_size).
|
36 |
+
confthre (float): confidence threshold ranging from 0 to 1, which
|
37 |
+
is defined in the config file.
|
38 |
+
nmsthre (float): IoU threshold of non-max supression ranging from 0 to 1.
|
39 |
+
"""
|
40 |
+
self.dataloader = dataloader
|
41 |
+
self.img_size = img_size
|
42 |
+
self.confthre = confthre
|
43 |
+
self.nmsthre = nmsthre
|
44 |
+
self.num_classes = num_classes
|
45 |
+
self.testdev = testdev
|
46 |
+
|
47 |
+
def evaluate(
|
48 |
+
self,
|
49 |
+
model,
|
50 |
+
distributed=False,
|
51 |
+
half=False,
|
52 |
+
trt_file=None,
|
53 |
+
decoder=None,
|
54 |
+
test_size=None,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
|
58 |
+
and the results are evaluated by COCO API.
|
59 |
+
|
60 |
+
NOTE: This function will change training mode to False, please save states if needed.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
model : model to evaluate.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
ap50_95 (float) : COCO AP of IoU=50:95
|
67 |
+
ap50 (float) : COCO AP of IoU=50
|
68 |
+
summary (sr): summary info of evaluation.
|
69 |
+
"""
|
70 |
+
# TODO half to amp_test
|
71 |
+
tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
|
72 |
+
model = model.eval()
|
73 |
+
if half:
|
74 |
+
model = model.half()
|
75 |
+
ids = []
|
76 |
+
data_list = []
|
77 |
+
progress_bar = tqdm if is_main_process() else iter
|
78 |
+
|
79 |
+
inference_time = 0
|
80 |
+
nms_time = 0
|
81 |
+
n_samples = len(self.dataloader) - 1
|
82 |
+
|
83 |
+
if trt_file is not None:
|
84 |
+
from torch2trt import TRTModule
|
85 |
+
|
86 |
+
model_trt = TRTModule()
|
87 |
+
model_trt.load_state_dict(torch.load(trt_file))
|
88 |
+
|
89 |
+
x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
|
90 |
+
dump_out = model(x)
|
91 |
+
model = model_trt
|
92 |
+
|
93 |
+
for cur_iter, (imgs, _, info_imgs, ids) in enumerate(
|
94 |
+
progress_bar(self.dataloader)
|
95 |
+
):
|
96 |
+
with torch.no_grad():
|
97 |
+
imgs = imgs.type(tensor_type)
|
98 |
+
|
99 |
+
# skip the the last iters since batchsize might be not enough for batch inference
|
100 |
+
is_time_record = cur_iter < len(self.dataloader) - 1
|
101 |
+
if is_time_record:
|
102 |
+
start = time.time()
|
103 |
+
|
104 |
+
outputs = model(imgs)
|
105 |
+
if decoder is not None:
|
106 |
+
outputs = decoder(outputs, dtype=outputs.type())
|
107 |
+
|
108 |
+
if is_time_record:
|
109 |
+
infer_end = time_synchronized()
|
110 |
+
inference_time += infer_end - start
|
111 |
+
|
112 |
+
outputs = postprocess(
|
113 |
+
outputs, self.num_classes, self.confthre, self.nmsthre
|
114 |
+
)
|
115 |
+
if is_time_record:
|
116 |
+
nms_end = time_synchronized()
|
117 |
+
nms_time += nms_end - infer_end
|
118 |
+
|
119 |
+
data_list.extend(self.convert_to_coco_format(outputs, info_imgs, ids))
|
120 |
+
|
121 |
+
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
|
122 |
+
if distributed:
|
123 |
+
data_list = gather(data_list, dst=0)
|
124 |
+
data_list = list(itertools.chain(*data_list))
|
125 |
+
torch.distributed.reduce(statistics, dst=0)
|
126 |
+
|
127 |
+
eval_results = self.evaluate_prediction(data_list, statistics)
|
128 |
+
synchronize()
|
129 |
+
return eval_results
|
130 |
+
|
131 |
+
def convert_to_coco_format(self, outputs, info_imgs, ids):
|
132 |
+
data_list = []
|
133 |
+
for (output, img_h, img_w, img_id) in zip(
|
134 |
+
outputs, info_imgs[0], info_imgs[1], ids
|
135 |
+
):
|
136 |
+
if output is None:
|
137 |
+
continue
|
138 |
+
output = output.cpu()
|
139 |
+
|
140 |
+
bboxes = output[:, 0:4]
|
141 |
+
|
142 |
+
# preprocessing: resize
|
143 |
+
scale = min(
|
144 |
+
self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
|
145 |
+
)
|
146 |
+
bboxes /= scale
|
147 |
+
bboxes = xyxy2xywh(bboxes)
|
148 |
+
|
149 |
+
cls = output[:, 6]
|
150 |
+
scores = output[:, 4] * output[:, 5]
|
151 |
+
for ind in range(bboxes.shape[0]):
|
152 |
+
label = self.dataloader.dataset.class_ids[int(cls[ind])]
|
153 |
+
pred_data = {
|
154 |
+
"image_id": int(img_id),
|
155 |
+
"category_id": label,
|
156 |
+
"bbox": bboxes[ind].numpy().tolist(),
|
157 |
+
"score": scores[ind].numpy().item(),
|
158 |
+
"segmentation": [],
|
159 |
+
} # COCO json format
|
160 |
+
data_list.append(pred_data)
|
161 |
+
return data_list
|
162 |
+
|
163 |
+
def evaluate_prediction(self, data_dict, statistics):
|
164 |
+
if not is_main_process():
|
165 |
+
return 0, 0, None
|
166 |
+
|
167 |
+
logger.info("Evaluate in main process...")
|
168 |
+
|
169 |
+
annType = ["segm", "bbox", "keypoints"]
|
170 |
+
|
171 |
+
inference_time = statistics[0].item()
|
172 |
+
nms_time = statistics[1].item()
|
173 |
+
n_samples = statistics[2].item()
|
174 |
+
|
175 |
+
a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
|
176 |
+
a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
|
177 |
+
|
178 |
+
time_info = ", ".join(
|
179 |
+
[
|
180 |
+
"Average {} time: {:.2f} ms".format(k, v)
|
181 |
+
for k, v in zip(
|
182 |
+
["forward", "NMS", "inference"],
|
183 |
+
[a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
|
184 |
+
)
|
185 |
+
]
|
186 |
+
)
|
187 |
+
|
188 |
+
info = time_info + "\n"
|
189 |
+
|
190 |
+
# Evaluate the Dt (detection) json comparing with the ground truth
|
191 |
+
if len(data_dict) > 0:
|
192 |
+
cocoGt = self.dataloader.dataset.coco
|
193 |
+
# TODO: since pycocotools can't process dict in py36, write data to json file.
|
194 |
+
if self.testdev:
|
195 |
+
json.dump(data_dict, open("./yolox_testdev_2017.json", "w"))
|
196 |
+
cocoDt = cocoGt.loadRes("./yolox_testdev_2017.json")
|
197 |
+
else:
|
198 |
+
_, tmp = tempfile.mkstemp()
|
199 |
+
json.dump(data_dict, open(tmp, "w"))
|
200 |
+
cocoDt = cocoGt.loadRes(tmp)
|
201 |
+
try:
|
202 |
+
from yolox.layers import COCOeval_opt as COCOeval
|
203 |
+
except ImportError:
|
204 |
+
from .cocoeval_mr import COCOeval
|
205 |
+
|
206 |
+
logger.warning("Use standard COCOeval.")
|
207 |
+
|
208 |
+
cocoEval = COCOeval(cocoGt, cocoDt, annType[1])
|
209 |
+
cocoEval.evaluate()
|
210 |
+
cocoEval.accumulate()
|
211 |
+
redirect_string = io.StringIO()
|
212 |
+
with contextlib.redirect_stdout(redirect_string):
|
213 |
+
cocoEval.summarize()
|
214 |
+
info += redirect_string.getvalue()
|
215 |
+
return cocoEval.stats[0], cocoEval.stats[1], info
|
216 |
+
else:
|
217 |
+
return 0, 0, info
|
yolox/evalutors/voc_eval.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Fast/er R-CNN
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Written by Bharath Hariharan
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import xml.etree.ElementTree as ET
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def parse_rec(filename):
|
15 |
+
""" Parse a PASCAL VOC xml file """
|
16 |
+
tree = ET.parse(filename)
|
17 |
+
objects = []
|
18 |
+
for obj in tree.findall("object"):
|
19 |
+
obj_struct = {}
|
20 |
+
obj_struct["name"] = obj.find("name").text
|
21 |
+
obj_struct["pose"] = obj.find("pose").text
|
22 |
+
obj_struct["truncated"] = int(obj.find("truncated").text)
|
23 |
+
obj_struct["difficult"] = int(obj.find("difficult").text)
|
24 |
+
bbox = obj.find("bndbox")
|
25 |
+
obj_struct["bbox"] = [
|
26 |
+
int(bbox.find("xmin").text),
|
27 |
+
int(bbox.find("ymin").text),
|
28 |
+
int(bbox.find("xmax").text),
|
29 |
+
int(bbox.find("ymax").text),
|
30 |
+
]
|
31 |
+
objects.append(obj_struct)
|
32 |
+
|
33 |
+
return objects
|
34 |
+
|
35 |
+
|
36 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
37 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
38 |
+
Compute VOC AP given precision and recall.
|
39 |
+
If use_07_metric is true, uses the
|
40 |
+
VOC 07 11 point method (default:False).
|
41 |
+
"""
|
42 |
+
if use_07_metric:
|
43 |
+
# 11 point metric
|
44 |
+
ap = 0.0
|
45 |
+
for t in np.arange(0.0, 1.1, 0.1):
|
46 |
+
if np.sum(rec >= t) == 0:
|
47 |
+
p = 0
|
48 |
+
else:
|
49 |
+
p = np.max(prec[rec >= t])
|
50 |
+
ap = ap + p / 11.0
|
51 |
+
else:
|
52 |
+
# correct AP calculation
|
53 |
+
# first append sentinel values at the end
|
54 |
+
mrec = np.concatenate(([0.0], rec, [1.0]))
|
55 |
+
mpre = np.concatenate(([0.0], prec, [0.0]))
|
56 |
+
|
57 |
+
# compute the precision envelope
|
58 |
+
for i in range(mpre.size - 1, 0, -1):
|
59 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
60 |
+
|
61 |
+
# to calculate area under PR curve, look for points
|
62 |
+
# where X axis (recall) changes value
|
63 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
64 |
+
|
65 |
+
# and sum (\Delta recall) * prec
|
66 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
67 |
+
return ap
|
68 |
+
|
69 |
+
|
70 |
+
def voc_eval(
|
71 |
+
detpath,
|
72 |
+
annopath,
|
73 |
+
imagesetfile,
|
74 |
+
classname,
|
75 |
+
cachedir,
|
76 |
+
ovthresh=0.5,
|
77 |
+
use_07_metric=False,
|
78 |
+
):
|
79 |
+
# first load gt
|
80 |
+
if not os.path.isdir(cachedir):
|
81 |
+
os.mkdir(cachedir)
|
82 |
+
cachefile = os.path.join(cachedir, "annots.pkl")
|
83 |
+
# read list of images
|
84 |
+
with open(imagesetfile, "r") as f:
|
85 |
+
lines = f.readlines()
|
86 |
+
imagenames = [x.strip() for x in lines]
|
87 |
+
|
88 |
+
if not os.path.isfile(cachefile):
|
89 |
+
# load annots
|
90 |
+
recs = {}
|
91 |
+
for i, imagename in enumerate(imagenames):
|
92 |
+
recs[imagename] = parse_rec(annopath.format(imagename))
|
93 |
+
if i % 100 == 0:
|
94 |
+
print("Reading annotation for {:d}/{:d}".format(i + 1, len(imagenames)))
|
95 |
+
# save
|
96 |
+
print("Saving cached annotations to {:s}".format(cachefile))
|
97 |
+
with open(cachefile, "wb") as f:
|
98 |
+
pickle.dump(recs, f)
|
99 |
+
else:
|
100 |
+
# load
|
101 |
+
with open(cachefile, "rb") as f:
|
102 |
+
recs = pickle.load(f)
|
103 |
+
|
104 |
+
# extract gt objects for this class
|
105 |
+
class_recs = {}
|
106 |
+
npos = 0
|
107 |
+
for imagename in imagenames:
|
108 |
+
R = [obj for obj in recs[imagename] if obj["name"] == classname]
|
109 |
+
bbox = np.array([x["bbox"] for x in R])
|
110 |
+
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
|
111 |
+
det = [False] * len(R)
|
112 |
+
npos = npos + sum(~difficult)
|
113 |
+
class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
|
114 |
+
|
115 |
+
# read dets
|
116 |
+
detfile = detpath.format(classname)
|
117 |
+
with open(detfile, "r") as f:
|
118 |
+
lines = f.readlines()
|
119 |
+
|
120 |
+
if len(lines) == 0:
|
121 |
+
return 0, 0, 0
|
122 |
+
|
123 |
+
splitlines = [x.strip().split(" ") for x in lines]
|
124 |
+
image_ids = [x[0] for x in splitlines]
|
125 |
+
confidence = np.array([float(x[1]) for x in splitlines])
|
126 |
+
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
|
127 |
+
|
128 |
+
# sort by confidence
|
129 |
+
sorted_ind = np.argsort(-confidence)
|
130 |
+
BB = BB[sorted_ind, :]
|
131 |
+
image_ids = [image_ids[x] for x in sorted_ind]
|
132 |
+
|
133 |
+
# go down dets and mark TPs and FPs
|
134 |
+
nd = len(image_ids)
|
135 |
+
tp = np.zeros(nd)
|
136 |
+
fp = np.zeros(nd)
|
137 |
+
for d in range(nd):
|
138 |
+
R = class_recs[image_ids[d]]
|
139 |
+
bb = BB[d, :].astype(float)
|
140 |
+
ovmax = -np.inf
|
141 |
+
BBGT = R["bbox"].astype(float)
|
142 |
+
|
143 |
+
if BBGT.size > 0:
|
144 |
+
# compute overlaps
|
145 |
+
# intersection
|
146 |
+
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
147 |
+
iymin = np.maximum(BBGT[:, 1], bb[1])
|
148 |
+
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
149 |
+
iymax = np.minimum(BBGT[:, 3], bb[3])
|
150 |
+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
151 |
+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
152 |
+
inters = iw * ih
|
153 |
+
|
154 |
+
# union
|
155 |
+
uni = (
|
156 |
+
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
157 |
+
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
|
158 |
+
- inters
|
159 |
+
)
|
160 |
+
|
161 |
+
overlaps = inters / uni
|
162 |
+
ovmax = np.max(overlaps)
|
163 |
+
jmax = np.argmax(overlaps)
|
164 |
+
|
165 |
+
if ovmax > ovthresh:
|
166 |
+
if not R["difficult"][jmax]:
|
167 |
+
if not R["det"][jmax]:
|
168 |
+
tp[d] = 1.0
|
169 |
+
R["det"][jmax] = 1
|
170 |
+
else:
|
171 |
+
fp[d] = 1.0
|
172 |
+
else:
|
173 |
+
fp[d] = 1.0
|
174 |
+
|
175 |
+
# compute precision recall
|
176 |
+
fp = np.cumsum(fp)
|
177 |
+
tp = np.cumsum(tp)
|
178 |
+
rec = tp / float(npos)
|
179 |
+
# avoid divide by zero in case the first detection matches a difficult
|
180 |
+
# ground truth
|
181 |
+
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
182 |
+
ap = voc_ap(rec, prec, use_07_metric)
|
183 |
+
|
184 |
+
return rec, prec, ap
|
yolox/evalutors/voc_evaluator.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# TODO check VOC
|
9 |
+
from yolox.data.dataset.vocdataset import ValTransform
|
10 |
+
from yolox.utils import get_rank, is_main_process, make_pred_vis, make_vis, synchronize
|
11 |
+
|
12 |
+
# TODO refactor this file in the future.
|
13 |
+
|
14 |
+
|
15 |
+
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
|
16 |
+
all_predictions = dist.scatter_gather(predictions_per_gpu)
|
17 |
+
if not is_main_process():
|
18 |
+
return
|
19 |
+
# merge the list of dicts
|
20 |
+
predictions = {}
|
21 |
+
for p in all_predictions:
|
22 |
+
predictions.update(p)
|
23 |
+
# convert a dict where the key is the index in a list
|
24 |
+
image_ids = list(sorted(predictions.keys()))
|
25 |
+
if len(image_ids) != image_ids[-1] + 1:
|
26 |
+
print("num_imgs: ", len(image_ids))
|
27 |
+
print("last img_id: ", image_ids[-1])
|
28 |
+
print(
|
29 |
+
"Number of images that were gathered from multiple processes is not "
|
30 |
+
"a contiguous set. Some images might be missing from the evaluation"
|
31 |
+
)
|
32 |
+
|
33 |
+
# convert to a list
|
34 |
+
predictions = [predictions[i] for i in image_ids]
|
35 |
+
return predictions
|
36 |
+
|
37 |
+
|
38 |
+
class VOCEvaluator:
|
39 |
+
"""
|
40 |
+
COCO AP Evaluation class.
|
41 |
+
All the data in the val2017 dataset are processed \
|
42 |
+
and evaluated by COCO API.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, data_dir, img_size, confthre, nmsthre, vis=False):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
data_dir (str): dataset root directory
|
49 |
+
img_size (int): image size after preprocess. images are resized \
|
50 |
+
to squares whose shape is (img_size, img_size).
|
51 |
+
confthre (float):
|
52 |
+
confidence threshold ranging from 0 to 1, \
|
53 |
+
which is defined in the config file.
|
54 |
+
nmsthre (float):
|
55 |
+
IoU threshold of non-max supression ranging from 0 to 1.
|
56 |
+
"""
|
57 |
+
test_sets = [("2007", "test")]
|
58 |
+
self.dataset = VOCDetection(
|
59 |
+
root=data_dir,
|
60 |
+
image_sets=test_sets,
|
61 |
+
input_dim=img_size,
|
62 |
+
preproc=ValTransform(
|
63 |
+
rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
64 |
+
),
|
65 |
+
)
|
66 |
+
self.num_images = len(self.dataset)
|
67 |
+
self.dataloader = torch.utils.data.DataLoader(
|
68 |
+
self.dataset, batch_size=1, shuffle=False, num_workers=0
|
69 |
+
)
|
70 |
+
self.img_size = img_size
|
71 |
+
self.confthre = confthre
|
72 |
+
self.nmsthre = nmsthre
|
73 |
+
self.vis = vis
|
74 |
+
|
75 |
+
def evaluate(self, model, distributed=False):
|
76 |
+
"""
|
77 |
+
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
|
78 |
+
and the results are evaluated by COCO API.
|
79 |
+
Args:
|
80 |
+
model : model object
|
81 |
+
Returns:
|
82 |
+
ap50_95 (float) : calculated COCO AP for IoU=50:95
|
83 |
+
ap50 (float) : calculated COCO AP for IoU=50
|
84 |
+
"""
|
85 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
86 |
+
model = model.module
|
87 |
+
model.eval()
|
88 |
+
cuda = torch.cuda.is_available()
|
89 |
+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
90 |
+
|
91 |
+
ids = []
|
92 |
+
data_dict = []
|
93 |
+
dataiterator = iter(self.dataloader)
|
94 |
+
img_num = 0
|
95 |
+
indices = list(range(self.num_images))
|
96 |
+
dis_indices = indices[get_rank() :: distributed_util.get_world_size()]
|
97 |
+
progress_bar = tqdm if distributed_util.is_main_process() else iter
|
98 |
+
num_classes = 20
|
99 |
+
predictions = {}
|
100 |
+
|
101 |
+
if is_main_process():
|
102 |
+
inference_time = 0
|
103 |
+
nms_time = 0
|
104 |
+
n_samples = len(dis_indices)
|
105 |
+
|
106 |
+
for i in progress_bar(dis_indices):
|
107 |
+
img, _, info_img, id_ = self.dataset[i] # load a batch
|
108 |
+
info_img = [float(info) for info in info_img]
|
109 |
+
ids.append(id_)
|
110 |
+
with torch.no_grad():
|
111 |
+
img = Variable(img.type(Tensor).unsqueeze(0))
|
112 |
+
|
113 |
+
if is_main_process() and i > 9:
|
114 |
+
start = time.time()
|
115 |
+
|
116 |
+
if self.vis:
|
117 |
+
outputs, fuse_weights, fused_f = model(img)
|
118 |
+
else:
|
119 |
+
outputs = model(img)
|
120 |
+
|
121 |
+
if is_main_process() and i > 9:
|
122 |
+
infer_end = time.time()
|
123 |
+
inference_time += infer_end - start
|
124 |
+
|
125 |
+
outputs = postprocess(outputs, 20, self.confthre, self.nmsthre)
|
126 |
+
|
127 |
+
if is_main_process() and i > 9:
|
128 |
+
nms_end = time.time()
|
129 |
+
nms_time += nms_end - infer_end
|
130 |
+
|
131 |
+
if outputs[0] is None:
|
132 |
+
predictions[i] = (None, None, None)
|
133 |
+
continue
|
134 |
+
outputs = outputs[0].cpu().data
|
135 |
+
|
136 |
+
bboxes = outputs[:, 0:4]
|
137 |
+
bboxes[:, 0::2] *= info_img[0] / self.img_size[0]
|
138 |
+
bboxes[:, 1::2] *= info_img[1] / self.img_size[1]
|
139 |
+
cls = outputs[:, 6]
|
140 |
+
scores = outputs[:, 4] * outputs[:, 5]
|
141 |
+
predictions[i] = (bboxes, cls, scores)
|
142 |
+
|
143 |
+
if self.vis:
|
144 |
+
o_img, _, _, _ = self.dataset.pull_item(i)
|
145 |
+
make_vis("VOC", i, o_img, fuse_weights, fused_f)
|
146 |
+
class_names = self.dataset._classes
|
147 |
+
|
148 |
+
bbox = bboxes.clone()
|
149 |
+
bbox[:, 2] = bbox[:, 2] - bbox[:, 0]
|
150 |
+
bbox[:, 3] = bbox[:, 3] - bbox[:, 1]
|
151 |
+
|
152 |
+
make_pred_vis("VOC", i, o_img, class_names, bbox, cls, scores)
|
153 |
+
|
154 |
+
if is_main_process():
|
155 |
+
o_img, _, _, _ = self.dataset.pull_item(i)
|
156 |
+
class_names = self.dataset._classes
|
157 |
+
bbox = bboxes.clone()
|
158 |
+
bbox[:, 2] = bbox[:, 2] - bbox[:, 0]
|
159 |
+
bbox[:, 3] = bbox[:, 3] - bbox[:, 1]
|
160 |
+
make_pred_vis("VOC", i, o_img, class_names, bbox, cls, scores)
|
161 |
+
|
162 |
+
synchronize()
|
163 |
+
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
|
164 |
+
if not is_main_process():
|
165 |
+
return 0, 0
|
166 |
+
|
167 |
+
print("Main process Evaluating...")
|
168 |
+
|
169 |
+
a_infer_time = 1000 * inference_time / (n_samples - 10)
|
170 |
+
a_nms_time = 1000 * nms_time / (n_samples - 10)
|
171 |
+
|
172 |
+
print(
|
173 |
+
"Average forward time: %.2f ms, Average NMS time: %.2f ms, Average inference time: %.2f ms"
|
174 |
+
% (a_infer_time, a_nms_time, (a_infer_time + a_nms_time))
|
175 |
+
)
|
176 |
+
|
177 |
+
all_boxes = [[[] for _ in range(self.num_images)] for _ in range(num_classes)]
|
178 |
+
for img_num in range(self.num_images):
|
179 |
+
bboxes, cls, scores = predictions[img_num]
|
180 |
+
if bboxes is None:
|
181 |
+
for j in range(num_classes):
|
182 |
+
all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
|
183 |
+
continue
|
184 |
+
for j in range(num_classes):
|
185 |
+
mask_c = cls == j
|
186 |
+
if sum(mask_c) == 0:
|
187 |
+
all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
|
188 |
+
continue
|
189 |
+
|
190 |
+
c_dets = torch.cat((bboxes, scores.unsqueeze(1)), dim=1)
|
191 |
+
all_boxes[j][img_num] = c_dets[mask_c].numpy()
|
192 |
+
|
193 |
+
sys.stdout.write(
|
194 |
+
"im_eval: {:d}/{:d} \r".format(img_num + 1, self.num_images)
|
195 |
+
)
|
196 |
+
sys.stdout.flush()
|
197 |
+
|
198 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
199 |
+
mAP50, mAP70 = self.dataset.evaluate_detections(all_boxes, tempdir)
|
200 |
+
return mAP50, mAP70
|
yolox/exp/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
from .base_exp import BaseExp
|
6 |
+
from .build import get_exp
|
7 |
+
from .yolox_base import Exp
|
yolox/exp/base_exp.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
4 |
+
import ast
|
5 |
+
import pprint
|
6 |
+
from abc import ABCMeta, abstractmethod
|
7 |
+
from typing import Dict
|
8 |
+
from tabulate import tabulate
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn import Module
|
12 |
+
|
13 |
+
from yolox.utils import LRScheduler
|
14 |
+
|
15 |
+
|
16 |
+
class BaseExp(metaclass=ABCMeta):
|
17 |
+
"""Basic class for any experiment.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
self.seed = None
|
22 |
+
self.output_dir = "/data/YOLOX_outputs"
|
23 |
+
self.print_interval = 100
|
24 |
+
self.eval_interval = 10
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def get_model(self) -> Module:
|
28 |
+
pass
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def get_data_loader(
|
32 |
+
self, batch_size: int, is_distributed: bool
|
33 |
+
) -> Dict[str, torch.utils.data.DataLoader]:
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def get_lr_scheduler(
|
42 |
+
self, lr: float, iters_per_epoch: int, **kwargs
|
43 |
+
) -> LRScheduler:
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def get_evaluator(self):
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def eval(self, model, evaluator, weights):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
table_header = ["keys", "values"]
|
56 |
+
exp_table = [
|
57 |
+
(str(k), pprint.pformat(v)) for k, v in vars(self).items() if not k.startswith("_")
|
58 |
+
]
|
59 |
+
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
|
60 |
+
|
61 |
+
def merge(self, cfg_list):
|
62 |
+
assert len(cfg_list) % 2 == 0
|
63 |
+
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
64 |
+
# only update value with same key
|
65 |
+
if hasattr(self, k):
|
66 |
+
src_value = getattr(self, k)
|
67 |
+
src_type = type(src_value)
|
68 |
+
if src_value is not None and src_type != type(v):
|
69 |
+
try:
|
70 |
+
v = src_type(v)
|
71 |
+
except Exception:
|
72 |
+
v = ast.literal_eval(v)
|
73 |
+
setattr(self, k, v)
|
yolox/exp/build.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import importlib
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
|
9 |
+
def get_exp_by_file(exp_file):
|
10 |
+
try:
|
11 |
+
sys.path.append(os.path.dirname(exp_file))
|
12 |
+
current_exp = importlib.import_module(os.path.basename(exp_file).split(".")[0])
|
13 |
+
exp = current_exp.Exp()
|
14 |
+
except Exception:
|
15 |
+
raise ImportError("{} doesn't contains class named 'Exp'".format(exp_file))
|
16 |
+
return exp
|
17 |
+
|
18 |
+
|
19 |
+
def get_exp_by_name(exp_name):
|
20 |
+
import yolox
|
21 |
+
yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
|
22 |
+
filedict = {
|
23 |
+
"yolox-s": "yolox_s.py",
|
24 |
+
"yolox-m": "yolox_l.py",
|
25 |
+
"yolox-l": "yolox_l.py",
|
26 |
+
"yolox-x": "yolox_x.py",
|
27 |
+
"yolox-nano": "nano.py",
|
28 |
+
"yolov3": "yolov3.py",
|
29 |
+
}
|
30 |
+
filename = filedict[exp_name]
|
31 |
+
exp_path = os.path.join(yolox_path, "exps", "base", filename)
|
32 |
+
return get_exp_by_file(exp_path)
|
33 |
+
|
34 |
+
|
35 |
+
def get_exp(exp_file, exp_name):
|
36 |
+
"""
|
37 |
+
get Exp object by file or name. If exp_file and exp_name
|
38 |
+
are both provided, get Exp by exp_file.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
exp_file (str): file path of experiment.
|
42 |
+
exp_name (str): name of experiment. "yolo-s",
|
43 |
+
"""
|
44 |
+
assert exp_file is not None or exp_name is not None, "plz provide exp file or exp name."
|
45 |
+
if exp_file is not None:
|
46 |
+
return get_exp_by_file(exp_file)
|
47 |
+
else:
|
48 |
+
return get_exp_by_name(exp_name)
|
yolox/exp/yolox_base.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .base_exp import BaseExp
|
13 |
+
|
14 |
+
|
15 |
+
class Exp(BaseExp):
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# ---------------- model config ---------------- #
|
21 |
+
self.num_classes = 80
|
22 |
+
self.depth = 1.00
|
23 |
+
self.width = 1.00
|
24 |
+
|
25 |
+
# ---------------- dataloader config ---------------- #
|
26 |
+
self.data_num_workers = 6
|
27 |
+
self.input_size = (640, 640)
|
28 |
+
self.random_size = (14, 26)
|
29 |
+
self.train_ann = "instances_train2017.json"
|
30 |
+
self.val_ann = "instances_val2017.json"
|
31 |
+
|
32 |
+
# --------------- transform config ----------------- #
|
33 |
+
self.degrees = 10.0
|
34 |
+
self.translate = 0.1
|
35 |
+
self.scale = (0.1, 2)
|
36 |
+
self.mscale = (0.8, 1.6)
|
37 |
+
self.shear = 2.0
|
38 |
+
self.perspective = 0.0
|
39 |
+
self.enable_mixup = True
|
40 |
+
|
41 |
+
# -------------- training config --------------------- #
|
42 |
+
self.warmup_epochs = 5
|
43 |
+
self.max_epoch = 300
|
44 |
+
self.warmup_lr = 0
|
45 |
+
self.basic_lr_per_img = 0.01 / 64.0
|
46 |
+
self.scheduler = "yoloxwarmcos"
|
47 |
+
self.no_aug_epochs = 15
|
48 |
+
self.min_lr_ratio = 0.05
|
49 |
+
self.ema = True
|
50 |
+
|
51 |
+
self.weight_decay = 5e-4
|
52 |
+
self.momentum = 0.9
|
53 |
+
self.print_interval = 10
|
54 |
+
self.eval_interval = 10
|
55 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
56 |
+
|
57 |
+
# ----------------- testing config ------------------ #
|
58 |
+
self.test_size = (640, 640)
|
59 |
+
self.test_conf = 0.01
|
60 |
+
self.nmsthre = 0.65
|
61 |
+
|
62 |
+
def get_model(self):
|
63 |
+
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
|
64 |
+
|
65 |
+
def init_yolo(M):
|
66 |
+
for m in M.modules():
|
67 |
+
if isinstance(m, nn.BatchNorm2d):
|
68 |
+
m.eps = 1e-3
|
69 |
+
m.momentum = 0.03
|
70 |
+
|
71 |
+
if getattr(self, "model", None) is None:
|
72 |
+
in_channels = [256, 512, 1024]
|
73 |
+
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
|
74 |
+
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)
|
75 |
+
self.model = YOLOX(backbone, head)
|
76 |
+
|
77 |
+
self.model.apply(init_yolo)
|
78 |
+
self.model.head.initialize_biases(1e-2)
|
79 |
+
return self.model
|
80 |
+
|
81 |
+
def get_data_loader(self, batch_size, is_distributed, no_aug=False):
|
82 |
+
from yolox.data import (
|
83 |
+
COCODataset,
|
84 |
+
TrainTransform,
|
85 |
+
YoloBatchSampler,
|
86 |
+
DataLoader,
|
87 |
+
InfiniteSampler,
|
88 |
+
MosaicDetection,
|
89 |
+
)
|
90 |
+
|
91 |
+
dataset = COCODataset(
|
92 |
+
data_dir=None,
|
93 |
+
json_file=self.train_ann,
|
94 |
+
img_size=self.input_size,
|
95 |
+
preproc=TrainTransform(
|
96 |
+
rgb_means=(0.485, 0.456, 0.406),
|
97 |
+
std=(0.229, 0.224, 0.225),
|
98 |
+
max_labels=50,
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
dataset = MosaicDetection(
|
103 |
+
dataset,
|
104 |
+
mosaic=not no_aug,
|
105 |
+
img_size=self.input_size,
|
106 |
+
preproc=TrainTransform(
|
107 |
+
rgb_means=(0.485, 0.456, 0.406),
|
108 |
+
std=(0.229, 0.224, 0.225),
|
109 |
+
max_labels=120,
|
110 |
+
),
|
111 |
+
degrees=self.degrees,
|
112 |
+
translate=self.translate,
|
113 |
+
scale=self.scale,
|
114 |
+
shear=self.shear,
|
115 |
+
perspective=self.perspective,
|
116 |
+
enable_mixup=self.enable_mixup,
|
117 |
+
)
|
118 |
+
|
119 |
+
self.dataset = dataset
|
120 |
+
|
121 |
+
if is_distributed:
|
122 |
+
batch_size = batch_size // dist.get_world_size()
|
123 |
+
sampler = InfiniteSampler(
|
124 |
+
len(self.dataset), seed=self.seed if self.seed else 0
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
sampler = torch.utils.data.RandomSampler(self.dataset)
|
128 |
+
|
129 |
+
batch_sampler = YoloBatchSampler(
|
130 |
+
sampler=sampler,
|
131 |
+
batch_size=batch_size,
|
132 |
+
drop_last=False,
|
133 |
+
input_dimension=self.input_size,
|
134 |
+
mosaic=not no_aug,
|
135 |
+
)
|
136 |
+
|
137 |
+
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
|
138 |
+
dataloader_kwargs["batch_sampler"] = batch_sampler
|
139 |
+
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
|
140 |
+
|
141 |
+
return train_loader
|
142 |
+
|
143 |
+
def random_resize(self, data_loader, epoch, rank, is_distributed):
|
144 |
+
tensor = torch.LongTensor(1).cuda()
|
145 |
+
|
146 |
+
if rank == 0:
|
147 |
+
if epoch >= self.max_epoch - 1:
|
148 |
+
size = self.input_size[0]
|
149 |
+
else:
|
150 |
+
size = random.randint(*self.random_size)
|
151 |
+
size = int(32 * size)
|
152 |
+
tensor.fill_(size)
|
153 |
+
|
154 |
+
if is_distributed:
|
155 |
+
dist.barrier()
|
156 |
+
dist.broadcast(tensor, 0)
|
157 |
+
|
158 |
+
input_size = data_loader.change_input_dim(
|
159 |
+
multiple=tensor.item(), random_range=None
|
160 |
+
)
|
161 |
+
return input_size
|
162 |
+
|
163 |
+
def get_optimizer(self, batch_size):
|
164 |
+
if "optimizer" not in self.__dict__:
|
165 |
+
if self.warmup_epochs > 0:
|
166 |
+
lr = self.warmup_lr
|
167 |
+
else:
|
168 |
+
lr = self.basic_lr_per_img * batch_size
|
169 |
+
|
170 |
+
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
171 |
+
|
172 |
+
for k, v in self.model.named_modules():
|
173 |
+
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
174 |
+
pg2.append(v.bias) # biases
|
175 |
+
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
|
176 |
+
pg0.append(v.weight) # no decay
|
177 |
+
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
178 |
+
pg1.append(v.weight) # apply decay
|
179 |
+
|
180 |
+
optimizer = torch.optim.SGD(
|
181 |
+
pg0, lr=lr, momentum=self.momentum, nesterov=True
|
182 |
+
)
|
183 |
+
optimizer.add_param_group(
|
184 |
+
{"params": pg1, "weight_decay": self.weight_decay}
|
185 |
+
) # add pg1 with weight_decay
|
186 |
+
optimizer.add_param_group({"params": pg2})
|
187 |
+
self.optimizer = optimizer
|
188 |
+
|
189 |
+
return self.optimizer
|
190 |
+
|
191 |
+
def get_lr_scheduler(self, lr, iters_per_epoch):
|
192 |
+
from yolox.utils import LRScheduler
|
193 |
+
scheduler = LRScheduler(
|
194 |
+
self.scheduler,
|
195 |
+
lr,
|
196 |
+
iters_per_epoch,
|
197 |
+
self.max_epoch,
|
198 |
+
warmup_epochs=self.warmup_epochs,
|
199 |
+
warmup_lr_start=self.warmup_lr,
|
200 |
+
no_aug_epochs=self.no_aug_epochs,
|
201 |
+
min_lr_ratio=self.min_lr_ratio,
|
202 |
+
)
|
203 |
+
return scheduler
|
204 |
+
|
205 |
+
def get_eval_loader(self, batch_size, is_distributed, testdev=False):
|
206 |
+
from yolox.data import COCODataset, ValTransform
|
207 |
+
|
208 |
+
valdataset = COCODataset(
|
209 |
+
data_dir=None,
|
210 |
+
json_file=self.val_ann if not testdev else "image_info_test-dev2017.json",
|
211 |
+
name="val2017" if not testdev else "test2017",
|
212 |
+
img_size=self.test_size,
|
213 |
+
preproc=ValTransform(
|
214 |
+
rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
215 |
+
),
|
216 |
+
)
|
217 |
+
|
218 |
+
if is_distributed:
|
219 |
+
batch_size = batch_size // dist.get_world_size()
|
220 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
221 |
+
valdataset, shuffle=False
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
sampler = torch.utils.data.SequentialSampler(valdataset)
|
225 |
+
|
226 |
+
dataloader_kwargs = {
|
227 |
+
"num_workers": self.data_num_workers,
|
228 |
+
"pin_memory": True,
|
229 |
+
"sampler": sampler,
|
230 |
+
}
|
231 |
+
dataloader_kwargs["batch_size"] = batch_size
|
232 |
+
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
|
233 |
+
|
234 |
+
return val_loader
|
235 |
+
|
236 |
+
def get_evaluator(self, batch_size, is_distributed, testdev=False):
|
237 |
+
from yolox.evalutors import COCOEvaluator
|
238 |
+
|
239 |
+
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
|
240 |
+
evaluator = COCOEvaluator(
|
241 |
+
dataloader=val_loader,
|
242 |
+
img_size=self.test_size,
|
243 |
+
confthre=self.test_conf,
|
244 |
+
nmsthre=self.nmsthre,
|
245 |
+
num_classes=self.num_classes,
|
246 |
+
testdev=testdev,
|
247 |
+
)
|
248 |
+
return evaluator
|
249 |
+
|
250 |
+
def eval(self, model, evaluator, is_distributed, half=False):
|
251 |
+
return evaluator.evaluate(model, is_distributed, half)
|
yolox/layers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .fast_coco_eval_api import COCOeval_opt
|
yolox/layers/csrc/cocoeval/cocoeval.cpp
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
#include "cocoeval.h"
|
3 |
+
#include <time.h>
|
4 |
+
#include <algorithm>
|
5 |
+
#include <cstdint>
|
6 |
+
#include <numeric>
|
7 |
+
|
8 |
+
using namespace pybind11::literals;
|
9 |
+
|
10 |
+
namespace COCOeval {
|
11 |
+
|
12 |
+
// Sort detections from highest score to lowest, such that
|
13 |
+
// detection_instances[detection_sorted_indices[t]] >=
|
14 |
+
// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match
|
15 |
+
// original COCO API
|
16 |
+
void SortInstancesByDetectionScore(
|
17 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
18 |
+
std::vector<uint64_t>* detection_sorted_indices) {
|
19 |
+
detection_sorted_indices->resize(detection_instances.size());
|
20 |
+
std::iota(
|
21 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
22 |
+
std::stable_sort(
|
23 |
+
detection_sorted_indices->begin(),
|
24 |
+
detection_sorted_indices->end(),
|
25 |
+
[&detection_instances](size_t j1, size_t j2) {
|
26 |
+
return detection_instances[j1].score > detection_instances[j2].score;
|
27 |
+
});
|
28 |
+
}
|
29 |
+
|
30 |
+
// Partition the ground truth objects based on whether or not to ignore them
|
31 |
+
// based on area
|
32 |
+
void SortInstancesByIgnore(
|
33 |
+
const std::array<double, 2>& area_range,
|
34 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
35 |
+
std::vector<uint64_t>* ground_truth_sorted_indices,
|
36 |
+
std::vector<bool>* ignores) {
|
37 |
+
ignores->clear();
|
38 |
+
ignores->reserve(ground_truth_instances.size());
|
39 |
+
for (auto o : ground_truth_instances) {
|
40 |
+
ignores->push_back(
|
41 |
+
o.ignore || o.area < area_range[0] || o.area > area_range[1]);
|
42 |
+
}
|
43 |
+
|
44 |
+
ground_truth_sorted_indices->resize(ground_truth_instances.size());
|
45 |
+
std::iota(
|
46 |
+
ground_truth_sorted_indices->begin(),
|
47 |
+
ground_truth_sorted_indices->end(),
|
48 |
+
0);
|
49 |
+
std::stable_sort(
|
50 |
+
ground_truth_sorted_indices->begin(),
|
51 |
+
ground_truth_sorted_indices->end(),
|
52 |
+
[&ignores](size_t j1, size_t j2) {
|
53 |
+
return (int)(*ignores)[j1] < (int)(*ignores)[j2];
|
54 |
+
});
|
55 |
+
}
|
56 |
+
|
57 |
+
// For each IOU threshold, greedily match each detected instance to a ground
|
58 |
+
// truth instance (if possible) and store the results
|
59 |
+
void MatchDetectionsToGroundTruth(
|
60 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
61 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
62 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
63 |
+
const std::vector<uint64_t>& ground_truth_sorted_indices,
|
64 |
+
const std::vector<bool>& ignores,
|
65 |
+
const std::vector<std::vector<double>>& ious,
|
66 |
+
const std::vector<double>& iou_thresholds,
|
67 |
+
const std::array<double, 2>& area_range,
|
68 |
+
ImageEvaluation* results) {
|
69 |
+
// Initialize memory to store return data matches and ignore
|
70 |
+
const int num_iou_thresholds = iou_thresholds.size();
|
71 |
+
const int num_ground_truth = ground_truth_sorted_indices.size();
|
72 |
+
const int num_detections = detection_sorted_indices.size();
|
73 |
+
std::vector<uint64_t> ground_truth_matches(
|
74 |
+
num_iou_thresholds * num_ground_truth, 0);
|
75 |
+
std::vector<uint64_t>& detection_matches = results->detection_matches;
|
76 |
+
std::vector<bool>& detection_ignores = results->detection_ignores;
|
77 |
+
std::vector<bool>& ground_truth_ignores = results->ground_truth_ignores;
|
78 |
+
detection_matches.resize(num_iou_thresholds * num_detections, 0);
|
79 |
+
detection_ignores.resize(num_iou_thresholds * num_detections, false);
|
80 |
+
ground_truth_ignores.resize(num_ground_truth);
|
81 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
82 |
+
ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]];
|
83 |
+
}
|
84 |
+
|
85 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
86 |
+
for (auto d = 0; d < num_detections; ++d) {
|
87 |
+
// information about best match so far (match=-1 -> unmatched)
|
88 |
+
double best_iou = std::min(iou_thresholds[t], 1 - 1e-10);
|
89 |
+
int match = -1;
|
90 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
91 |
+
// if this ground truth instance is already matched and not a
|
92 |
+
// crowd, it cannot be matched to another detection
|
93 |
+
if (ground_truth_matches[t * num_ground_truth + g] > 0 &&
|
94 |
+
!ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) {
|
95 |
+
continue;
|
96 |
+
}
|
97 |
+
|
98 |
+
// if detected instance matched to a regular ground truth
|
99 |
+
// instance, we can break on the first ground truth instance
|
100 |
+
// tagged as ignore (because they are sorted by the ignore tag)
|
101 |
+
if (match >= 0 && !ground_truth_ignores[match] &&
|
102 |
+
ground_truth_ignores[g]) {
|
103 |
+
break;
|
104 |
+
}
|
105 |
+
|
106 |
+
// if IOU overlap is the best so far, store the match appropriately
|
107 |
+
if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) {
|
108 |
+
best_iou = ious[d][ground_truth_sorted_indices[g]];
|
109 |
+
match = g;
|
110 |
+
}
|
111 |
+
}
|
112 |
+
// if match was made, store id of match for both detection and
|
113 |
+
// ground truth
|
114 |
+
if (match >= 0) {
|
115 |
+
detection_ignores[t * num_detections + d] = ground_truth_ignores[match];
|
116 |
+
detection_matches[t * num_detections + d] =
|
117 |
+
ground_truth_instances[ground_truth_sorted_indices[match]].id;
|
118 |
+
ground_truth_matches[t * num_ground_truth + match] =
|
119 |
+
detection_instances[detection_sorted_indices[d]].id;
|
120 |
+
}
|
121 |
+
|
122 |
+
// set unmatched detections outside of area range to ignore
|
123 |
+
const InstanceAnnotation& detection =
|
124 |
+
detection_instances[detection_sorted_indices[d]];
|
125 |
+
detection_ignores[t * num_detections + d] =
|
126 |
+
detection_ignores[t * num_detections + d] ||
|
127 |
+
(detection_matches[t * num_detections + d] == 0 &&
|
128 |
+
(detection.area < area_range[0] || detection.area > area_range[1]));
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// store detection score results
|
133 |
+
results->detection_scores.resize(detection_sorted_indices.size());
|
134 |
+
for (size_t d = 0; d < detection_sorted_indices.size(); ++d) {
|
135 |
+
results->detection_scores[d] =
|
136 |
+
detection_instances[detection_sorted_indices[d]].score;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
141 |
+
const std::vector<std::array<double, 2>>& area_ranges,
|
142 |
+
int max_detections,
|
143 |
+
const std::vector<double>& iou_thresholds,
|
144 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
145 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
146 |
+
image_category_ground_truth_instances,
|
147 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
148 |
+
image_category_detection_instances) {
|
149 |
+
const int num_area_ranges = area_ranges.size();
|
150 |
+
const int num_images = image_category_ground_truth_instances.size();
|
151 |
+
const int num_categories =
|
152 |
+
image_category_ious.size() > 0 ? image_category_ious[0].size() : 0;
|
153 |
+
std::vector<uint64_t> detection_sorted_indices;
|
154 |
+
std::vector<uint64_t> ground_truth_sorted_indices;
|
155 |
+
std::vector<bool> ignores;
|
156 |
+
std::vector<ImageEvaluation> results_all(
|
157 |
+
num_images * num_area_ranges * num_categories);
|
158 |
+
|
159 |
+
// Store results for each image, category, and area range combination. Results
|
160 |
+
// for each IOU threshold are packed into the same ImageEvaluation object
|
161 |
+
for (auto i = 0; i < num_images; ++i) {
|
162 |
+
for (auto c = 0; c < num_categories; ++c) {
|
163 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances =
|
164 |
+
image_category_ground_truth_instances[i][c];
|
165 |
+
const std::vector<InstanceAnnotation>& detection_instances =
|
166 |
+
image_category_detection_instances[i][c];
|
167 |
+
|
168 |
+
SortInstancesByDetectionScore(
|
169 |
+
detection_instances, &detection_sorted_indices);
|
170 |
+
if ((int)detection_sorted_indices.size() > max_detections) {
|
171 |
+
detection_sorted_indices.resize(max_detections);
|
172 |
+
}
|
173 |
+
|
174 |
+
for (size_t a = 0; a < area_ranges.size(); ++a) {
|
175 |
+
SortInstancesByIgnore(
|
176 |
+
area_ranges[a],
|
177 |
+
ground_truth_instances,
|
178 |
+
&ground_truth_sorted_indices,
|
179 |
+
&ignores);
|
180 |
+
|
181 |
+
MatchDetectionsToGroundTruth(
|
182 |
+
detection_instances,
|
183 |
+
detection_sorted_indices,
|
184 |
+
ground_truth_instances,
|
185 |
+
ground_truth_sorted_indices,
|
186 |
+
ignores,
|
187 |
+
image_category_ious[i][c],
|
188 |
+
iou_thresholds,
|
189 |
+
area_ranges[a],
|
190 |
+
&results_all
|
191 |
+
[c * num_area_ranges * num_images + a * num_images + i]);
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
|
196 |
+
return results_all;
|
197 |
+
}
|
198 |
+
|
199 |
+
// Convert a python list to a vector
|
200 |
+
template <typename T>
|
201 |
+
std::vector<T> list_to_vec(const py::list& l) {
|
202 |
+
std::vector<T> v(py::len(l));
|
203 |
+
for (int i = 0; i < (int)py::len(l); ++i) {
|
204 |
+
v[i] = l[i].cast<T>();
|
205 |
+
}
|
206 |
+
return v;
|
207 |
+
}
|
208 |
+
|
209 |
+
// Helper function to Accumulate()
|
210 |
+
// Considers the evaluation results applicable to a particular category, area
|
211 |
+
// range, and max_detections parameter setting, which begin at
|
212 |
+
// evaluations[evaluation_index]. Extracts a sorted list of length n of all
|
213 |
+
// applicable detection instances concatenated across all images in the dataset,
|
214 |
+
// which are represented by the outputs evaluation_indices, detection_scores,
|
215 |
+
// image_detection_indices, and detection_sorted_indices--all of which are
|
216 |
+
// length n. evaluation_indices[i] stores the applicable index into
|
217 |
+
// evaluations[] for instance i, which has detection score detection_score[i],
|
218 |
+
// and is the image_detection_indices[i]'th of the list of detections
|
219 |
+
// for the image containing i. detection_sorted_indices[] defines a sorted
|
220 |
+
// permutation of the 3 other outputs
|
221 |
+
int BuildSortedDetectionList(
|
222 |
+
const std::vector<ImageEvaluation>& evaluations,
|
223 |
+
const int64_t evaluation_index,
|
224 |
+
const int64_t num_images,
|
225 |
+
const int max_detections,
|
226 |
+
std::vector<uint64_t>* evaluation_indices,
|
227 |
+
std::vector<double>* detection_scores,
|
228 |
+
std::vector<uint64_t>* detection_sorted_indices,
|
229 |
+
std::vector<uint64_t>* image_detection_indices) {
|
230 |
+
assert(evaluations.size() >= evaluation_index + num_images);
|
231 |
+
|
232 |
+
// Extract a list of object instances of the applicable category, area
|
233 |
+
// range, and max detections requirements such that they can be sorted
|
234 |
+
image_detection_indices->clear();
|
235 |
+
evaluation_indices->clear();
|
236 |
+
detection_scores->clear();
|
237 |
+
image_detection_indices->reserve(num_images * max_detections);
|
238 |
+
evaluation_indices->reserve(num_images * max_detections);
|
239 |
+
detection_scores->reserve(num_images * max_detections);
|
240 |
+
int num_valid_ground_truth = 0;
|
241 |
+
for (auto i = 0; i < num_images; ++i) {
|
242 |
+
const ImageEvaluation& evaluation = evaluations[evaluation_index + i];
|
243 |
+
|
244 |
+
for (int d = 0;
|
245 |
+
d < (int)evaluation.detection_scores.size() && d < max_detections;
|
246 |
+
++d) { // detected instances
|
247 |
+
evaluation_indices->push_back(evaluation_index + i);
|
248 |
+
image_detection_indices->push_back(d);
|
249 |
+
detection_scores->push_back(evaluation.detection_scores[d]);
|
250 |
+
}
|
251 |
+
for (auto ground_truth_ignore : evaluation.ground_truth_ignores) {
|
252 |
+
if (!ground_truth_ignore) {
|
253 |
+
++num_valid_ground_truth;
|
254 |
+
}
|
255 |
+
}
|
256 |
+
}
|
257 |
+
|
258 |
+
// Sort detections by decreasing score, using stable sort to match
|
259 |
+
// python implementation
|
260 |
+
detection_sorted_indices->resize(detection_scores->size());
|
261 |
+
std::iota(
|
262 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
263 |
+
std::stable_sort(
|
264 |
+
detection_sorted_indices->begin(),
|
265 |
+
detection_sorted_indices->end(),
|
266 |
+
[&detection_scores](size_t j1, size_t j2) {
|
267 |
+
return (*detection_scores)[j1] > (*detection_scores)[j2];
|
268 |
+
});
|
269 |
+
|
270 |
+
return num_valid_ground_truth;
|
271 |
+
}
|
272 |
+
|
273 |
+
// Helper function to Accumulate()
|
274 |
+
// Compute a precision recall curve given a sorted list of detected instances
|
275 |
+
// encoded in evaluations, evaluation_indices, detection_scores,
|
276 |
+
// detection_sorted_indices, image_detection_indices (see
|
277 |
+
// BuildSortedDetectionList()). Using vectors precisions and recalls
|
278 |
+
// and temporary storage, output the results into precisions_out, recalls_out,
|
279 |
+
// and scores_out, which are large buffers containing many precion/recall curves
|
280 |
+
// for all possible parameter settings, with precisions_out_index and
|
281 |
+
// recalls_out_index defining the applicable indices to store results.
|
282 |
+
void ComputePrecisionRecallCurve(
|
283 |
+
const int64_t precisions_out_index,
|
284 |
+
const int64_t precisions_out_stride,
|
285 |
+
const int64_t recalls_out_index,
|
286 |
+
const std::vector<double>& recall_thresholds,
|
287 |
+
const int iou_threshold_index,
|
288 |
+
const int num_iou_thresholds,
|
289 |
+
const int num_valid_ground_truth,
|
290 |
+
const std::vector<ImageEvaluation>& evaluations,
|
291 |
+
const std::vector<uint64_t>& evaluation_indices,
|
292 |
+
const std::vector<double>& detection_scores,
|
293 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
294 |
+
const std::vector<uint64_t>& image_detection_indices,
|
295 |
+
std::vector<double>* precisions,
|
296 |
+
std::vector<double>* recalls,
|
297 |
+
std::vector<double>* precisions_out,
|
298 |
+
std::vector<double>* scores_out,
|
299 |
+
std::vector<double>* recalls_out) {
|
300 |
+
assert(recalls_out->size() > recalls_out_index);
|
301 |
+
|
302 |
+
// Compute precision/recall for each instance in the sorted list of detections
|
303 |
+
int64_t true_positives_sum = 0, false_positives_sum = 0;
|
304 |
+
precisions->clear();
|
305 |
+
recalls->clear();
|
306 |
+
precisions->reserve(detection_sorted_indices.size());
|
307 |
+
recalls->reserve(detection_sorted_indices.size());
|
308 |
+
assert(!evaluations.empty() || detection_sorted_indices.empty());
|
309 |
+
for (auto detection_sorted_index : detection_sorted_indices) {
|
310 |
+
const ImageEvaluation& evaluation =
|
311 |
+
evaluations[evaluation_indices[detection_sorted_index]];
|
312 |
+
const auto num_detections =
|
313 |
+
evaluation.detection_matches.size() / num_iou_thresholds;
|
314 |
+
const auto detection_index = iou_threshold_index * num_detections +
|
315 |
+
image_detection_indices[detection_sorted_index];
|
316 |
+
assert(evaluation.detection_matches.size() > detection_index);
|
317 |
+
assert(evaluation.detection_ignores.size() > detection_index);
|
318 |
+
const int64_t detection_match =
|
319 |
+
evaluation.detection_matches[detection_index];
|
320 |
+
const bool detection_ignores =
|
321 |
+
evaluation.detection_ignores[detection_index];
|
322 |
+
const auto true_positive = detection_match > 0 && !detection_ignores;
|
323 |
+
const auto false_positive = detection_match == 0 && !detection_ignores;
|
324 |
+
if (true_positive) {
|
325 |
+
++true_positives_sum;
|
326 |
+
}
|
327 |
+
if (false_positive) {
|
328 |
+
++false_positives_sum;
|
329 |
+
}
|
330 |
+
|
331 |
+
const double recall =
|
332 |
+
static_cast<double>(true_positives_sum) / num_valid_ground_truth;
|
333 |
+
recalls->push_back(recall);
|
334 |
+
const int64_t num_valid_detections =
|
335 |
+
true_positives_sum + false_positives_sum;
|
336 |
+
const double precision = num_valid_detections > 0
|
337 |
+
? static_cast<double>(true_positives_sum) / num_valid_detections
|
338 |
+
: 0.0;
|
339 |
+
precisions->push_back(precision);
|
340 |
+
}
|
341 |
+
|
342 |
+
(*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0;
|
343 |
+
|
344 |
+
for (int64_t i = static_cast<int64_t>(precisions->size()) - 1; i > 0; --i) {
|
345 |
+
if ((*precisions)[i] > (*precisions)[i - 1]) {
|
346 |
+
(*precisions)[i - 1] = (*precisions)[i];
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
// Sample the per instance precision/recall list at each recall threshold
|
351 |
+
for (size_t r = 0; r < recall_thresholds.size(); ++r) {
|
352 |
+
// first index in recalls >= recall_thresholds[r]
|
353 |
+
std::vector<double>::iterator low = std::lower_bound(
|
354 |
+
recalls->begin(), recalls->end(), recall_thresholds[r]);
|
355 |
+
size_t precisions_index = low - recalls->begin();
|
356 |
+
|
357 |
+
const auto results_ind = precisions_out_index + r * precisions_out_stride;
|
358 |
+
assert(results_ind < precisions_out->size());
|
359 |
+
assert(results_ind < scores_out->size());
|
360 |
+
if (precisions_index < precisions->size()) {
|
361 |
+
(*precisions_out)[results_ind] = (*precisions)[precisions_index];
|
362 |
+
(*scores_out)[results_ind] =
|
363 |
+
detection_scores[detection_sorted_indices[precisions_index]];
|
364 |
+
} else {
|
365 |
+
(*precisions_out)[results_ind] = 0;
|
366 |
+
(*scores_out)[results_ind] = 0;
|
367 |
+
}
|
368 |
+
}
|
369 |
+
}
|
370 |
+
py::dict Accumulate(
|
371 |
+
const py::object& params,
|
372 |
+
const std::vector<ImageEvaluation>& evaluations) {
|
373 |
+
const std::vector<double> recall_thresholds =
|
374 |
+
list_to_vec<double>(params.attr("recThrs"));
|
375 |
+
const std::vector<int> max_detections =
|
376 |
+
list_to_vec<int>(params.attr("maxDets"));
|
377 |
+
const int num_iou_thresholds = py::len(params.attr("iouThrs"));
|
378 |
+
const int num_recall_thresholds = py::len(params.attr("recThrs"));
|
379 |
+
const int num_categories = params.attr("useCats").cast<int>() == 1
|
380 |
+
? py::len(params.attr("catIds"))
|
381 |
+
: 1;
|
382 |
+
const int num_area_ranges = py::len(params.attr("areaRng"));
|
383 |
+
const int num_max_detections = py::len(params.attr("maxDets"));
|
384 |
+
const int num_images = py::len(params.attr("imgIds"));
|
385 |
+
|
386 |
+
std::vector<double> precisions_out(
|
387 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
388 |
+
num_area_ranges * num_max_detections,
|
389 |
+
-1);
|
390 |
+
std::vector<double> recalls_out(
|
391 |
+
num_iou_thresholds * num_categories * num_area_ranges *
|
392 |
+
num_max_detections,
|
393 |
+
-1);
|
394 |
+
std::vector<double> scores_out(
|
395 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
396 |
+
num_area_ranges * num_max_detections,
|
397 |
+
-1);
|
398 |
+
|
399 |
+
// Consider the list of all detected instances in the entire dataset in one
|
400 |
+
// large list. evaluation_indices, detection_scores,
|
401 |
+
// image_detection_indices, and detection_sorted_indices all have the same
|
402 |
+
// length as this list, such that each entry corresponds to one detected
|
403 |
+
// instance
|
404 |
+
std::vector<uint64_t> evaluation_indices; // indices into evaluations[]
|
405 |
+
std::vector<double> detection_scores; // detection scores of each instance
|
406 |
+
std::vector<uint64_t> detection_sorted_indices; // sorted indices of all
|
407 |
+
// instances in the dataset
|
408 |
+
std::vector<uint64_t>
|
409 |
+
image_detection_indices; // indices into the list of detected instances in
|
410 |
+
// the same image as each instance
|
411 |
+
std::vector<double> precisions, recalls;
|
412 |
+
|
413 |
+
for (auto c = 0; c < num_categories; ++c) {
|
414 |
+
for (auto a = 0; a < num_area_ranges; ++a) {
|
415 |
+
for (auto m = 0; m < num_max_detections; ++m) {
|
416 |
+
// The COCO PythonAPI assumes evaluations[] (the return value of
|
417 |
+
// COCOeval::EvaluateImages() is one long list storing results for each
|
418 |
+
// combination of category, area range, and image id, with categories in
|
419 |
+
// the outermost loop and images in the innermost loop.
|
420 |
+
const int64_t evaluations_index =
|
421 |
+
c * num_area_ranges * num_images + a * num_images;
|
422 |
+
int num_valid_ground_truth = BuildSortedDetectionList(
|
423 |
+
evaluations,
|
424 |
+
evaluations_index,
|
425 |
+
num_images,
|
426 |
+
max_detections[m],
|
427 |
+
&evaluation_indices,
|
428 |
+
&detection_scores,
|
429 |
+
&detection_sorted_indices,
|
430 |
+
&image_detection_indices);
|
431 |
+
|
432 |
+
if (num_valid_ground_truth == 0) {
|
433 |
+
continue;
|
434 |
+
}
|
435 |
+
|
436 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
437 |
+
// recalls_out is a flattened vectors representing a
|
438 |
+
// num_iou_thresholds X num_categories X num_area_ranges X
|
439 |
+
// num_max_detections matrix
|
440 |
+
const int64_t recalls_out_index =
|
441 |
+
t * num_categories * num_area_ranges * num_max_detections +
|
442 |
+
c * num_area_ranges * num_max_detections +
|
443 |
+
a * num_max_detections + m;
|
444 |
+
|
445 |
+
// precisions_out and scores_out are flattened vectors
|
446 |
+
// representing a num_iou_thresholds X num_recall_thresholds X
|
447 |
+
// num_categories X num_area_ranges X num_max_detections matrix
|
448 |
+
const int64_t precisions_out_stride =
|
449 |
+
num_categories * num_area_ranges * num_max_detections;
|
450 |
+
const int64_t precisions_out_index = t * num_recall_thresholds *
|
451 |
+
num_categories * num_area_ranges * num_max_detections +
|
452 |
+
c * num_area_ranges * num_max_detections +
|
453 |
+
a * num_max_detections + m;
|
454 |
+
|
455 |
+
ComputePrecisionRecallCurve(
|
456 |
+
precisions_out_index,
|
457 |
+
precisions_out_stride,
|
458 |
+
recalls_out_index,
|
459 |
+
recall_thresholds,
|
460 |
+
t,
|
461 |
+
num_iou_thresholds,
|
462 |
+
num_valid_ground_truth,
|
463 |
+
evaluations,
|
464 |
+
evaluation_indices,
|
465 |
+
detection_scores,
|
466 |
+
detection_sorted_indices,
|
467 |
+
image_detection_indices,
|
468 |
+
&precisions,
|
469 |
+
&recalls,
|
470 |
+
&precisions_out,
|
471 |
+
&scores_out,
|
472 |
+
&recalls_out);
|
473 |
+
}
|
474 |
+
}
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
time_t rawtime;
|
479 |
+
struct tm local_time;
|
480 |
+
std::array<char, 200> buffer;
|
481 |
+
time(&rawtime);
|
482 |
+
#ifdef _WIN32
|
483 |
+
localtime_s(&local_time, &rawtime);
|
484 |
+
#else
|
485 |
+
localtime_r(&rawtime, &local_time);
|
486 |
+
#endif
|
487 |
+
strftime(
|
488 |
+
buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time);
|
489 |
+
return py::dict(
|
490 |
+
"params"_a = params,
|
491 |
+
"counts"_a = std::vector<int64_t>({num_iou_thresholds,
|
492 |
+
num_recall_thresholds,
|
493 |
+
num_categories,
|
494 |
+
num_area_ranges,
|
495 |
+
num_max_detections}),
|
496 |
+
"date"_a = buffer,
|
497 |
+
"precision"_a = precisions_out,
|
498 |
+
"recall"_a = recalls_out,
|
499 |
+
"scores"_a = scores_out);
|
500 |
+
}
|
501 |
+
|
502 |
+
} // namespace COCOeval
|
yolox/layers/csrc/cocoeval/cocoeval.h
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include <pybind11/numpy.h>
|
5 |
+
#include <pybind11/pybind11.h>
|
6 |
+
#include <pybind11/stl.h>
|
7 |
+
#include <pybind11/stl_bind.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
namespace py = pybind11;
|
11 |
+
|
12 |
+
namespace COCOeval {
|
13 |
+
|
14 |
+
// Annotation data for a single object instance in an image
|
15 |
+
struct InstanceAnnotation {
|
16 |
+
InstanceAnnotation(
|
17 |
+
uint64_t id,
|
18 |
+
double score,
|
19 |
+
double area,
|
20 |
+
bool is_crowd,
|
21 |
+
bool ignore)
|
22 |
+
: id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {}
|
23 |
+
uint64_t id;
|
24 |
+
double score = 0.;
|
25 |
+
double area = 0.;
|
26 |
+
bool is_crowd = false;
|
27 |
+
bool ignore = false;
|
28 |
+
};
|
29 |
+
|
30 |
+
// Stores intermediate results for evaluating detection results for a single
|
31 |
+
// image that has D detected instances and G ground truth instances. This stores
|
32 |
+
// matches between detected and ground truth instances
|
33 |
+
struct ImageEvaluation {
|
34 |
+
// For each of the D detected instances, the id of the matched ground truth
|
35 |
+
// instance, or 0 if unmatched
|
36 |
+
std::vector<uint64_t> detection_matches;
|
37 |
+
|
38 |
+
// The detection score of each of the D detected instances
|
39 |
+
std::vector<double> detection_scores;
|
40 |
+
|
41 |
+
// Marks whether or not each of G instances was ignored from evaluation (e.g.,
|
42 |
+
// because it's outside area_range)
|
43 |
+
std::vector<bool> ground_truth_ignores;
|
44 |
+
|
45 |
+
// Marks whether or not each of D instances was ignored from evaluation (e.g.,
|
46 |
+
// because it's outside aRng)
|
47 |
+
std::vector<bool> detection_ignores;
|
48 |
+
};
|
49 |
+
|
50 |
+
template <class T>
|
51 |
+
using ImageCategoryInstances = std::vector<std::vector<std::vector<T>>>;
|
52 |
+
|
53 |
+
// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each
|
54 |
+
// combination of image, category, area range settings, and IOU thresholds to
|
55 |
+
// evaluate, it matches detected instances to ground truth instances and stores
|
56 |
+
// the results into a vector of ImageEvaluation results, which will be
|
57 |
+
// interpreted by the COCOeval::Accumulate() function to produce precion-recall
|
58 |
+
// curves. The parameters of nested vectors have the following semantics:
|
59 |
+
// image_category_ious[i][c][d][g] is the intersection over union of the d'th
|
60 |
+
// detected instance and g'th ground truth instance of
|
61 |
+
// category category_ids[c] in image image_ids[i]
|
62 |
+
// image_category_ground_truth_instances[i][c] is a vector of ground truth
|
63 |
+
// instances in image image_ids[i] of category category_ids[c]
|
64 |
+
// image_category_detection_instances[i][c] is a vector of detected
|
65 |
+
// instances in image image_ids[i] of category category_ids[c]
|
66 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
67 |
+
const std::vector<std::array<double, 2>>& area_ranges, // vector of 2-tuples
|
68 |
+
int max_detections,
|
69 |
+
const std::vector<double>& iou_thresholds,
|
70 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
71 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
72 |
+
image_category_ground_truth_instances,
|
73 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
74 |
+
image_category_detection_instances);
|
75 |
+
|
76 |
+
// C++ implementation of COCOeval.accumulate(), which generates precision
|
77 |
+
// recall curves for each set of category, IOU threshold, detection area range,
|
78 |
+
// and max number of detections parameters. It is assumed that the parameter
|
79 |
+
// evaluations is the return value of the functon COCOeval::EvaluateImages(),
|
80 |
+
// which was called with the same parameter settings params
|
81 |
+
py::dict Accumulate(
|
82 |
+
const py::object& params,
|
83 |
+
const std::vector<ImageEvaluation>& evalutations);
|
84 |
+
|
85 |
+
} // namespace COCOeval
|
yolox/layers/csrc/vision.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cocoeval/cocoeval.h"
|
2 |
+
|
3 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
4 |
+
m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate");
|
5 |
+
m.def(
|
6 |
+
"COCOevalEvaluateImages",
|
7 |
+
&COCOeval::EvaluateImages,
|
8 |
+
"COCOeval::EvaluateImages");
|
9 |
+
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
|
10 |
+
.def(pybind11::init<uint64_t, double, double, bool, bool>());
|
11 |
+
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
|
12 |
+
.def(pybind11::init<>());
|
13 |
+
}
|
yolox/layers/fast_coco_eval_api.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# This file comes from
|
4 |
+
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/evaluation/fast_eval_api.py
|
5 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import time
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from pycocotools.cocoeval import COCOeval
|
12 |
+
|
13 |
+
# import torch first to make yolox._C work without ImportError of libc10.so
|
14 |
+
# in YOLOX, env is already set in __init__.py.
|
15 |
+
from yolox import _C
|
16 |
+
|
17 |
+
|
18 |
+
class COCOeval_opt(COCOeval):
|
19 |
+
"""
|
20 |
+
This is a slightly modified version of the original COCO API, where the functions evaluateImg()
|
21 |
+
and accumulate() are implemented in C++ to speedup evaluation
|
22 |
+
"""
|
23 |
+
|
24 |
+
def evaluate(self):
|
25 |
+
"""
|
26 |
+
Run per image evaluation on given images and store results in self.evalImgs_cpp, a
|
27 |
+
datastructure that isn't readable from Python but is used by a c++ implementation of
|
28 |
+
accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure
|
29 |
+
self.evalImgs because this datastructure is a computational bottleneck.
|
30 |
+
:return: None
|
31 |
+
"""
|
32 |
+
tic = time.time()
|
33 |
+
|
34 |
+
print("Running per image evaluation...")
|
35 |
+
p = self.params
|
36 |
+
# add backward compatibility if useSegm is specified in params
|
37 |
+
if p.useSegm is not None:
|
38 |
+
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
39 |
+
print(
|
40 |
+
"useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)
|
41 |
+
)
|
42 |
+
print("Evaluate annotation type *{}*".format(p.iouType))
|
43 |
+
p.imgIds = list(np.unique(p.imgIds))
|
44 |
+
if p.useCats:
|
45 |
+
p.catIds = list(np.unique(p.catIds))
|
46 |
+
p.maxDets = sorted(p.maxDets)
|
47 |
+
self.params = p
|
48 |
+
|
49 |
+
self._prepare()
|
50 |
+
|
51 |
+
# loop through images, area range, max detection number
|
52 |
+
catIds = p.catIds if p.useCats else [-1]
|
53 |
+
|
54 |
+
if p.iouType == "segm" or p.iouType == "bbox":
|
55 |
+
computeIoU = self.computeIoU
|
56 |
+
elif p.iouType == "keypoints":
|
57 |
+
computeIoU = self.computeOks
|
58 |
+
self.ious = {
|
59 |
+
(imgId, catId): computeIoU(imgId, catId)
|
60 |
+
for imgId in p.imgIds
|
61 |
+
for catId in catIds
|
62 |
+
}
|
63 |
+
|
64 |
+
maxDet = p.maxDets[-1]
|
65 |
+
|
66 |
+
# <<<< Beginning of code differences with original COCO API
|
67 |
+
def convert_instances_to_cpp(instances, is_det=False):
|
68 |
+
# Convert annotations for a list of instances in an image to a format that's fast
|
69 |
+
# to access in C++
|
70 |
+
instances_cpp = []
|
71 |
+
for instance in instances:
|
72 |
+
instance_cpp = _C.InstanceAnnotation(
|
73 |
+
int(instance["id"]),
|
74 |
+
instance["score"] if is_det else instance.get("score", 0.0),
|
75 |
+
instance["area"],
|
76 |
+
bool(instance.get("iscrowd", 0)),
|
77 |
+
bool(instance.get("ignore", 0)),
|
78 |
+
)
|
79 |
+
instances_cpp.append(instance_cpp)
|
80 |
+
return instances_cpp
|
81 |
+
|
82 |
+
# Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
|
83 |
+
ground_truth_instances = [
|
84 |
+
[convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
|
85 |
+
for imgId in p.imgIds
|
86 |
+
]
|
87 |
+
detected_instances = [
|
88 |
+
[
|
89 |
+
convert_instances_to_cpp(self._dts[imgId, catId], is_det=True)
|
90 |
+
for catId in p.catIds
|
91 |
+
]
|
92 |
+
for imgId in p.imgIds
|
93 |
+
]
|
94 |
+
ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
|
95 |
+
|
96 |
+
if not p.useCats:
|
97 |
+
# For each image, flatten per-category lists into a single list
|
98 |
+
ground_truth_instances = [
|
99 |
+
[[o for c in i for o in c]] for i in ground_truth_instances
|
100 |
+
]
|
101 |
+
detected_instances = [
|
102 |
+
[[o for c in i for o in c]] for i in detected_instances
|
103 |
+
]
|
104 |
+
|
105 |
+
# Call C++ implementation of self.evaluateImgs()
|
106 |
+
self._evalImgs_cpp = _C.COCOevalEvaluateImages(
|
107 |
+
p.areaRng,
|
108 |
+
maxDet,
|
109 |
+
p.iouThrs,
|
110 |
+
ious,
|
111 |
+
ground_truth_instances,
|
112 |
+
detected_instances,
|
113 |
+
)
|
114 |
+
self._evalImgs = None
|
115 |
+
|
116 |
+
self._paramsEval = copy.deepcopy(self.params)
|
117 |
+
toc = time.time()
|
118 |
+
print("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic))
|
119 |
+
# >>>> End of code differences with original COCO API
|
120 |
+
|
121 |
+
def accumulate(self):
|
122 |
+
"""
|
123 |
+
Accumulate per image evaluation results and store the result in self.eval. Does not
|
124 |
+
support changing parameter settings from those used by self.evaluate()
|
125 |
+
"""
|
126 |
+
print("Accumulating evaluation results...")
|
127 |
+
tic = time.time()
|
128 |
+
if not hasattr(self, "_evalImgs_cpp"):
|
129 |
+
print("Please run evaluate() first")
|
130 |
+
|
131 |
+
self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
|
132 |
+
|
133 |
+
# recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
|
134 |
+
self.eval["recall"] = np.array(self.eval["recall"]).reshape(
|
135 |
+
self.eval["counts"][:1] + self.eval["counts"][2:]
|
136 |
+
)
|
137 |
+
|
138 |
+
# precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
|
139 |
+
# num_area_ranges X num_max_detections
|
140 |
+
self.eval["precision"] = np.array(self.eval["precision"]).reshape(
|
141 |
+
self.eval["counts"]
|
142 |
+
)
|
143 |
+
self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
|
144 |
+
toc = time.time()
|
145 |
+
print(
|
146 |
+
"COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic)
|
147 |
+
)
|
yolox/models/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .darknet import CSPDarknet, Darknet
|
5 |
+
from .losses import IOUloss
|
6 |
+
from .yolo_fpn import YOLOFPN
|
7 |
+
from .yolo_head import YOLOXHead
|
8 |
+
from .yolo_pafpn import YOLOPAFPN
|
9 |
+
from .yolox import YOLOX
|
yolox/models/darknet.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
|
7 |
+
|
8 |
+
|
9 |
+
class Darknet(nn.Module):
|
10 |
+
# number of blocks from dark2 to dark5.
|
11 |
+
depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self, depth, in_channels=3, stem_out_channels=32, out_features=("dark3", "dark4", "dark5"),
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Args:
|
18 |
+
depth (int): depth of darknet used in model, usually use [21, 53] for this param.
|
19 |
+
in_channels (int): number of input channels, for example, use 3 for RGB image.
|
20 |
+
stem_out_channels (int): number of output chanels of darknet stem.
|
21 |
+
It decides channels of darknet layer2 to layer5.
|
22 |
+
out_features (Tuple[str]): desired output layer name.
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
assert out_features, "please provide output features of Darknet"
|
26 |
+
self.out_features = out_features
|
27 |
+
self.stem = nn.Sequential(
|
28 |
+
BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
|
29 |
+
*self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
|
30 |
+
)
|
31 |
+
in_channels = stem_out_channels * 2 # 64
|
32 |
+
|
33 |
+
num_blocks = Darknet.depth2blocks[depth]
|
34 |
+
# create darknet with `stem_out_channels` and `num_blocks` layers.
|
35 |
+
# to make model structure more clear, we don't use `for` statement in python.
|
36 |
+
self.dark2 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[0], stride=2))
|
37 |
+
in_channels *= 2 # 128
|
38 |
+
self.dark3 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[1], stride=2))
|
39 |
+
in_channels *= 2 # 256
|
40 |
+
self.dark4 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[2], stride=2))
|
41 |
+
in_channels *= 2 # 512
|
42 |
+
|
43 |
+
self.dark5 = nn.Sequential(
|
44 |
+
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
|
45 |
+
*self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
|
46 |
+
)
|
47 |
+
|
48 |
+
def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
|
49 |
+
"starts with conv layer then has `num_blocks` `ResLayer`"
|
50 |
+
return [
|
51 |
+
BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
|
52 |
+
*[(ResLayer(in_channels * 2)) for _ in range(num_blocks)]
|
53 |
+
]
|
54 |
+
|
55 |
+
def make_spp_block(self, filters_list, in_filters):
|
56 |
+
m = nn.Sequential(
|
57 |
+
*[
|
58 |
+
BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
|
59 |
+
BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
|
60 |
+
SPPBottleneck(
|
61 |
+
in_channels=filters_list[1],
|
62 |
+
out_channels=filters_list[0],
|
63 |
+
activation="lrelu"
|
64 |
+
),
|
65 |
+
BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
|
66 |
+
BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
return m
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
outputs = {}
|
73 |
+
x = self.stem(x)
|
74 |
+
outputs["stem"] = x
|
75 |
+
x = self.dark2(x)
|
76 |
+
outputs["dark2"] = x
|
77 |
+
x = self.dark3(x)
|
78 |
+
outputs["dark3"] = x
|
79 |
+
x = self.dark4(x)
|
80 |
+
outputs["dark4"] = x
|
81 |
+
x = self.dark5(x)
|
82 |
+
outputs["dark5"] = x
|
83 |
+
return {k: v for k, v in outputs.items() if k in self.out_features}
|
84 |
+
|
85 |
+
|
86 |
+
class CSPDarknet(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False):
|
89 |
+
super().__init__()
|
90 |
+
assert out_features, "please provide output features of Darknet"
|
91 |
+
self.out_features = out_features
|
92 |
+
Conv = DWConv if depthwise else BaseConv
|
93 |
+
|
94 |
+
base_channels = int(wid_mul * 64) # 64
|
95 |
+
base_depth = max(round(dep_mul * 3), 1) # 3
|
96 |
+
|
97 |
+
# stem
|
98 |
+
self.stem = Focus(3, base_channels, ksize=3)
|
99 |
+
|
100 |
+
# dark2
|
101 |
+
self.dark2 = nn.Sequential(
|
102 |
+
Conv(base_channels, base_channels * 2, 3, 2),
|
103 |
+
CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise),
|
104 |
+
)
|
105 |
+
|
106 |
+
# dark3
|
107 |
+
self.dark3 = nn.Sequential(
|
108 |
+
Conv(base_channels * 2, base_channels * 4, 3, 2),
|
109 |
+
CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise),
|
110 |
+
)
|
111 |
+
|
112 |
+
# dark4
|
113 |
+
self.dark4 = nn.Sequential(
|
114 |
+
Conv(base_channels * 4, base_channels * 8, 3, 2),
|
115 |
+
CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise),
|
116 |
+
)
|
117 |
+
|
118 |
+
# dark5
|
119 |
+
self.dark5 = nn.Sequential(
|
120 |
+
Conv(base_channels * 8, base_channels * 16, 3, 2),
|
121 |
+
SPPBottleneck(base_channels * 16, base_channels * 16),
|
122 |
+
CSPLayer(
|
123 |
+
base_channels * 16, base_channels * 16, n=base_depth,
|
124 |
+
shortcut=False, depthwise=depthwise,
|
125 |
+
),
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
outputs = {}
|
130 |
+
x = self.stem(x)
|
131 |
+
outputs["stem"] = x
|
132 |
+
x = self.dark2(x)
|
133 |
+
outputs["dark2"] = x
|
134 |
+
x = self.dark3(x)
|
135 |
+
outputs["dark3"] = x
|
136 |
+
x = self.dark4(x)
|
137 |
+
outputs["dark4"] = x
|
138 |
+
x = self.dark5(x)
|
139 |
+
outputs["dark5"] = x
|
140 |
+
return {k: v for k, v in outputs.items() if k in self.out_features}
|
yolox/models/losses.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class IOUloss(nn.Module):
|
6 |
+
def __init__(self, reduction="none", loss_type="iou"):
|
7 |
+
super(IOUloss, self).__init__()
|
8 |
+
self.reduction = reduction
|
9 |
+
self.loss_type = loss_type
|
10 |
+
|
11 |
+
def forward(self, pred, target):
|
12 |
+
assert pred.shape[0] == target.shape[0]
|
13 |
+
|
14 |
+
pred = pred.view(-1, 4)
|
15 |
+
target = target.view(-1, 4)
|
16 |
+
tl = torch.max(
|
17 |
+
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
18 |
+
)
|
19 |
+
br = torch.min(
|
20 |
+
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
21 |
+
)
|
22 |
+
|
23 |
+
area_p = torch.prod(pred[:, 2:], 1)
|
24 |
+
area_g = torch.prod(target[:, 2:], 1)
|
25 |
+
|
26 |
+
en = (tl < br).type(tl.type()).prod(dim=1)
|
27 |
+
area_i = torch.prod(br - tl, 1) * en
|
28 |
+
iou = (area_i) / (area_p + area_g - area_i + 1e-16)
|
29 |
+
|
30 |
+
if self.loss_type == "iou":
|
31 |
+
loss = 1 - iou ** 2
|
32 |
+
elif self.loss_type == "giou":
|
33 |
+
c_tl = torch.min(
|
34 |
+
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
35 |
+
)
|
36 |
+
c_br = torch.max(
|
37 |
+
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
38 |
+
)
|
39 |
+
area_c = torch.prod(c_br - c_tl, 1)
|
40 |
+
giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
|
41 |
+
loss = 1 - giou.clamp(min=-1.0, max=1.0)
|
42 |
+
|
43 |
+
if self.reduction == "mean":
|
44 |
+
loss = loss.mean()
|
45 |
+
elif self.reduction == "sum":
|
46 |
+
loss = loss.sum()
|
47 |
+
|
48 |
+
return loss
|
yolox/models/network_blocks.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class SiLU(nn.Module):
|
6 |
+
# export-friendly version of nn.SiLU()
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def forward(x):
|
10 |
+
return x * torch.sigmoid(x)
|
11 |
+
|
12 |
+
|
13 |
+
def get_activation(name="silu", inplace=True):
|
14 |
+
if name == "silu":
|
15 |
+
module = nn.SiLU(inplace=inplace)
|
16 |
+
elif name == "relu":
|
17 |
+
module = nn.ReLU(inplace=inplace)
|
18 |
+
elif name == "lrelu":
|
19 |
+
module = nn.LeakyReLU(0.1, inplace=inplace)
|
20 |
+
else:
|
21 |
+
raise AttributeError("Unsupported act type: {}".format(name))
|
22 |
+
return module
|
23 |
+
|
24 |
+
|
25 |
+
class BaseConv(nn.Module):
|
26 |
+
"""
|
27 |
+
A Conv2d -> Batchnorm -> silu/leaky relu block
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
|
31 |
+
super().__init__()
|
32 |
+
# same padding
|
33 |
+
pad = (ksize - 1) // 2
|
34 |
+
self.conv = nn.Conv2d(
|
35 |
+
in_channels,
|
36 |
+
out_channels,
|
37 |
+
kernel_size=ksize,
|
38 |
+
stride=stride,
|
39 |
+
padding=pad,
|
40 |
+
groups=groups,
|
41 |
+
bias=bias,
|
42 |
+
)
|
43 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
44 |
+
self.act = get_activation(act, inplace=True)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.act(self.bn(self.conv(x)))
|
48 |
+
|
49 |
+
def fuseforward(self, x):
|
50 |
+
return self.act(self.conv(x))
|
51 |
+
|
52 |
+
|
53 |
+
class DWConv(nn.Module):
|
54 |
+
"""Depthwise Conv + Conv"""
|
55 |
+
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
|
56 |
+
super().__init__()
|
57 |
+
self.dconv = BaseConv(
|
58 |
+
in_channels, in_channels, ksize=ksize,
|
59 |
+
stride=stride, groups=in_channels, act=act
|
60 |
+
)
|
61 |
+
self.pconv = BaseConv(
|
62 |
+
in_channels, out_channels, ksize=1,
|
63 |
+
stride=1, groups=1, act=act
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.dconv(x)
|
68 |
+
return self.pconv(x)
|
69 |
+
|
70 |
+
|
71 |
+
class Bottleneck(nn.Module):
|
72 |
+
# Standard bottleneck
|
73 |
+
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False):
|
74 |
+
super().__init__()
|
75 |
+
hidden_channels = int(out_channels * expansion)
|
76 |
+
Conv = DWConv if depthwise else BaseConv
|
77 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1)
|
78 |
+
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1)
|
79 |
+
self.use_add = shortcut and in_channels == out_channels
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
y = self.conv2(self.conv1(x))
|
83 |
+
if self.use_add:
|
84 |
+
y = y + x
|
85 |
+
return y
|
86 |
+
|
87 |
+
|
88 |
+
class ResLayer(nn.Module):
|
89 |
+
"Residual layer with `in_channels` inputs."
|
90 |
+
def __init__(self, in_channels: int):
|
91 |
+
super().__init__()
|
92 |
+
mid_channels = in_channels // 2
|
93 |
+
self.layer1 = BaseConv(in_channels, mid_channels, ksize=1, stride=1, act="lrelu")
|
94 |
+
self.layer2 = BaseConv(mid_channels, in_channels, ksize=3, stride=1, act="lrelu")
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
out = self.layer2(self.layer1(x))
|
98 |
+
return x + out
|
99 |
+
|
100 |
+
|
101 |
+
class SPPBottleneck(nn.Module):
|
102 |
+
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
103 |
+
def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
|
104 |
+
super().__init__()
|
105 |
+
hidden_channels = in_channels // 2
|
106 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
|
107 |
+
self.m = nn.ModuleList(
|
108 |
+
[nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]
|
109 |
+
)
|
110 |
+
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
|
111 |
+
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x = self.conv1(x)
|
115 |
+
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
|
116 |
+
x = self.conv2(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class CSPLayer(nn.Module):
|
121 |
+
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self, in_channels, out_channels, n=1,
|
125 |
+
shortcut=True, expansion=0.5, depthwise=False
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Args:
|
129 |
+
n (int): number of Bottlenecks. Default value: 1.
|
130 |
+
"""
|
131 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
132 |
+
super().__init__()
|
133 |
+
hidden_channels = int(out_channels * expansion) # hidden channels
|
134 |
+
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1)
|
135 |
+
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1)
|
136 |
+
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1) # act=FReLU(c2)
|
137 |
+
module_list = [
|
138 |
+
Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise)
|
139 |
+
for _ in range(n)
|
140 |
+
]
|
141 |
+
self.m = nn.Sequential(*module_list)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
x_1 = self.conv1(x)
|
145 |
+
x_2 = self.conv2(x)
|
146 |
+
x_1 = self.m(x_1)
|
147 |
+
x = torch.cat((x_1, x_2), dim=1)
|
148 |
+
return self.conv3(x)
|
149 |
+
|
150 |
+
|
151 |
+
class Focus(nn.Module):
|
152 |
+
"""Focus width and height information into channel space."""
|
153 |
+
def __init__(self, in_channels, out_channels, ksize=1, stride=1):
|
154 |
+
super().__init__()
|
155 |
+
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride)
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
|
159 |
+
patch_top_left = x[..., ::2, ::2]
|
160 |
+
patch_top_right = x[..., ::2, 1::2]
|
161 |
+
patch_bot_left = x[..., 1::2, ::2]
|
162 |
+
patch_bot_right = x[..., 1::2, 1::2]
|
163 |
+
x = torch.cat(
|
164 |
+
(patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,
|
165 |
+
)
|
166 |
+
return self.conv(x)
|
yolox/models/yolo_fpn.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .darknet import Darknet
|
5 |
+
from .network_blocks import BaseConv
|
6 |
+
|
7 |
+
|
8 |
+
class YOLOFPN(nn.Module):
|
9 |
+
"""
|
10 |
+
YOLOFPN module. Darknet 53 is the default backbone of this model.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self, depth=53, in_features=["dark3", "dark4", "dark5"],
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.backbone = Darknet(depth)
|
19 |
+
self.in_features = in_features
|
20 |
+
|
21 |
+
# out 1
|
22 |
+
self.out1_cbl = self._make_cbl(512, 256, 1)
|
23 |
+
self.out1 = self._make_embedding([256, 512], 512 + 256)
|
24 |
+
|
25 |
+
# out 2
|
26 |
+
self.out2_cbl = self._make_cbl(256, 128, 1)
|
27 |
+
self.out2 = self._make_embedding([128, 256], 256 + 128)
|
28 |
+
|
29 |
+
# upsample
|
30 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
31 |
+
|
32 |
+
def _make_cbl(self, _in, _out, ks):
|
33 |
+
return BaseConv(_in, _out, ks, stride=1, act="lrelu")
|
34 |
+
|
35 |
+
def _make_embedding(self, filters_list, in_filters):
|
36 |
+
m = nn.Sequential(
|
37 |
+
*[
|
38 |
+
self._make_cbl(in_filters, filters_list[0], 1),
|
39 |
+
self._make_cbl(filters_list[0], filters_list[1], 3),
|
40 |
+
|
41 |
+
self._make_cbl(filters_list[1], filters_list[0], 1),
|
42 |
+
|
43 |
+
self._make_cbl(filters_list[0], filters_list[1], 3),
|
44 |
+
self._make_cbl(filters_list[1], filters_list[0], 1),
|
45 |
+
]
|
46 |
+
)
|
47 |
+
return m
|
48 |
+
|
49 |
+
def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
|
50 |
+
with open(filename, "rb") as f:
|
51 |
+
state_dict = torch.load(f, map_location="cpu")
|
52 |
+
print("loading pretrained weights...")
|
53 |
+
self.backbone.load_state_dict(state_dict)
|
54 |
+
|
55 |
+
def forward(self, inputs):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
inputs (Tensor): input image.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tuple[Tensor]: FPN output features..
|
62 |
+
"""
|
63 |
+
# backbone
|
64 |
+
out_features = self.backbone(inputs)
|
65 |
+
x2, x1, x0 = [out_features[f] for f in self.in_features]
|
66 |
+
|
67 |
+
# yolo branch 1
|
68 |
+
x1_in = self.out1_cbl(x0)
|
69 |
+
x1_in = self.upsample(x1_in)
|
70 |
+
x1_in = torch.cat([x1_in, x1], 1)
|
71 |
+
out_dark4 = self.out1(x1_in)
|
72 |
+
|
73 |
+
# yolo branch 2
|
74 |
+
x2_in = self.out2_cbl(out_dark4)
|
75 |
+
x2_in = self.upsample(x2_in)
|
76 |
+
x2_in = torch.cat([x2_in, x2], 1)
|
77 |
+
out_dark3 = self.out2(x2_in)
|
78 |
+
|
79 |
+
outputs = (out_dark3, out_dark4, x0)
|
80 |
+
return outputs
|
yolox/models/yolo_head.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import math
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from yolox.utils import bboxes_iou
|
12 |
+
|
13 |
+
from .losses import IOUloss
|
14 |
+
from .network_blocks import BaseConv, DWConv
|
15 |
+
|
16 |
+
|
17 |
+
class YOLOXHead(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self, num_classes, width=1.0, strides=[8, 16, 32],
|
20 |
+
in_channels=[256, 512, 1024], act="silu", depthwise=False
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
act (str): activation type of conv. Defalut value: "silu".
|
25 |
+
depthwise (bool): wheather apply depthwise conv in conv branch. Defalut value: False.
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.n_anchors = 1
|
30 |
+
self.num_classes = num_classes
|
31 |
+
self.decode_in_inference = True # for deploy, set to False
|
32 |
+
|
33 |
+
self.cls_convs = nn.ModuleList()
|
34 |
+
self.reg_convs = nn.ModuleList()
|
35 |
+
self.cls_preds = nn.ModuleList()
|
36 |
+
self.reg_preds = nn.ModuleList()
|
37 |
+
self.obj_preds = nn.ModuleList()
|
38 |
+
self.stems = nn.ModuleList()
|
39 |
+
Conv = DWConv if depthwise else BaseConv
|
40 |
+
|
41 |
+
for i in range(len(in_channels)):
|
42 |
+
self.stems.append(
|
43 |
+
BaseConv(
|
44 |
+
in_channels=int(in_channels[i] * width),
|
45 |
+
out_channels=int(256 * width),
|
46 |
+
ksize=1,
|
47 |
+
stride=1,
|
48 |
+
act=act,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
self.cls_convs.append(
|
52 |
+
nn.Sequential(
|
53 |
+
*[
|
54 |
+
Conv(
|
55 |
+
in_channels=int(256 * width),
|
56 |
+
out_channels=int(256 * width),
|
57 |
+
ksize=3,
|
58 |
+
stride=1,
|
59 |
+
act=act,
|
60 |
+
),
|
61 |
+
Conv(
|
62 |
+
in_channels=int(256 * width),
|
63 |
+
out_channels=int(256 * width),
|
64 |
+
ksize=3,
|
65 |
+
stride=1,
|
66 |
+
act=act,
|
67 |
+
),
|
68 |
+
]
|
69 |
+
)
|
70 |
+
)
|
71 |
+
self.reg_convs.append(
|
72 |
+
nn.Sequential(
|
73 |
+
*[
|
74 |
+
Conv(
|
75 |
+
in_channels=int(256 * width),
|
76 |
+
out_channels=int(256 * width),
|
77 |
+
ksize=3,
|
78 |
+
stride=1,
|
79 |
+
act=act,
|
80 |
+
),
|
81 |
+
Conv(
|
82 |
+
in_channels=int(256 * width),
|
83 |
+
out_channels=int(256 * width),
|
84 |
+
ksize=3,
|
85 |
+
stride=1,
|
86 |
+
act=act,
|
87 |
+
),
|
88 |
+
]
|
89 |
+
)
|
90 |
+
)
|
91 |
+
self.cls_preds.append(
|
92 |
+
nn.Conv2d(
|
93 |
+
in_channels=int(256 * width),
|
94 |
+
out_channels=self.n_anchors * self.num_classes,
|
95 |
+
kernel_size=1,
|
96 |
+
stride=1,
|
97 |
+
padding=0,
|
98 |
+
)
|
99 |
+
)
|
100 |
+
self.reg_preds.append(
|
101 |
+
nn.Conv2d(
|
102 |
+
in_channels=int(256 * width),
|
103 |
+
out_channels=4,
|
104 |
+
kernel_size=1,
|
105 |
+
stride=1,
|
106 |
+
padding=0,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
self.obj_preds.append(
|
110 |
+
nn.Conv2d(
|
111 |
+
in_channels=int(256 * width),
|
112 |
+
out_channels=self.n_anchors * 1,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0,
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
self.use_l1 = False
|
120 |
+
self.l1_loss = nn.L1Loss(reduction="none")
|
121 |
+
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
122 |
+
self.iou_loss = IOUloss(reduction="none")
|
123 |
+
self.strides = strides
|
124 |
+
self.grids = [torch.zeros(1)] * len(in_channels)
|
125 |
+
self.expanded_strides = [None] * len(in_channels)
|
126 |
+
|
127 |
+
def initialize_biases(self, prior_prob):
|
128 |
+
for conv in self.cls_preds:
|
129 |
+
b = conv.bias.view(self.n_anchors, -1)
|
130 |
+
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
131 |
+
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
132 |
+
|
133 |
+
for conv in self.obj_preds:
|
134 |
+
b = conv.bias.view(self.n_anchors, -1)
|
135 |
+
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
|
136 |
+
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
137 |
+
|
138 |
+
def forward(self, xin, labels=None, imgs=None):
|
139 |
+
outputs = []
|
140 |
+
origin_preds = []
|
141 |
+
x_shifts = []
|
142 |
+
y_shifts = []
|
143 |
+
expanded_strides = []
|
144 |
+
|
145 |
+
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
|
146 |
+
zip(self.cls_convs, self.reg_convs, self.strides, xin)
|
147 |
+
):
|
148 |
+
x = self.stems[k](x)
|
149 |
+
cls_x = x
|
150 |
+
reg_x = x
|
151 |
+
|
152 |
+
cls_feat = cls_conv(cls_x)
|
153 |
+
cls_output = self.cls_preds[k](cls_feat)
|
154 |
+
|
155 |
+
reg_feat = reg_conv(reg_x)
|
156 |
+
reg_output = self.reg_preds[k](reg_feat)
|
157 |
+
obj_output = self.obj_preds[k](reg_feat)
|
158 |
+
|
159 |
+
if self.training:
|
160 |
+
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
161 |
+
output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
|
162 |
+
x_shifts.append(grid[:, :, 0])
|
163 |
+
y_shifts.append(grid[:, :, 1])
|
164 |
+
expanded_strides.append(
|
165 |
+
torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(xin[0])
|
166 |
+
)
|
167 |
+
if self.use_l1:
|
168 |
+
origin_preds.append(reg_output.clone())
|
169 |
+
|
170 |
+
else:
|
171 |
+
output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
|
172 |
+
|
173 |
+
outputs.append(output)
|
174 |
+
|
175 |
+
if self.training:
|
176 |
+
return self.get_losses(
|
177 |
+
imgs, x_shifts, y_shifts, expanded_strides, labels,
|
178 |
+
torch.cat(outputs, 1), origin_preds, dtype=xin[0].dtype
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
self.hw = [x.shape[-2:] for x in outputs]
|
182 |
+
# [batch, n_anchors_all, 85]
|
183 |
+
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
|
184 |
+
if self.decode_in_inference:
|
185 |
+
return self.decode_outputs(outputs, dtype=xin[0].type())
|
186 |
+
else:
|
187 |
+
return outputs
|
188 |
+
|
189 |
+
def get_output_and_grid(self, output, k, stride, dtype):
|
190 |
+
grid = self.grids[k]
|
191 |
+
|
192 |
+
batch_size = output.shape[0]
|
193 |
+
n_ch = 5 + self.num_classes
|
194 |
+
hsize, wsize = output.shape[-2:]
|
195 |
+
if grid.shape[2:3] != output.shape[2:3]:
|
196 |
+
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
197 |
+
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
|
198 |
+
self.grids[k] = grid
|
199 |
+
|
200 |
+
output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
|
201 |
+
output = (
|
202 |
+
output.permute(0, 1, 3, 4, 2)
|
203 |
+
.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
|
204 |
+
)
|
205 |
+
grid = grid.view(1, -1, 2)
|
206 |
+
output[..., :2] = (output[..., :2] + grid) * stride
|
207 |
+
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
208 |
+
return output, grid
|
209 |
+
|
210 |
+
def decode_outputs(self, outputs, dtype):
|
211 |
+
grids = []
|
212 |
+
strides = []
|
213 |
+
for (hsize, wsize), stride in zip(self.hw, self.strides):
|
214 |
+
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
215 |
+
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
|
216 |
+
grids.append(grid)
|
217 |
+
shape = grid.shape[:2]
|
218 |
+
strides.append(torch.full((*shape, 1), stride))
|
219 |
+
|
220 |
+
grids = torch.cat(grids, dim=1).type(dtype)
|
221 |
+
strides = torch.cat(strides, dim=1).type(dtype)
|
222 |
+
|
223 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * strides
|
224 |
+
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
def get_losses(
|
228 |
+
self, imgs, x_shifts, y_shifts, expanded_strides, labels, outputs, origin_preds, dtype,
|
229 |
+
):
|
230 |
+
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
|
231 |
+
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
|
232 |
+
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
|
233 |
+
|
234 |
+
# calculate targets
|
235 |
+
mixup = labels.shape[2] > 5
|
236 |
+
if mixup:
|
237 |
+
label_cut = labels[..., :5]
|
238 |
+
else:
|
239 |
+
label_cut = labels
|
240 |
+
nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) # number of objects
|
241 |
+
|
242 |
+
total_num_anchors = outputs.shape[1]
|
243 |
+
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
|
244 |
+
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
|
245 |
+
expanded_strides = torch.cat(expanded_strides, 1)
|
246 |
+
if self.use_l1:
|
247 |
+
origin_preds = torch.cat(origin_preds, 1)
|
248 |
+
|
249 |
+
cls_targets = []
|
250 |
+
reg_targets = []
|
251 |
+
l1_targets = []
|
252 |
+
obj_targets = []
|
253 |
+
fg_masks = []
|
254 |
+
|
255 |
+
num_fg = 0.0
|
256 |
+
num_gts = 0.0
|
257 |
+
|
258 |
+
for batch_idx in range(outputs.shape[0]):
|
259 |
+
num_gt = int(nlabel[batch_idx])
|
260 |
+
num_gts += num_gt
|
261 |
+
if num_gt == 0:
|
262 |
+
cls_target = outputs.new_zeros((0, self.num_classes))
|
263 |
+
reg_target = outputs.new_zeros((0, 4))
|
264 |
+
l1_target = outputs.new_zeros((0, 4))
|
265 |
+
obj_target = outputs.new_zeros((total_num_anchors, 1))
|
266 |
+
fg_mask = outputs.new_zeros(total_num_anchors).bool()
|
267 |
+
else:
|
268 |
+
gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
|
269 |
+
gt_classes = labels[batch_idx, :num_gt, 0]
|
270 |
+
bboxes_preds_per_image = bbox_preds[batch_idx]
|
271 |
+
|
272 |
+
try:
|
273 |
+
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments( # noqa
|
274 |
+
batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
|
275 |
+
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
|
276 |
+
cls_preds, bbox_preds, obj_preds, labels, imgs,
|
277 |
+
)
|
278 |
+
except RuntimeError:
|
279 |
+
logger.error(
|
280 |
+
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
|
281 |
+
CPU mode is applied in this batch. If you want to avoid this issue, \
|
282 |
+
try to reduce the batch size or image size."
|
283 |
+
)
|
284 |
+
torch.cuda.empty_cache()
|
285 |
+
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments( # noqa
|
286 |
+
batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
|
287 |
+
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
|
288 |
+
cls_preds, bbox_preds, obj_preds, labels, imgs, "cpu",
|
289 |
+
)
|
290 |
+
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
num_fg += num_fg_img
|
293 |
+
|
294 |
+
cls_target = F.one_hot(
|
295 |
+
gt_matched_classes.to(torch.int64), self.num_classes
|
296 |
+
) * pred_ious_this_matching.unsqueeze(-1)
|
297 |
+
obj_target = fg_mask.unsqueeze(-1)
|
298 |
+
reg_target = gt_bboxes_per_image[matched_gt_inds]
|
299 |
+
if self.use_l1:
|
300 |
+
l1_target = self.get_l1_target(
|
301 |
+
outputs.new_zeros((num_fg_img, 4)),
|
302 |
+
gt_bboxes_per_image[matched_gt_inds],
|
303 |
+
expanded_strides[0][fg_mask],
|
304 |
+
x_shifts=x_shifts[0][fg_mask],
|
305 |
+
y_shifts=y_shifts[0][fg_mask],
|
306 |
+
)
|
307 |
+
|
308 |
+
cls_targets.append(cls_target)
|
309 |
+
reg_targets.append(reg_target)
|
310 |
+
obj_targets.append(obj_target.to(dtype))
|
311 |
+
fg_masks.append(fg_mask)
|
312 |
+
if self.use_l1:
|
313 |
+
l1_targets.append(l1_target)
|
314 |
+
|
315 |
+
cls_targets = torch.cat(cls_targets, 0)
|
316 |
+
reg_targets = torch.cat(reg_targets, 0)
|
317 |
+
obj_targets = torch.cat(obj_targets, 0)
|
318 |
+
fg_masks = torch.cat(fg_masks, 0)
|
319 |
+
if self.use_l1:
|
320 |
+
l1_targets = torch.cat(l1_targets, 0)
|
321 |
+
|
322 |
+
num_fg = max(num_fg, 1)
|
323 |
+
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() / num_fg
|
324 |
+
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() / num_fg
|
325 |
+
loss_cls = (
|
326 |
+
self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)
|
327 |
+
).sum() / num_fg
|
328 |
+
if self.use_l1:
|
329 |
+
loss_l1 = (self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
|
330 |
+
else:
|
331 |
+
loss_l1 = 0.0
|
332 |
+
|
333 |
+
reg_weight = 5.0
|
334 |
+
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
|
335 |
+
|
336 |
+
return loss, reg_weight * loss_iou, loss_obj, loss_cls, loss_l1, num_fg / max(num_gts, 1)
|
337 |
+
|
338 |
+
def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
|
339 |
+
l1_target[:, 0] = gt[:, 0] / stride - x_shifts
|
340 |
+
l1_target[:, 1] = gt[:, 1] / stride - y_shifts
|
341 |
+
l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
|
342 |
+
l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
|
343 |
+
return l1_target
|
344 |
+
|
345 |
+
@torch.no_grad()
|
346 |
+
def get_assignments(
|
347 |
+
self, batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
|
348 |
+
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
|
349 |
+
cls_preds, bbox_preds, obj_preds, labels, imgs, mode="gpu",
|
350 |
+
):
|
351 |
+
|
352 |
+
if mode == "cpu":
|
353 |
+
print("------------CPU Mode for This Batch-------------")
|
354 |
+
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
|
355 |
+
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
|
356 |
+
gt_classes = gt_classes.cpu().float()
|
357 |
+
expanded_strides = expanded_strides.cpu().float()
|
358 |
+
x_shifts = x_shifts.cpu()
|
359 |
+
y_shifts = y_shifts.cpu()
|
360 |
+
|
361 |
+
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
|
362 |
+
gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt,
|
363 |
+
)
|
364 |
+
|
365 |
+
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
366 |
+
cls_preds_ = cls_preds[batch_idx][fg_mask]
|
367 |
+
obj_preds_ = obj_preds[batch_idx][fg_mask]
|
368 |
+
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
|
369 |
+
|
370 |
+
if mode == "cpu":
|
371 |
+
gt_bboxes_per_image = gt_bboxes_per_image.cpu()
|
372 |
+
bboxes_preds_per_image = bboxes_preds_per_image.cpu()
|
373 |
+
|
374 |
+
pair_wise_ious = bboxes_iou(
|
375 |
+
gt_bboxes_per_image, bboxes_preds_per_image, False
|
376 |
+
)
|
377 |
+
|
378 |
+
gt_cls_per_image = (
|
379 |
+
F.one_hot(gt_classes.to(torch.int64), self.num_classes).float()
|
380 |
+
.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
|
381 |
+
)
|
382 |
+
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
383 |
+
|
384 |
+
if mode == "cpu":
|
385 |
+
cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
|
386 |
+
|
387 |
+
cls_preds_ = (
|
388 |
+
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
389 |
+
* obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
390 |
+
)
|
391 |
+
pair_wise_cls_loss = F.binary_cross_entropy(
|
392 |
+
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
|
393 |
+
).sum(-1)
|
394 |
+
del cls_preds_
|
395 |
+
|
396 |
+
cost = (
|
397 |
+
pair_wise_cls_loss
|
398 |
+
+ 3.0 * pair_wise_ious_loss
|
399 |
+
+ 100000.0 * (~is_in_boxes_and_center)
|
400 |
+
)
|
401 |
+
|
402 |
+
(
|
403 |
+
num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
404 |
+
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
405 |
+
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
406 |
+
|
407 |
+
if mode == "cpu":
|
408 |
+
gt_matched_classes = gt_matched_classes.cuda()
|
409 |
+
fg_mask = fg_mask.cuda()
|
410 |
+
pred_ious_this_matching = pred_ious_this_matching.cuda()
|
411 |
+
matched_gt_inds = matched_gt_inds.cuda()
|
412 |
+
|
413 |
+
return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
|
414 |
+
|
415 |
+
def get_in_boxes_info(
|
416 |
+
self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt,
|
417 |
+
):
|
418 |
+
expanded_strides_per_image = expanded_strides[0]
|
419 |
+
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
|
420 |
+
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
|
421 |
+
x_centers_per_image = (
|
422 |
+
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
|
423 |
+
.unsqueeze(0)
|
424 |
+
.repeat(num_gt, 1)
|
425 |
+
) # [n_anchor] -> [n_gt, n_anchor]
|
426 |
+
y_centers_per_image = (
|
427 |
+
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
|
428 |
+
.unsqueeze(0)
|
429 |
+
.repeat(num_gt, 1)
|
430 |
+
)
|
431 |
+
|
432 |
+
gt_bboxes_per_image_l = (
|
433 |
+
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
|
434 |
+
.unsqueeze(1)
|
435 |
+
.repeat(1, total_num_anchors)
|
436 |
+
)
|
437 |
+
gt_bboxes_per_image_r = (
|
438 |
+
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
|
439 |
+
.unsqueeze(1)
|
440 |
+
.repeat(1, total_num_anchors)
|
441 |
+
)
|
442 |
+
gt_bboxes_per_image_t = (
|
443 |
+
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
|
444 |
+
.unsqueeze(1)
|
445 |
+
.repeat(1, total_num_anchors)
|
446 |
+
)
|
447 |
+
gt_bboxes_per_image_b = (
|
448 |
+
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
|
449 |
+
.unsqueeze(1)
|
450 |
+
.repeat(1, total_num_anchors)
|
451 |
+
)
|
452 |
+
|
453 |
+
b_l = x_centers_per_image - gt_bboxes_per_image_l
|
454 |
+
b_r = gt_bboxes_per_image_r - x_centers_per_image
|
455 |
+
b_t = y_centers_per_image - gt_bboxes_per_image_t
|
456 |
+
b_b = gt_bboxes_per_image_b - y_centers_per_image
|
457 |
+
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
|
458 |
+
|
459 |
+
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
460 |
+
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
461 |
+
# in fixed center
|
462 |
+
|
463 |
+
center_radius = 2.5
|
464 |
+
|
465 |
+
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
|
466 |
+
1, total_num_anchors
|
467 |
+
) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
468 |
+
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
|
469 |
+
1, total_num_anchors
|
470 |
+
) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
471 |
+
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
|
472 |
+
1, total_num_anchors
|
473 |
+
) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
474 |
+
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
|
475 |
+
1, total_num_anchors
|
476 |
+
) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
477 |
+
|
478 |
+
c_l = x_centers_per_image - gt_bboxes_per_image_l
|
479 |
+
c_r = gt_bboxes_per_image_r - x_centers_per_image
|
480 |
+
c_t = y_centers_per_image - gt_bboxes_per_image_t
|
481 |
+
c_b = gt_bboxes_per_image_b - y_centers_per_image
|
482 |
+
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
|
483 |
+
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
484 |
+
is_in_centers_all = is_in_centers.sum(dim=0) > 0
|
485 |
+
|
486 |
+
# in boxes and in centers
|
487 |
+
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
488 |
+
|
489 |
+
is_in_boxes_and_center = (
|
490 |
+
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
491 |
+
)
|
492 |
+
return is_in_boxes_anchor, is_in_boxes_and_center
|
493 |
+
|
494 |
+
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
495 |
+
# Dynamic K
|
496 |
+
# ---------------------------------------------------------------
|
497 |
+
matching_matrix = torch.zeros_like(cost)
|
498 |
+
|
499 |
+
ious_in_boxes_matrix = pair_wise_ious
|
500 |
+
n_candidate_k = 10
|
501 |
+
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
502 |
+
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
503 |
+
for gt_idx in range(num_gt):
|
504 |
+
_, pos_idx = torch.topk(
|
505 |
+
cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
|
506 |
+
)
|
507 |
+
matching_matrix[gt_idx][pos_idx] = 1.0
|
508 |
+
|
509 |
+
del topk_ious, dynamic_ks, pos_idx
|
510 |
+
|
511 |
+
anchor_matching_gt = matching_matrix.sum(0)
|
512 |
+
if (anchor_matching_gt > 1).sum() > 0:
|
513 |
+
cost_min, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
514 |
+
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
|
515 |
+
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
|
516 |
+
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
|
517 |
+
num_fg = fg_mask_inboxes.sum().item()
|
518 |
+
|
519 |
+
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
520 |
+
|
521 |
+
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
522 |
+
gt_matched_classes = gt_classes[matched_gt_inds]
|
523 |
+
|
524 |
+
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
|
525 |
+
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
yolox/models/yolo_pafpn.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .darknet import CSPDarknet
|
5 |
+
from .network_blocks import BaseConv, CSPLayer, DWConv
|
6 |
+
|
7 |
+
|
8 |
+
class YOLOPAFPN(nn.Module):
|
9 |
+
"""
|
10 |
+
YOLOv3 model. Darknet 53 is the default backbone of this model.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"),
|
15 |
+
in_channels=[256, 512, 1024], depthwise=False,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.backbone = CSPDarknet(depth, width, depthwise=depthwise)
|
19 |
+
self.in_features = in_features
|
20 |
+
self.in_channels = in_channels
|
21 |
+
Conv = DWConv if depthwise else BaseConv
|
22 |
+
|
23 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
24 |
+
self.lateral_conv0 = BaseConv(
|
25 |
+
int(in_channels[2] * width), int(in_channels[1] * width), 1, 1
|
26 |
+
)
|
27 |
+
self.C3_p4 = CSPLayer(
|
28 |
+
int(2 * in_channels[1] * width),
|
29 |
+
int(in_channels[1] * width),
|
30 |
+
round(3 * depth),
|
31 |
+
False,
|
32 |
+
depthwise=depthwise,
|
33 |
+
) # cat
|
34 |
+
|
35 |
+
self.reduce_conv1 = BaseConv(
|
36 |
+
int(in_channels[1] * width), int(in_channels[0] * width), 1, 1
|
37 |
+
)
|
38 |
+
self.C3_p3 = CSPLayer(
|
39 |
+
int(2 * in_channels[0] * width),
|
40 |
+
int(in_channels[0] * width),
|
41 |
+
round(3 * depth),
|
42 |
+
False,
|
43 |
+
depthwise=depthwise,
|
44 |
+
)
|
45 |
+
|
46 |
+
# bottom-up conv
|
47 |
+
self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2)
|
48 |
+
self.C3_n3 = CSPLayer(
|
49 |
+
int(2 * in_channels[0] * width),
|
50 |
+
int(in_channels[1] * width),
|
51 |
+
round(3 * depth),
|
52 |
+
False,
|
53 |
+
depthwise=depthwise,
|
54 |
+
)
|
55 |
+
|
56 |
+
# bottom-up conv
|
57 |
+
self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2)
|
58 |
+
self.C3_n4 = CSPLayer(
|
59 |
+
int(2 * in_channels[1] * width),
|
60 |
+
int(in_channels[2] * width),
|
61 |
+
round(3 * depth),
|
62 |
+
False,
|
63 |
+
depthwise=depthwise,
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, input):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
inputs: input images.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Tuple[Tensor]: FPN feature.
|
73 |
+
"""
|
74 |
+
|
75 |
+
# backbone
|
76 |
+
out_features = self.backbone(input)
|
77 |
+
features = [out_features[f] for f in self.in_features]
|
78 |
+
[x2, x1, x0] = features
|
79 |
+
|
80 |
+
fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
|
81 |
+
f_out0 = self.upsample(fpn_out0) # 512/16
|
82 |
+
f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
|
83 |
+
f_out0 = self.C3_p4(f_out0) # 1024->512/16
|
84 |
+
|
85 |
+
fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
|
86 |
+
f_out1 = self.upsample(fpn_out1) # 256/8
|
87 |
+
f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
|
88 |
+
pan_out2 = self.C3_p3(f_out1) # 512->256/8
|
89 |
+
|
90 |
+
p_out1 = self.bu_conv2(pan_out2) # 256->256/16
|
91 |
+
p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
|
92 |
+
pan_out1 = self.C3_n3(p_out1) # 512->512/16
|
93 |
+
|
94 |
+
p_out0 = self.bu_conv1(pan_out1) # 512->512/32
|
95 |
+
p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
|
96 |
+
pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
|
97 |
+
|
98 |
+
outputs = (pan_out2, pan_out1, pan_out0)
|
99 |
+
return outputs
|
yolox/models/yolox.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .yolo_head import YOLOXHead
|
5 |
+
from .yolo_pafpn import YOLOPAFPN
|
6 |
+
|
7 |
+
|
8 |
+
class YOLOX(nn.Module):
|
9 |
+
"""
|
10 |
+
YOLOX model module. The module list is defined by create_yolov3_modules function.
|
11 |
+
The network returns loss values from three YOLO layers during training
|
12 |
+
and detection results during test.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, backbone=None, head=None):
|
16 |
+
super().__init__()
|
17 |
+
if backbone is None:
|
18 |
+
backbone = YOLOPAFPN()
|
19 |
+
if head is None:
|
20 |
+
head = YOLOXHead(80)
|
21 |
+
|
22 |
+
self.backbone = backbone
|
23 |
+
self.head = head
|
24 |
+
|
25 |
+
def forward(self, x, targets=None):
|
26 |
+
# fpn output content features of [dark3, dark4, dark5]
|
27 |
+
fpn_outs = self.backbone(x)
|
28 |
+
|
29 |
+
if self.training:
|
30 |
+
assert targets is not None
|
31 |
+
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
|
32 |
+
fpn_outs, targets, x
|
33 |
+
)
|
34 |
+
outputs = {
|
35 |
+
"total_loss": loss,
|
36 |
+
"iou_loss": iou_loss,
|
37 |
+
"l1_loss": l1_loss,
|
38 |
+
"conf_loss": conf_loss,
|
39 |
+
"cls_loss": cls_loss,
|
40 |
+
"num_fg": num_fg,
|
41 |
+
}
|
42 |
+
else:
|
43 |
+
outputs = self.head(fpn_outs)
|
44 |
+
|
45 |
+
return outputs
|
yolox/utils/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
from .allreduce_norm import *
|
5 |
+
from .boxes import *
|
6 |
+
from .checkpoint import load_ckpt, save_checkpoint
|
7 |
+
from .dist import *
|
8 |
+
from .ema import ModelEMA
|
9 |
+
from .logger import setup_logger
|
10 |
+
from .lr_scheduler import LRScheduler
|
11 |
+
from .metric import *
|
12 |
+
from .model_utils import *
|
13 |
+
from .setup_env import *
|
14 |
+
from .visualize import *
|
yolox/utils/allreduce_norm.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import pickle
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import distributed as dist
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from .dist import _get_global_gloo_group, get_world_size
|
12 |
+
|
13 |
+
ASYNC_NORM = (
|
14 |
+
nn.BatchNorm1d,
|
15 |
+
nn.BatchNorm2d,
|
16 |
+
nn.BatchNorm3d,
|
17 |
+
nn.InstanceNorm1d,
|
18 |
+
nn.InstanceNorm2d,
|
19 |
+
nn.InstanceNorm3d,
|
20 |
+
)
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
"get_async_norm_states", "pyobj2tensor", "tensor2pyobj", "all_reduce", "all_reduce_norm"
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def get_async_norm_states(module):
|
28 |
+
async_norm_states = OrderedDict()
|
29 |
+
for name, child in module.named_modules():
|
30 |
+
if isinstance(child, ASYNC_NORM):
|
31 |
+
for k, v in child.state_dict().items():
|
32 |
+
async_norm_states[".".join([name, k])] = v
|
33 |
+
return async_norm_states
|
34 |
+
|
35 |
+
|
36 |
+
def pyobj2tensor(pyobj, device="cuda"):
|
37 |
+
"""serialize picklable python object to tensor"""
|
38 |
+
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
|
39 |
+
return torch.ByteTensor(storage).to(device=device)
|
40 |
+
|
41 |
+
|
42 |
+
def tensor2pyobj(tensor):
|
43 |
+
"""deserialize tensor to picklable python object"""
|
44 |
+
return pickle.loads(tensor.cpu().numpy().tobytes())
|
45 |
+
|
46 |
+
|
47 |
+
def _get_reduce_op(op_name):
|
48 |
+
return {
|
49 |
+
"sum": dist.ReduceOp.SUM,
|
50 |
+
"mean": dist.ReduceOp.SUM,
|
51 |
+
}[op_name.lower()]
|
52 |
+
|
53 |
+
|
54 |
+
def all_reduce(py_dict, op="sum", group=None):
|
55 |
+
"""
|
56 |
+
Apply all reduce function for python dict object.
|
57 |
+
NOTE: make sure that every py_dict has the same keys and values are in the same shape.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
py_dict (dict): dict to apply all reduce op.
|
61 |
+
op (str): operator, could be "sum" or "mean".
|
62 |
+
"""
|
63 |
+
world_size = get_world_size()
|
64 |
+
if world_size == 1:
|
65 |
+
return py_dict
|
66 |
+
if group is None:
|
67 |
+
group = _get_global_gloo_group()
|
68 |
+
if dist.get_world_size(group) == 1:
|
69 |
+
return py_dict
|
70 |
+
|
71 |
+
# all reduce logic across different devices.
|
72 |
+
py_key = list(py_dict.keys())
|
73 |
+
py_key_tensor = pyobj2tensor(py_key)
|
74 |
+
dist.broadcast(py_key_tensor, src=0)
|
75 |
+
py_key = tensor2pyobj(py_key_tensor)
|
76 |
+
|
77 |
+
tensor_shapes = [py_dict[k].shape for k in py_key]
|
78 |
+
tensor_numels = [py_dict[k].numel() for k in py_key]
|
79 |
+
|
80 |
+
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
|
81 |
+
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
|
82 |
+
if op == "mean":
|
83 |
+
flatten_tensor /= world_size
|
84 |
+
|
85 |
+
split_tensors = [
|
86 |
+
x.reshape(shape) for x, shape in zip(
|
87 |
+
torch.split(flatten_tensor, tensor_numels), tensor_shapes
|
88 |
+
)
|
89 |
+
]
|
90 |
+
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
|
91 |
+
|
92 |
+
|
93 |
+
def all_reduce_norm(module):
|
94 |
+
"""
|
95 |
+
All reduce norm statistics in different devices.
|
96 |
+
"""
|
97 |
+
states = get_async_norm_states(module)
|
98 |
+
states = all_reduce(states, op="mean")
|
99 |
+
module.load_state_dict(states, strict=False)
|