Spaces:
Running
Running
Commit
·
181d94d
1
Parent(s):
61c2c74
model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- model/.gitignore +27 -0
- model/LICENSE +12 -0
- model/LICENSE.fuck +13 -0
- model/LICENSE.gen +91 -0
- model/LICENSE.gpl +674 -0
- model/LICENSE.meta +8 -0
- model/LICENSE.mit +22 -0
- model/LICENSE.v1 +13 -0
- model/Makefile +105 -0
- model/README.md +124 -0
- model/examples/art.c +59 -0
- model/examples/attention.c +459 -0
- model/examples/captcha.c +353 -0
- model/examples/cifar.c +251 -0
- model/examples/classifier.c +1098 -0
- model/examples/coco.c +357 -0
- model/examples/darknet.c +503 -0
- model/examples/detector-scipy-opencv.py +56 -0
- model/examples/detector.c +850 -0
- model/examples/detector.py +27 -0
- model/examples/dice.c +116 -0
- model/examples/go.c +1370 -0
- model/examples/instance-segmenter.c +267 -0
- model/examples/lsd.c +1378 -0
- model/examples/nightmare.c +414 -0
- model/examples/regressor.c +240 -0
- model/examples/rnn.c +542 -0
- model/examples/rnn_vid.c +208 -0
- model/examples/segmenter.c +255 -0
- model/examples/super.c +120 -0
- model/examples/swag.c +83 -0
- model/examples/tag.c +140 -0
- model/examples/voxel.c +161 -0
- model/examples/writing.c +144 -0
- model/examples/yolo.c +327 -0
- model/include/darknet.h +805 -0
- model/libdarknet.a +0 -0
- model/libdarknet.so +0 -0
- model/python/darknet.py +156 -0
- model/python/proverbot.py +37 -0
- model/scripts/dice_label.sh +20 -0
- model/scripts/gen_tactic.sh +5 -0
- model/scripts/get_coco_dataset.sh +31 -0
- model/scripts/imagenet_label.sh +15 -0
- model/scripts/voc_label.py +59 -0
- model/src/activation_kernels.cu +206 -0
- model/src/activation_layer.c +63 -0
- model/src/activation_layer.h +19 -0
- model/src/activations.c +150 -0
- model/src/activations.h +87 -0
model/.gitignore
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.o
|
2 |
+
*.dSYM
|
3 |
+
*.csv
|
4 |
+
*.out
|
5 |
+
*.png
|
6 |
+
*.jpg
|
7 |
+
*.pyc
|
8 |
+
old/
|
9 |
+
mnist/
|
10 |
+
data/
|
11 |
+
caffe/
|
12 |
+
grasp/
|
13 |
+
images/
|
14 |
+
opencv/
|
15 |
+
convnet/
|
16 |
+
decaf/
|
17 |
+
submission/
|
18 |
+
cfg/
|
19 |
+
darknet
|
20 |
+
.fuse*
|
21 |
+
|
22 |
+
# OS Generated #
|
23 |
+
.DS_Store*
|
24 |
+
ehthumbs.db
|
25 |
+
Icon?
|
26 |
+
Thumbs.db
|
27 |
+
*.swp
|
model/LICENSE
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
YOLO LICENSE
|
2 |
+
Version 2, July 29 2016
|
3 |
+
|
4 |
+
THIS SOFTWARE LICENSE IS PROVIDED "ALL CAPS" SO THAT YOU KNOW IT IS SUPER
|
5 |
+
SERIOUS AND YOU DON'T MESS AROUND WITH COPYRIGHT LAW BECAUSE YOU WILL GET IN
|
6 |
+
TROUBLE HERE ARE SOME OTHER BUZZWORDS COMMONLY IN THESE THINGS WARRANTIES
|
7 |
+
LIABILITY CONTRACT TORT LIABLE CLAIMS RESTRICTION MERCHANTABILITY. NOW HERE'S
|
8 |
+
THE REAL LICENSE:
|
9 |
+
|
10 |
+
0. Darknet is public domain.
|
11 |
+
1. Do whatever you want with it.
|
12 |
+
2. Stop emailing me about it!
|
model/LICENSE.fuck
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
|
2 |
+
Version 2, December 2004
|
3 |
+
|
4 |
+
Copyright (C) 2004 Sam Hocevar <[email protected]>
|
5 |
+
|
6 |
+
Everyone is permitted to copy and distribute verbatim or modified
|
7 |
+
copies of this license document, and changing it is allowed as long
|
8 |
+
as the name is changed.
|
9 |
+
|
10 |
+
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
|
11 |
+
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
12 |
+
|
13 |
+
0. You just DO WHAT THE FUCK YOU WANT TO.
|
model/LICENSE.gen
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RNN LICENSE Version 3, June 21 2017
|
2 |
+
|
3 |
+
Copyright (c) 1990, 1989, 1999 Free87337 May 48 THIRD PARTIES OR ANY OTHER THE
|
4 |
+
COMPLAIN OR CONSEQUENTIAL DAMAGES AND REGARDLESS OF WHETHER IN CONTRACT, TO THE
|
5 |
+
EXTENT REPAIR OR AGENTS (NOT THE IN ANY EVENT). THE SOFTWARE WILL BE
|
6 |
+
UNINTERRUPTED OR ERROR-FREE OR ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
7 |
+
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
8 |
+
OUT OF THE USE OF ALL THE WORK (GOVERNED CODE) HIM RESPONSES, OR OF FINES,
|
9 |
+
SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR ANY OTHER OR OTHER HARL UNDER NO
|
10 |
+
CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING NEGLIGENCE),
|
11 |
+
PATENT PERMITTED BY THE INSTAGRAM PARENT STATE OR TORT (INCLUDING NEGLIGENCE),
|
12 |
+
PRODUCT LIABILITY OR OTHERWISE, ARISING OUT OF OR IN CONNECTION WITH THE
|
13 |
+
SOFTWARE OR THE USE OR ANYTHING PROVIDED IN THIS PRODUCT, COMMIS AND SERVICES
|
14 |
+
ARE LICENSED SOFTWARE AND ANY RESULE OR ANY OTHER THE COPYRIGHT HOLDERS BE
|
15 |
+
LIABLE FOR ANY SPECIAL, INCIDENTAL, CASE, SUCH WARRANTIES, EXPRESS OR IMPLIED,
|
16 |
+
INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COPYRIGHT HOLDERS AND/OR ANY
|
17 |
+
PERSON FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY
|
18 |
+
EXPRESS OR DISTRIBUTE THAT ALL CLAIMS ARE SHALL CREATE DERAVE BE LIABLE TO YOU
|
19 |
+
WILL HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
20 |
+
|
21 |
+
6\. TERMINATION. TO THE EXTENT PERMITTED BY LAW, NO USE OF THE COVERED CODE IS
|
22 |
+
WITH YOU. SHOULD ANY COVERED CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE
|
23 |
+
INITIAL DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY
|
24 |
+
SERVICING, REPAIR OR COULT OR IN ANY WAY OUT OF THE USE OF THE WEBSITES OR
|
25 |
+
SERVICE WILL BE CONSEQUENTIAL DAMAGES OF ANY KIND HAS BEEN ADVISED OF THE
|
26 |
+
POSSIBILITY OF SUCH DAMAGES.
|
27 |
+
|
28 |
+
|
29 |
+
This paragraph Agreement constitutes the entire agreement between the parties
|
30 |
+
with respect to the Work licensed here. However, if you place the name of the
|
31 |
+
fact that the arbitration was the consultation of the parties as a "patent is".
|
32 |
+
Subject to the terms and conditions of this License, Contributor has knowledge
|
33 |
+
that a license under a third party may also be used to endorse or promote
|
34 |
+
products derived from the Work, and there is no warranty on the Software and
|
35 |
+
Science Fees. For the purposes of this Agreement, attach the following
|
36 |
+
disclaimers (without liabilities of written notice to the Subject Software) in a
|
37 |
+
manner that a product is under common control with you. The Free Software
|
38 |
+
Foundation may publish revised and/or new versions of the License for the
|
39 |
+
Modifications made by the applicable terms. The Recipient shall promptly retain
|
40 |
+
the covered works for any reason be entered in any federal or state or login
|
41 |
+
Restricted Laws appearing in the United States or any of its own information
|
42 |
+
that is not disabled from a derivative work except as expressly permitted in
|
43 |
+
this License, to the extent that they are in receiving the Software and Source
|
44 |
+
Code or any exercise of the rights granted to You by this License or a
|
45 |
+
Contributor made by the Licensor or are authorized to make a reasonable
|
46 |
+
retirement by the courts of the courts located in Santa Clara County, California
|
47 |
+
printed and related to the Work or “Company” and Apache Software Foundation. If
|
48 |
+
the Licensor shall be entitled to reflect your rights to use the Software and
|
49 |
+
the Software to exercise the rights granted to the recipient without a
|
50 |
+
requirement to exercise the rights granted by the Agreement to the provision
|
51 |
+
will begin will appear in such cases, you will use such information without such
|
52 |
+
corporation shall be an officer with respect to any part of the Software or any
|
53 |
+
portion thereof. Capitalized terms are included in the Initial Contributor and
|
54 |
+
under no circumstances will license the Service at any time and for any direct,
|
55 |
+
indirect, special, incidental, or consequential damages of or assist in
|
56 |
+
connection with any Services or the registration purposes only to the extent
|
57 |
+
that it includes any or all means including the processing of which you download
|
58 |
+
any derivative work. Any of the purchases’ transmission purposes are made
|
59 |
+
available, if any, in other circumstances, we may review the copyright notice.
|
60 |
+
In the event that this Agreement is required to give us strict content. The
|
61 |
+
inclusion of the other party hereunder may also notify you Intellectual Property
|
62 |
+
Rights to any third party. This means that the Source Code exists of the Work
|
63 |
+
will not charge a program available to you at any time. You must include a
|
64 |
+
prominent statement that the Software is governed under a particular version of
|
65 |
+
this Agreement. You must include a provision to the extent that there is no
|
66 |
+
warranty for the content of others. You agree that the Recipient was appointed
|
67 |
+
as a Contributor, (c) are effective until terminated by hereunder, then the
|
68 |
+
registration are not disabled and not limited to, submit any Customer Data
|
69 |
+
without the updated use of the Software and that no fee is released. You grant
|
70 |
+
to Use Other Arbitration Rules for Diagnostic or Services may use or modify the
|
71 |
+
Apple Software and Consolidated Apple Software or Services. The Company may have
|
72 |
+
full risk as a product of the Compatible Source. A Contribution by the Licensor
|
73 |
+
or by the updated Software under the following conditions we can redistribute
|
74 |
+
any General Provision of this Agreement. If the Program is used in accordance
|
75 |
+
with the terms of this Agreement, Customer may provide advertisements from your
|
76 |
+
devices that clause you can your employer or a transaction or country that has
|
77 |
+
been controlled by the arbitrator, that they will be useful of this Agreement.
|
78 |
+
The term "Open Source Software is available in connection with the program, and
|
79 |
+
you may not protect the combination of the Covered Code. You should like to
|
80 |
+
select a user's rights to charge a copy of this License. I are Contributor's
|
81 |
+
confidentiality of the exercise of the rights granted herein. Such a covered
|
82 |
+
work is released as a consequence, the Licensor shall be eligible for a purpose
|
83 |
+
or subcontractor of the person or entity to the user of the user, then the word
|
84 |
+
"Application" means having the original fee for any reason; and that no patent
|
85 |
+
license to more than fifty stated close of the license term. The terms of this
|
86 |
+
License will the license terms and conditions set forth in Section 2.2 (OPEC)
|
87 |
+
and You will not use the Software or any set of responsibility for any resulting
|
88 |
+
information that the Original Code warrants that you have the right to disclose
|
89 |
+
these information (or in the notification; or (iii) late use of the software or
|
90 |
+
any third party to the three (50) days before such belief to the extent that it
|
91 |
+
includes a court court obtains the rights granted by this License.
|
model/LICENSE.gpl
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
{one line to give the program's name and a brief idea of what it does.}
|
635 |
+
Copyright (C) {year} {name of author}
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
{project} Copyright (C) {year} {fullname}
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<http://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<http://www.gnu.org/philosophy/why-not-lgpl.html>.
|
model/LICENSE.meta
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
META-LICENSE
|
2 |
+
Version 1, June 21 2017
|
3 |
+
|
4 |
+
Any and all licenses may be applied to the software either individually
|
5 |
+
or in concert. Any issues, ambiguities, paradoxes, or metaphysical quandries
|
6 |
+
arising from this combination should be discussed with a local faith leader,
|
7 |
+
hermit, or guru. The Oxford comma shall be used.
|
8 |
+
|
model/LICENSE.mit
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2017 Joseph Redmon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
model/LICENSE.v1
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
YOLO LICENSE
|
2 |
+
Version 1, July 10 2015
|
3 |
+
|
4 |
+
THIS SOFTWARE LICENSE IS PROVIDED "ALL CAPS" SO THAT YOU KNOW IT IS SUPER
|
5 |
+
SERIOUS AND YOU DON'T MESS AROUND WITH COPYRIGHT LAW BECAUSE YOU WILL GET IN
|
6 |
+
TROUBLE HERE ARE SOME OTHER BUZZWORDS COMMONLY IN THESE THINGS WARRANTIES
|
7 |
+
LIABILITY CONTRACT TORT LIABLE CLAIMS RESTRICTION MERCHANTABILITY SUBJECT TO
|
8 |
+
THE FOLLOWING CONDITIONS:
|
9 |
+
|
10 |
+
1. #yolo
|
11 |
+
2. #swag
|
12 |
+
3. #blazeit
|
13 |
+
|
model/Makefile
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GPU=0
|
2 |
+
CUDNN=0
|
3 |
+
OPENCV=0
|
4 |
+
OPENMP=0
|
5 |
+
DEBUG=0
|
6 |
+
|
7 |
+
ARCH= -gencode arch=compute_30,code=sm_30 \
|
8 |
+
-gencode arch=compute_35,code=sm_35 \
|
9 |
+
-gencode arch=compute_50,code=[sm_50,compute_50] \
|
10 |
+
-gencode arch=compute_52,code=[sm_52,compute_52]
|
11 |
+
# -gencode arch=compute_20,code=[sm_20,sm_21] \ This one is deprecated?
|
12 |
+
|
13 |
+
# This is what I use, uncomment if you know your arch and want to specify
|
14 |
+
# ARCH= -gencode arch=compute_52,code=compute_52
|
15 |
+
|
16 |
+
VPATH=./src/:./examples
|
17 |
+
SLIB=libdarknet.so
|
18 |
+
ALIB=libdarknet.a
|
19 |
+
EXEC=darknet
|
20 |
+
OBJDIR=./obj/
|
21 |
+
|
22 |
+
CC=gcc
|
23 |
+
CPP=g++
|
24 |
+
NVCC=nvcc
|
25 |
+
AR=ar
|
26 |
+
ARFLAGS=rcs
|
27 |
+
OPTS=-Ofast
|
28 |
+
LDFLAGS= -lm -pthread
|
29 |
+
COMMON= -Iinclude/ -Isrc/
|
30 |
+
CFLAGS=-Wall -Wno-unused-result -Wno-unknown-pragmas -Wfatal-errors -fPIC
|
31 |
+
|
32 |
+
ifeq ($(OPENMP), 1)
|
33 |
+
CFLAGS+= -fopenmp
|
34 |
+
endif
|
35 |
+
|
36 |
+
ifeq ($(DEBUG), 1)
|
37 |
+
OPTS=-O0 -g
|
38 |
+
endif
|
39 |
+
|
40 |
+
CFLAGS+=$(OPTS)
|
41 |
+
|
42 |
+
ifeq ($(OPENCV), 1)
|
43 |
+
COMMON+= -DOPENCV
|
44 |
+
CFLAGS+= -DOPENCV
|
45 |
+
LDFLAGS+= `pkg-config --libs opencv` -lstdc++
|
46 |
+
COMMON+= `pkg-config --cflags opencv`
|
47 |
+
endif
|
48 |
+
|
49 |
+
ifeq ($(GPU), 1)
|
50 |
+
COMMON+= -DGPU -I/usr/local/cuda/include/
|
51 |
+
CFLAGS+= -DGPU
|
52 |
+
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
53 |
+
endif
|
54 |
+
|
55 |
+
ifeq ($(CUDNN), 1)
|
56 |
+
COMMON+= -DCUDNN
|
57 |
+
CFLAGS+= -DCUDNN
|
58 |
+
LDFLAGS+= -lcudnn
|
59 |
+
endif
|
60 |
+
|
61 |
+
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o yolo_layer.o iseg_layer.o image_opencv.o
|
62 |
+
EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o instance-segmenter.o darknet.o
|
63 |
+
ifeq ($(GPU), 1)
|
64 |
+
LDFLAGS+= -lstdc++
|
65 |
+
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o avgpool_layer_kernels.o
|
66 |
+
endif
|
67 |
+
|
68 |
+
EXECOBJ = $(addprefix $(OBJDIR), $(EXECOBJA))
|
69 |
+
OBJS = $(addprefix $(OBJDIR), $(OBJ))
|
70 |
+
DEPS = $(wildcard src/*.h) Makefile include/darknet.h
|
71 |
+
|
72 |
+
all: obj backup results $(SLIB) $(ALIB) $(EXEC)
|
73 |
+
#all: obj results $(SLIB) $(ALIB) $(EXEC)
|
74 |
+
|
75 |
+
|
76 |
+
$(EXEC): $(EXECOBJ) $(ALIB)
|
77 |
+
$(CC) $(COMMON) $(CFLAGS) $^ -o $@ $(LDFLAGS) $(ALIB)
|
78 |
+
|
79 |
+
$(ALIB): $(OBJS)
|
80 |
+
$(AR) $(ARFLAGS) $@ $^
|
81 |
+
|
82 |
+
$(SLIB): $(OBJS)
|
83 |
+
$(CC) $(CFLAGS) -shared $^ -o $@ $(LDFLAGS)
|
84 |
+
|
85 |
+
$(OBJDIR)%.o: %.cpp $(DEPS)
|
86 |
+
$(CPP) $(COMMON) $(CFLAGS) -c $< -o $@
|
87 |
+
|
88 |
+
$(OBJDIR)%.o: %.c $(DEPS)
|
89 |
+
$(CC) $(COMMON) $(CFLAGS) -c $< -o $@
|
90 |
+
|
91 |
+
$(OBJDIR)%.o: %.cu $(DEPS)
|
92 |
+
$(NVCC) $(ARCH) $(COMMON) --compiler-options "$(CFLAGS)" -c $< -o $@
|
93 |
+
|
94 |
+
obj:
|
95 |
+
mkdir -p obj
|
96 |
+
backup:
|
97 |
+
mkdir -p backup
|
98 |
+
results:
|
99 |
+
mkdir -p results
|
100 |
+
|
101 |
+
.PHONY: clean
|
102 |
+
|
103 |
+
clean:
|
104 |
+
rm -rf $(OBJS) $(SLIB) $(ALIB) $(EXEC) $(EXECOBJ) $(OBJDIR)/*
|
105 |
+
|
model/README.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
# Darknet #
|
4 |
+
Darknet is an open source neural network framework written in C and CUDA. It is fast, easy to install, and supports CPU and GPU computation.
|
5 |
+
|
6 |
+
**Discord** invite link for for communication and questions: https://discord.gg/zSq8rtW
|
7 |
+
|
8 |
+
## YOLOv7:
|
9 |
+
|
10 |
+
* **paper** - YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors: https://arxiv.org/abs/2207.02696
|
11 |
+
|
12 |
+
* **source code - Pytorch (use to reproduce results):** https://github.com/WongKinYiu/yolov7
|
13 |
+
|
14 |
+
----
|
15 |
+
|
16 |
+
Official YOLOv7 is more accurate and faster than YOLOv5 by **120%** FPS, than YOLOX by **180%** FPS, than Dual-Swin-T by **1200%** FPS, than ConvNext by **550%** FPS, than SWIN-L by **500%** FPS.
|
17 |
+
|
18 |
+
YOLOv7 surpasses all known object detectors in both speed and accuracy in the range from 5 FPS to 160 FPS and has the highest accuracy 56.8% AP among all known real-time object detectors with 30 FPS or higher on GPU V100, batch=1.
|
19 |
+
|
20 |
+
* YOLOv7-e6 (55.9% AP, 56 FPS V100 b=1) by `+500%` FPS faster than SWIN-L Cascade-Mask R-CNN (53.9% AP, 9.2 FPS A100 b=1)
|
21 |
+
* YOLOv7-e6 (55.9% AP, 56 FPS V100 b=1) by `+550%` FPS faster than ConvNeXt-XL C-M-RCNN (55.2% AP, 8.6 FPS A100 b=1)
|
22 |
+
* YOLOv7-w6 (54.6% AP, 84 FPS V100 b=1) by `+120%` FPS faster than YOLOv5-X6-r6.1 (55.0% AP, 38 FPS V100 b=1)
|
23 |
+
* YOLOv7-w6 (54.6% AP, 84 FPS V100 b=1) by `+1200%` FPS faster than Dual-Swin-T C-M-RCNN (53.6% AP, 6.5 FPS V100 b=1)
|
24 |
+
* YOLOv7x (52.9% AP, 114 FPS V100 b=1) by `+150%` FPS faster than PPYOLOE-X (51.9% AP, 45 FPS V100 b=1)
|
25 |
+
* YOLOv7 (51.2% AP, 161 FPS V100 b=1) by `+180%` FPS faster than YOLOX-X (51.1% AP, 58 FPS V100 b=1)
|
26 |
+
|
27 |
+
----
|
28 |
+
|
29 |
+

|
30 |
+
|
31 |
+
----
|
32 |
+
|
33 |
+

|
34 |
+
|
35 |
+
|
36 |
+
----
|
37 |
+
|
38 |
+

|
39 |
+
|
40 |
+
----
|
41 |
+
|
42 |
+
## Scaled-YOLOv4:
|
43 |
+
|
44 |
+
* **paper (CVPR 2021)**: https://openaccess.thecvf.com/content/CVPR2021/html/Wang_Scaled-YOLOv4_Scaling_Cross_Stage_Partial_Network_CVPR_2021_paper.html
|
45 |
+
|
46 |
+
* **source code - Pytorch (use to reproduce results):** https://github.com/WongKinYiu/ScaledYOLOv4
|
47 |
+
|
48 |
+
* **source code - Darknet:** https://github.com/AlexeyAB/darknet
|
49 |
+
|
50 |
+
* **Medium:** https://alexeyab84.medium.com/scaled-yolo-v4-is-the-best-neural-network-for-object-detection-on-ms-coco-dataset-39dfa22fa982?source=friends_link&sk=c8553bfed861b1a7932f739d26f487c8
|
51 |
+
|
52 |
+
## YOLOv4:
|
53 |
+
|
54 |
+
* **paper:** https://arxiv.org/abs/2004.10934
|
55 |
+
|
56 |
+
* **source code:** https://github.com/AlexeyAB/darknet
|
57 |
+
|
58 |
+
* **Wiki:** https://github.com/AlexeyAB/darknet/wiki
|
59 |
+
|
60 |
+
* **useful links:** https://medium.com/@alexeyab84/yolov4-the-most-accurate-real-time-neural-network-on-ms-coco-dataset-73adfd3602fe?source=friends_link&sk=6039748846bbcf1d960c3061542591d7
|
61 |
+
|
62 |
+
For more information see the [Darknet project website](http://pjreddie.com/darknet).
|
63 |
+
|
64 |
+
|
65 |
+
<details><summary> <b>Expand</b> </summary>
|
66 |
+
|
67 |
+
 https://paperswithcode.com/sota/object-detection-on-coco
|
68 |
+
|
69 |
+
----
|
70 |
+
|
71 |
+
 AP50:95 - FPS (Tesla V100) Paper: https://arxiv.org/abs/2011.08036
|
72 |
+
|
73 |
+
----
|
74 |
+
|
75 |
+

|
76 |
+
|
77 |
+
----
|
78 |
+
|
79 |
+

|
80 |
+
|
81 |
+
</details>
|
82 |
+
|
83 |
+
----
|
84 |
+
|
85 |
+

|
86 |
+
|
87 |
+
|
88 |
+
## Citation
|
89 |
+
|
90 |
+
|
91 |
+
```
|
92 |
+
@misc{https://doi.org/10.48550/arxiv.2207.02696,
|
93 |
+
doi = {10.48550/ARXIV.2207.02696},
|
94 |
+
url = {https://arxiv.org/abs/2207.02696},
|
95 |
+
author = {Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark},
|
96 |
+
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
97 |
+
title = {YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors},
|
98 |
+
publisher = {arXiv},
|
99 |
+
year = {2022},
|
100 |
+
copyright = {arXiv.org perpetual, non-exclusive license}
|
101 |
+
}
|
102 |
+
```
|
103 |
+
|
104 |
+
```
|
105 |
+
@misc{bochkovskiy2020yolov4,
|
106 |
+
title={YOLOv4: Optimal Speed and Accuracy of Object Detection},
|
107 |
+
author={Alexey Bochkovskiy and Chien-Yao Wang and Hong-Yuan Mark Liao},
|
108 |
+
year={2020},
|
109 |
+
eprint={2004.10934},
|
110 |
+
archivePrefix={arXiv},
|
111 |
+
primaryClass={cs.CV}
|
112 |
+
}
|
113 |
+
```
|
114 |
+
|
115 |
+
```
|
116 |
+
@InProceedings{Wang_2021_CVPR,
|
117 |
+
author = {Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark},
|
118 |
+
title = {{Scaled-YOLOv4}: Scaling Cross Stage Partial Network},
|
119 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
120 |
+
month = {June},
|
121 |
+
year = {2021},
|
122 |
+
pages = {13029-13038}
|
123 |
+
}
|
124 |
+
```
|
model/examples/art.c
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <sys/time.h>
|
4 |
+
|
5 |
+
void demo_art(char *cfgfile, char *weightfile, int cam_index)
|
6 |
+
{
|
7 |
+
#ifdef OPENCV
|
8 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
9 |
+
set_batch_network(net, 1);
|
10 |
+
|
11 |
+
srand(2222222);
|
12 |
+
|
13 |
+
void * cap = open_video_stream(0, cam_index, 0,0,0);
|
14 |
+
|
15 |
+
char *window = "ArtJudgementBot9000!!!";
|
16 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
17 |
+
int i;
|
18 |
+
int idx[] = {37, 401, 434};
|
19 |
+
int n = sizeof(idx)/sizeof(idx[0]);
|
20 |
+
|
21 |
+
while(1){
|
22 |
+
image in = get_image_from_stream(cap);
|
23 |
+
image in_s = resize_image(in, net->w, net->h);
|
24 |
+
|
25 |
+
float *p = network_predict(net, in_s.data);
|
26 |
+
|
27 |
+
printf("\033[2J");
|
28 |
+
printf("\033[1;1H");
|
29 |
+
|
30 |
+
float score = 0;
|
31 |
+
for(i = 0; i < n; ++i){
|
32 |
+
float s = p[idx[i]];
|
33 |
+
if (s > score) score = s;
|
34 |
+
}
|
35 |
+
score = score;
|
36 |
+
printf("I APPRECIATE THIS ARTWORK: %10.7f%%\n", score*100);
|
37 |
+
printf("[");
|
38 |
+
int upper = 30;
|
39 |
+
for(i = 0; i < upper; ++i){
|
40 |
+
printf("%c", ((i+.5) < score*upper) ? 219 : ' ');
|
41 |
+
}
|
42 |
+
printf("]\n");
|
43 |
+
|
44 |
+
show_image(in, window, 1);
|
45 |
+
free_image(in_s);
|
46 |
+
free_image(in);
|
47 |
+
}
|
48 |
+
#endif
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
void run_art(int argc, char **argv)
|
53 |
+
{
|
54 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
55 |
+
char *cfg = argv[2];
|
56 |
+
char *weights = argv[3];
|
57 |
+
demo_art(cfg, weights, cam_index);
|
58 |
+
}
|
59 |
+
|
model/examples/attention.c
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <sys/time.h>
|
4 |
+
#include <assert.h>
|
5 |
+
|
6 |
+
void extend_data_truth(data *d, int n, float val)
|
7 |
+
{
|
8 |
+
int i, j;
|
9 |
+
for(i = 0; i < d->y.rows; ++i){
|
10 |
+
d->y.vals[i] = realloc(d->y.vals[i], (d->y.cols+n)*sizeof(float));
|
11 |
+
for(j = 0; j < n; ++j){
|
12 |
+
d->y.vals[i][d->y.cols + j] = val;
|
13 |
+
}
|
14 |
+
}
|
15 |
+
d->y.cols += n;
|
16 |
+
}
|
17 |
+
|
18 |
+
matrix network_loss_data(network *net, data test)
|
19 |
+
{
|
20 |
+
int i,b;
|
21 |
+
int k = 1;
|
22 |
+
matrix pred = make_matrix(test.X.rows, k);
|
23 |
+
float *X = calloc(net->batch*test.X.cols, sizeof(float));
|
24 |
+
float *y = calloc(net->batch*test.y.cols, sizeof(float));
|
25 |
+
for(i = 0; i < test.X.rows; i += net->batch){
|
26 |
+
for(b = 0; b < net->batch; ++b){
|
27 |
+
if(i+b == test.X.rows) break;
|
28 |
+
memcpy(X+b*test.X.cols, test.X.vals[i+b], test.X.cols*sizeof(float));
|
29 |
+
memcpy(y+b*test.y.cols, test.y.vals[i+b], test.y.cols*sizeof(float));
|
30 |
+
}
|
31 |
+
|
32 |
+
network orig = *net;
|
33 |
+
net->input = X;
|
34 |
+
net->truth = y;
|
35 |
+
net->train = 0;
|
36 |
+
net->delta = 0;
|
37 |
+
forward_network(net);
|
38 |
+
*net = orig;
|
39 |
+
|
40 |
+
float *delta = net->layers[net->n-1].output;
|
41 |
+
for(b = 0; b < net->batch; ++b){
|
42 |
+
if(i+b == test.X.rows) break;
|
43 |
+
int t = max_index(y + b*test.y.cols, 1000);
|
44 |
+
float err = sum_array(delta + b*net->outputs, net->outputs);
|
45 |
+
pred.vals[i+b][0] = -err;
|
46 |
+
//pred.vals[i+b][0] = 1-delta[b*net->outputs + t];
|
47 |
+
}
|
48 |
+
}
|
49 |
+
free(X);
|
50 |
+
free(y);
|
51 |
+
return pred;
|
52 |
+
}
|
53 |
+
|
54 |
+
void train_attention(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
55 |
+
{
|
56 |
+
int i, j;
|
57 |
+
|
58 |
+
float avg_cls_loss = -1;
|
59 |
+
float avg_att_loss = -1;
|
60 |
+
char *base = basecfg(cfgfile);
|
61 |
+
printf("%s\n", base);
|
62 |
+
printf("%d\n", ngpus);
|
63 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
64 |
+
|
65 |
+
srand(time(0));
|
66 |
+
int seed = rand();
|
67 |
+
for(i = 0; i < ngpus; ++i){
|
68 |
+
srand(seed);
|
69 |
+
#ifdef GPU
|
70 |
+
cuda_set_device(gpus[i]);
|
71 |
+
#endif
|
72 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
73 |
+
nets[i]->learning_rate *= ngpus;
|
74 |
+
}
|
75 |
+
srand(time(0));
|
76 |
+
network *net = nets[0];
|
77 |
+
|
78 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
79 |
+
|
80 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
81 |
+
list *options = read_data_cfg(datacfg);
|
82 |
+
|
83 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
84 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
85 |
+
char *train_list = option_find_str(options, "train", "data/train.list");
|
86 |
+
int classes = option_find_int(options, "classes", 2);
|
87 |
+
|
88 |
+
char **labels = get_labels(label_list);
|
89 |
+
list *plist = get_paths(train_list);
|
90 |
+
char **paths = (char **)list_to_array(plist);
|
91 |
+
printf("%d\n", plist->size);
|
92 |
+
int N = plist->size;
|
93 |
+
double time;
|
94 |
+
|
95 |
+
int divs=3;
|
96 |
+
int size=2;
|
97 |
+
|
98 |
+
load_args args = {0};
|
99 |
+
args.w = divs*net->w/size;
|
100 |
+
args.h = divs*net->h/size;
|
101 |
+
args.size = divs*net->w/size;
|
102 |
+
args.threads = 32;
|
103 |
+
args.hierarchy = net->hierarchy;
|
104 |
+
|
105 |
+
args.min = net->min_ratio*args.w;
|
106 |
+
args.max = net->max_ratio*args.w;
|
107 |
+
args.angle = net->angle;
|
108 |
+
args.aspect = net->aspect;
|
109 |
+
args.exposure = net->exposure;
|
110 |
+
args.saturation = net->saturation;
|
111 |
+
args.hue = net->hue;
|
112 |
+
|
113 |
+
args.paths = paths;
|
114 |
+
args.classes = classes;
|
115 |
+
args.n = imgs;
|
116 |
+
args.m = N;
|
117 |
+
args.labels = labels;
|
118 |
+
args.type = CLASSIFICATION_DATA;
|
119 |
+
|
120 |
+
data train;
|
121 |
+
data buffer;
|
122 |
+
pthread_t load_thread;
|
123 |
+
args.d = &buffer;
|
124 |
+
load_thread = load_data(args);
|
125 |
+
|
126 |
+
int epoch = (*net->seen)/N;
|
127 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
128 |
+
time = what_time_is_it_now();
|
129 |
+
|
130 |
+
pthread_join(load_thread, 0);
|
131 |
+
train = buffer;
|
132 |
+
load_thread = load_data(args);
|
133 |
+
data resized = resize_data(train, net->w, net->h);
|
134 |
+
extend_data_truth(&resized, divs*divs, 0);
|
135 |
+
data *tiles = tile_data(train, divs, size);
|
136 |
+
|
137 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
138 |
+
time = what_time_is_it_now();
|
139 |
+
|
140 |
+
float aloss = 0;
|
141 |
+
float closs = 0;
|
142 |
+
int z;
|
143 |
+
for (i = 0; i < divs*divs/ngpus; ++i) {
|
144 |
+
#pragma omp parallel for
|
145 |
+
for(j = 0; j < ngpus; ++j){
|
146 |
+
int index = i*ngpus + j;
|
147 |
+
extend_data_truth(tiles+index, divs*divs, SECRET_NUM);
|
148 |
+
matrix deltas = network_loss_data(nets[j], tiles[index]);
|
149 |
+
for(z = 0; z < resized.y.rows; ++z){
|
150 |
+
resized.y.vals[z][train.y.cols + index] = deltas.vals[z][0];
|
151 |
+
}
|
152 |
+
free_matrix(deltas);
|
153 |
+
}
|
154 |
+
}
|
155 |
+
int *inds = calloc(resized.y.rows, sizeof(int));
|
156 |
+
for(z = 0; z < resized.y.rows; ++z){
|
157 |
+
int index = max_index(resized.y.vals[z] + train.y.cols, divs*divs);
|
158 |
+
inds[z] = index;
|
159 |
+
for(i = 0; i < divs*divs; ++i){
|
160 |
+
resized.y.vals[z][train.y.cols + i] = (i == index)? 1 : 0;
|
161 |
+
}
|
162 |
+
}
|
163 |
+
data best = select_data(tiles, inds);
|
164 |
+
free(inds);
|
165 |
+
#ifdef GPU
|
166 |
+
if (ngpus == 1) {
|
167 |
+
closs = train_network(net, best);
|
168 |
+
} else {
|
169 |
+
closs = train_networks(nets, ngpus, best, 4);
|
170 |
+
}
|
171 |
+
#endif
|
172 |
+
for (i = 0; i < divs*divs; ++i) {
|
173 |
+
printf("%.2f ", resized.y.vals[0][train.y.cols + i]);
|
174 |
+
if((i+1)%divs == 0) printf("\n");
|
175 |
+
free_data(tiles[i]);
|
176 |
+
}
|
177 |
+
free_data(best);
|
178 |
+
printf("\n");
|
179 |
+
image im = float_to_image(64,64,3,resized.X.vals[0]);
|
180 |
+
//show_image(im, "orig");
|
181 |
+
//cvWaitKey(100);
|
182 |
+
/*
|
183 |
+
image im1 = float_to_image(64,64,3,tiles[i].X.vals[0]);
|
184 |
+
image im2 = float_to_image(64,64,3,resized.X.vals[0]);
|
185 |
+
show_image(im1, "tile");
|
186 |
+
show_image(im2, "res");
|
187 |
+
*/
|
188 |
+
#ifdef GPU
|
189 |
+
if (ngpus == 1) {
|
190 |
+
aloss = train_network(net, resized);
|
191 |
+
} else {
|
192 |
+
aloss = train_networks(nets, ngpus, resized, 4);
|
193 |
+
}
|
194 |
+
#endif
|
195 |
+
for(i = 0; i < divs*divs; ++i){
|
196 |
+
printf("%f ", nets[0]->output[1000 + i]);
|
197 |
+
if ((i+1) % divs == 0) printf("\n");
|
198 |
+
}
|
199 |
+
printf("\n");
|
200 |
+
|
201 |
+
free_data(resized);
|
202 |
+
free_data(train);
|
203 |
+
if(avg_cls_loss == -1) avg_cls_loss = closs;
|
204 |
+
if(avg_att_loss == -1) avg_att_loss = aloss;
|
205 |
+
avg_cls_loss = avg_cls_loss*.9 + closs*.1;
|
206 |
+
avg_att_loss = avg_att_loss*.9 + aloss*.1;
|
207 |
+
|
208 |
+
printf("%ld, %.3f: Att: %f, %f avg, Class: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, aloss, avg_att_loss, closs, avg_cls_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
209 |
+
if(*net->seen/N > epoch){
|
210 |
+
epoch = *net->seen/N;
|
211 |
+
char buff[256];
|
212 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
213 |
+
save_weights(net, buff);
|
214 |
+
}
|
215 |
+
if(get_current_batch(net)%1000 == 0){
|
216 |
+
char buff[256];
|
217 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
218 |
+
save_weights(net, buff);
|
219 |
+
}
|
220 |
+
}
|
221 |
+
char buff[256];
|
222 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
223 |
+
save_weights(net, buff);
|
224 |
+
pthread_join(load_thread, 0);
|
225 |
+
|
226 |
+
free_network(net);
|
227 |
+
free_ptrs((void**)labels, classes);
|
228 |
+
free_ptrs((void**)paths, plist->size);
|
229 |
+
free_list(plist);
|
230 |
+
free(base);
|
231 |
+
}
|
232 |
+
|
233 |
+
void validate_attention_single(char *datacfg, char *filename, char *weightfile)
|
234 |
+
{
|
235 |
+
int i, j;
|
236 |
+
network *net = load_network(filename, weightfile, 0);
|
237 |
+
set_batch_network(net, 1);
|
238 |
+
srand(time(0));
|
239 |
+
|
240 |
+
list *options = read_data_cfg(datacfg);
|
241 |
+
|
242 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
243 |
+
char *leaf_list = option_find_str(options, "leaves", 0);
|
244 |
+
if(leaf_list) change_leaves(net->hierarchy, leaf_list);
|
245 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
246 |
+
int classes = option_find_int(options, "classes", 2);
|
247 |
+
int topk = option_find_int(options, "top", 1);
|
248 |
+
|
249 |
+
char **labels = get_labels(label_list);
|
250 |
+
list *plist = get_paths(valid_list);
|
251 |
+
|
252 |
+
char **paths = (char **)list_to_array(plist);
|
253 |
+
int m = plist->size;
|
254 |
+
free_list(plist);
|
255 |
+
|
256 |
+
float avg_acc = 0;
|
257 |
+
float avg_topk = 0;
|
258 |
+
int *indexes = calloc(topk, sizeof(int));
|
259 |
+
int divs = 4;
|
260 |
+
int size = 2;
|
261 |
+
int extra = 0;
|
262 |
+
float *avgs = calloc(classes, sizeof(float));
|
263 |
+
int *inds = calloc(divs*divs, sizeof(int));
|
264 |
+
|
265 |
+
for(i = 0; i < m; ++i){
|
266 |
+
int class = -1;
|
267 |
+
char *path = paths[i];
|
268 |
+
for(j = 0; j < classes; ++j){
|
269 |
+
if(strstr(path, labels[j])){
|
270 |
+
class = j;
|
271 |
+
break;
|
272 |
+
}
|
273 |
+
}
|
274 |
+
image im = load_image_color(paths[i], 0, 0);
|
275 |
+
image resized = resize_min(im, net->w*divs/size);
|
276 |
+
image crop = crop_image(resized, (resized.w - net->w*divs/size)/2, (resized.h - net->h*divs/size)/2, net->w*divs/size, net->h*divs/size);
|
277 |
+
image rcrop = resize_image(crop, net->w, net->h);
|
278 |
+
//show_image(im, "orig");
|
279 |
+
//show_image(crop, "cropped");
|
280 |
+
//cvWaitKey(0);
|
281 |
+
float *pred = network_predict(net, rcrop.data);
|
282 |
+
//pred[classes + 56] = 0;
|
283 |
+
for(j = 0; j < divs*divs; ++j){
|
284 |
+
printf("%.2f ", pred[classes + j]);
|
285 |
+
if((j+1)%divs == 0) printf("\n");
|
286 |
+
}
|
287 |
+
printf("\n");
|
288 |
+
copy_cpu(classes, pred, 1, avgs, 1);
|
289 |
+
top_k(pred + classes, divs*divs, divs*divs, inds);
|
290 |
+
show_image(crop, "crop");
|
291 |
+
for(j = 0; j < extra; ++j){
|
292 |
+
int index = inds[j];
|
293 |
+
int row = index / divs;
|
294 |
+
int col = index % divs;
|
295 |
+
int y = row * crop.h / divs - (net->h - crop.h/divs)/2;
|
296 |
+
int x = col * crop.w / divs - (net->w - crop.w/divs)/2;
|
297 |
+
printf("%d %d %d %d\n", row, col, y, x);
|
298 |
+
image tile = crop_image(crop, x, y, net->w, net->h);
|
299 |
+
float *pred = network_predict(net, tile.data);
|
300 |
+
axpy_cpu(classes, 1., pred, 1, avgs, 1);
|
301 |
+
show_image(tile, "tile");
|
302 |
+
//cvWaitKey(10);
|
303 |
+
}
|
304 |
+
if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 1, 1);
|
305 |
+
|
306 |
+
if(rcrop.data != resized.data) free_image(rcrop);
|
307 |
+
if(resized.data != im.data) free_image(resized);
|
308 |
+
free_image(im);
|
309 |
+
free_image(crop);
|
310 |
+
top_k(pred, classes, topk, indexes);
|
311 |
+
|
312 |
+
if(indexes[0] == class) avg_acc += 1;
|
313 |
+
for(j = 0; j < topk; ++j){
|
314 |
+
if(indexes[j] == class) avg_topk += 1;
|
315 |
+
}
|
316 |
+
|
317 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
318 |
+
}
|
319 |
+
}
|
320 |
+
|
321 |
+
void validate_attention_multi(char *datacfg, char *filename, char *weightfile)
|
322 |
+
{
|
323 |
+
int i, j;
|
324 |
+
network *net = load_network(filename, weightfile, 0);
|
325 |
+
set_batch_network(net, 1);
|
326 |
+
srand(time(0));
|
327 |
+
|
328 |
+
list *options = read_data_cfg(datacfg);
|
329 |
+
|
330 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
331 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
332 |
+
int classes = option_find_int(options, "classes", 2);
|
333 |
+
int topk = option_find_int(options, "top", 1);
|
334 |
+
|
335 |
+
char **labels = get_labels(label_list);
|
336 |
+
list *plist = get_paths(valid_list);
|
337 |
+
int scales[] = {224, 288, 320, 352, 384};
|
338 |
+
int nscales = sizeof(scales)/sizeof(scales[0]);
|
339 |
+
|
340 |
+
char **paths = (char **)list_to_array(plist);
|
341 |
+
int m = plist->size;
|
342 |
+
free_list(plist);
|
343 |
+
|
344 |
+
float avg_acc = 0;
|
345 |
+
float avg_topk = 0;
|
346 |
+
int *indexes = calloc(topk, sizeof(int));
|
347 |
+
|
348 |
+
for(i = 0; i < m; ++i){
|
349 |
+
int class = -1;
|
350 |
+
char *path = paths[i];
|
351 |
+
for(j = 0; j < classes; ++j){
|
352 |
+
if(strstr(path, labels[j])){
|
353 |
+
class = j;
|
354 |
+
break;
|
355 |
+
}
|
356 |
+
}
|
357 |
+
float *pred = calloc(classes, sizeof(float));
|
358 |
+
image im = load_image_color(paths[i], 0, 0);
|
359 |
+
for(j = 0; j < nscales; ++j){
|
360 |
+
image r = resize_min(im, scales[j]);
|
361 |
+
resize_network(net, r.w, r.h);
|
362 |
+
float *p = network_predict(net, r.data);
|
363 |
+
if(net->hierarchy) hierarchy_predictions(p, net->outputs, net->hierarchy, 1 , 1);
|
364 |
+
axpy_cpu(classes, 1, p, 1, pred, 1);
|
365 |
+
flip_image(r);
|
366 |
+
p = network_predict(net, r.data);
|
367 |
+
axpy_cpu(classes, 1, p, 1, pred, 1);
|
368 |
+
if(r.data != im.data) free_image(r);
|
369 |
+
}
|
370 |
+
free_image(im);
|
371 |
+
top_k(pred, classes, topk, indexes);
|
372 |
+
free(pred);
|
373 |
+
if(indexes[0] == class) avg_acc += 1;
|
374 |
+
for(j = 0; j < topk; ++j){
|
375 |
+
if(indexes[j] == class) avg_topk += 1;
|
376 |
+
}
|
377 |
+
|
378 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
379 |
+
}
|
380 |
+
}
|
381 |
+
|
382 |
+
void predict_attention(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
|
383 |
+
{
|
384 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
385 |
+
set_batch_network(net, 1);
|
386 |
+
srand(2222222);
|
387 |
+
|
388 |
+
list *options = read_data_cfg(datacfg);
|
389 |
+
|
390 |
+
char *name_list = option_find_str(options, "names", 0);
|
391 |
+
if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
|
392 |
+
if(top == 0) top = option_find_int(options, "top", 1);
|
393 |
+
|
394 |
+
int i = 0;
|
395 |
+
char **names = get_labels(name_list);
|
396 |
+
clock_t time;
|
397 |
+
int *indexes = calloc(top, sizeof(int));
|
398 |
+
char buff[256];
|
399 |
+
char *input = buff;
|
400 |
+
while(1){
|
401 |
+
if(filename){
|
402 |
+
strncpy(input, filename, 256);
|
403 |
+
}else{
|
404 |
+
printf("Enter Image Path: ");
|
405 |
+
fflush(stdout);
|
406 |
+
input = fgets(input, 256, stdin);
|
407 |
+
if(!input) return;
|
408 |
+
strtok(input, "\n");
|
409 |
+
}
|
410 |
+
image im = load_image_color(input, 0, 0);
|
411 |
+
image r = letterbox_image(im, net->w, net->h);
|
412 |
+
//resize_network(&net, r.w, r.h);
|
413 |
+
//printf("%d %d\n", r.w, r.h);
|
414 |
+
|
415 |
+
float *X = r.data;
|
416 |
+
time=clock();
|
417 |
+
float *predictions = network_predict(net, X);
|
418 |
+
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
|
419 |
+
top_k(predictions, net->outputs, top, indexes);
|
420 |
+
fprintf(stderr, "%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
421 |
+
for(i = 0; i < top; ++i){
|
422 |
+
int index = indexes[i];
|
423 |
+
//if(net->hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net->hierarchy->parent[index] >= 0) ? names[net->hierarchy->parent[index]] : "Root");
|
424 |
+
//else printf("%s: %f\n",names[index], predictions[index]);
|
425 |
+
printf("%5.2f%%: %s\n", predictions[index]*100, names[index]);
|
426 |
+
}
|
427 |
+
if(r.data != im.data) free_image(r);
|
428 |
+
free_image(im);
|
429 |
+
if (filename) break;
|
430 |
+
}
|
431 |
+
}
|
432 |
+
|
433 |
+
|
434 |
+
void run_attention(int argc, char **argv)
|
435 |
+
{
|
436 |
+
if(argc < 4){
|
437 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
438 |
+
return;
|
439 |
+
}
|
440 |
+
|
441 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
442 |
+
int ngpus;
|
443 |
+
int *gpus = read_intlist(gpu_list, &ngpus, gpu_index);
|
444 |
+
|
445 |
+
|
446 |
+
int top = find_int_arg(argc, argv, "-t", 0);
|
447 |
+
int clear = find_arg(argc, argv, "-clear");
|
448 |
+
char *data = argv[3];
|
449 |
+
char *cfg = argv[4];
|
450 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
451 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
452 |
+
char *layer_s = (argc > 7) ? argv[7]: 0;
|
453 |
+
if(0==strcmp(argv[2], "predict")) predict_attention(data, cfg, weights, filename, top);
|
454 |
+
else if(0==strcmp(argv[2], "train")) train_attention(data, cfg, weights, gpus, ngpus, clear);
|
455 |
+
else if(0==strcmp(argv[2], "valid")) validate_attention_single(data, cfg, weights);
|
456 |
+
else if(0==strcmp(argv[2], "validmulti")) validate_attention_multi(data, cfg, weights);
|
457 |
+
}
|
458 |
+
|
459 |
+
|
model/examples/captcha.c
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void fix_data_captcha(data d, int mask)
|
4 |
+
{
|
5 |
+
matrix labels = d.y;
|
6 |
+
int i, j;
|
7 |
+
for(i = 0; i < d.y.rows; ++i){
|
8 |
+
for(j = 0; j < d.y.cols; j += 2){
|
9 |
+
if (mask){
|
10 |
+
if(!labels.vals[i][j]){
|
11 |
+
labels.vals[i][j] = SECRET_NUM;
|
12 |
+
labels.vals[i][j+1] = SECRET_NUM;
|
13 |
+
}else if(labels.vals[i][j+1]){
|
14 |
+
labels.vals[i][j] = 0;
|
15 |
+
}
|
16 |
+
} else{
|
17 |
+
if (labels.vals[i][j]) {
|
18 |
+
labels.vals[i][j+1] = 0;
|
19 |
+
} else {
|
20 |
+
labels.vals[i][j+1] = 1;
|
21 |
+
}
|
22 |
+
}
|
23 |
+
}
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
void train_captcha(char *cfgfile, char *weightfile)
|
28 |
+
{
|
29 |
+
srand(time(0));
|
30 |
+
float avg_loss = -1;
|
31 |
+
char *base = basecfg(cfgfile);
|
32 |
+
printf("%s\n", base);
|
33 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
34 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
35 |
+
int imgs = 1024;
|
36 |
+
int i = *net->seen/imgs;
|
37 |
+
int solved = 1;
|
38 |
+
list *plist;
|
39 |
+
char **labels = get_labels("/data/captcha/reimgs.labels.list");
|
40 |
+
if (solved){
|
41 |
+
plist = get_paths("/data/captcha/reimgs.solved.list");
|
42 |
+
}else{
|
43 |
+
plist = get_paths("/data/captcha/reimgs.raw.list");
|
44 |
+
}
|
45 |
+
char **paths = (char **)list_to_array(plist);
|
46 |
+
printf("%d\n", plist->size);
|
47 |
+
clock_t time;
|
48 |
+
pthread_t load_thread;
|
49 |
+
data train;
|
50 |
+
data buffer;
|
51 |
+
|
52 |
+
load_args args = {0};
|
53 |
+
args.w = net->w;
|
54 |
+
args.h = net->h;
|
55 |
+
args.paths = paths;
|
56 |
+
args.classes = 26;
|
57 |
+
args.n = imgs;
|
58 |
+
args.m = plist->size;
|
59 |
+
args.labels = labels;
|
60 |
+
args.d = &buffer;
|
61 |
+
args.type = CLASSIFICATION_DATA;
|
62 |
+
|
63 |
+
load_thread = load_data_in_thread(args);
|
64 |
+
while(1){
|
65 |
+
++i;
|
66 |
+
time=clock();
|
67 |
+
pthread_join(load_thread, 0);
|
68 |
+
train = buffer;
|
69 |
+
fix_data_captcha(train, solved);
|
70 |
+
|
71 |
+
/*
|
72 |
+
image im = float_to_image(256, 256, 3, train.X.vals[114]);
|
73 |
+
show_image(im, "training");
|
74 |
+
cvWaitKey(0);
|
75 |
+
*/
|
76 |
+
|
77 |
+
load_thread = load_data_in_thread(args);
|
78 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
79 |
+
time=clock();
|
80 |
+
float loss = train_network(net, train);
|
81 |
+
if(avg_loss == -1) avg_loss = loss;
|
82 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
83 |
+
printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net->seen);
|
84 |
+
free_data(train);
|
85 |
+
if(i%100==0){
|
86 |
+
char buff[256];
|
87 |
+
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
88 |
+
save_weights(net, buff);
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
void test_captcha(char *cfgfile, char *weightfile, char *filename)
|
94 |
+
{
|
95 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
96 |
+
set_batch_network(net, 1);
|
97 |
+
srand(2222222);
|
98 |
+
int i = 0;
|
99 |
+
char **names = get_labels("/data/captcha/reimgs.labels.list");
|
100 |
+
char buff[256];
|
101 |
+
char *input = buff;
|
102 |
+
int indexes[26];
|
103 |
+
while(1){
|
104 |
+
if(filename){
|
105 |
+
strncpy(input, filename, 256);
|
106 |
+
}else{
|
107 |
+
//printf("Enter Image Path: ");
|
108 |
+
//fflush(stdout);
|
109 |
+
input = fgets(input, 256, stdin);
|
110 |
+
if(!input) return;
|
111 |
+
strtok(input, "\n");
|
112 |
+
}
|
113 |
+
image im = load_image_color(input, net->w, net->h);
|
114 |
+
float *X = im.data;
|
115 |
+
float *predictions = network_predict(net, X);
|
116 |
+
top_predictions(net, 26, indexes);
|
117 |
+
//printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
118 |
+
for(i = 0; i < 26; ++i){
|
119 |
+
int index = indexes[i];
|
120 |
+
if(i != 0) printf(", ");
|
121 |
+
printf("%s %f", names[index], predictions[index]);
|
122 |
+
}
|
123 |
+
printf("\n");
|
124 |
+
fflush(stdout);
|
125 |
+
free_image(im);
|
126 |
+
if (filename) break;
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
void valid_captcha(char *cfgfile, char *weightfile, char *filename)
|
131 |
+
{
|
132 |
+
char **labels = get_labels("/data/captcha/reimgs.labels.list");
|
133 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
134 |
+
list *plist = get_paths("/data/captcha/reimgs.fg.list");
|
135 |
+
char **paths = (char **)list_to_array(plist);
|
136 |
+
int N = plist->size;
|
137 |
+
int outputs = net->outputs;
|
138 |
+
|
139 |
+
set_batch_network(net, 1);
|
140 |
+
srand(2222222);
|
141 |
+
int i, j;
|
142 |
+
for(i = 0; i < N; ++i){
|
143 |
+
if (i%100 == 0) fprintf(stderr, "%d\n", i);
|
144 |
+
image im = load_image_color(paths[i], net->w, net->h);
|
145 |
+
float *X = im.data;
|
146 |
+
float *predictions = network_predict(net, X);
|
147 |
+
//printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
148 |
+
int truth = -1;
|
149 |
+
for(j = 0; j < 13; ++j){
|
150 |
+
if (strstr(paths[i], labels[j])) truth = j;
|
151 |
+
}
|
152 |
+
if (truth == -1){
|
153 |
+
fprintf(stderr, "bad: %s\n", paths[i]);
|
154 |
+
return;
|
155 |
+
}
|
156 |
+
printf("%d, ", truth);
|
157 |
+
for(j = 0; j < outputs; ++j){
|
158 |
+
if (j != 0) printf(", ");
|
159 |
+
printf("%f", predictions[j]);
|
160 |
+
}
|
161 |
+
printf("\n");
|
162 |
+
fflush(stdout);
|
163 |
+
free_image(im);
|
164 |
+
if (filename) break;
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
/*
|
169 |
+
void train_captcha(char *cfgfile, char *weightfile)
|
170 |
+
{
|
171 |
+
float avg_loss = -1;
|
172 |
+
srand(time(0));
|
173 |
+
char *base = basecfg(cfgfile);
|
174 |
+
printf("%s\n", base);
|
175 |
+
network net = parse_network_cfg(cfgfile);
|
176 |
+
if(weightfile){
|
177 |
+
load_weights(&net, weightfile);
|
178 |
+
}
|
179 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
180 |
+
int imgs = 1024;
|
181 |
+
int i = net->seen/imgs;
|
182 |
+
list *plist = get_paths("/data/captcha/train.auto5");
|
183 |
+
char **paths = (char **)list_to_array(plist);
|
184 |
+
printf("%d\n", plist->size);
|
185 |
+
clock_t time;
|
186 |
+
while(1){
|
187 |
+
++i;
|
188 |
+
time=clock();
|
189 |
+
data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
|
190 |
+
translate_data_rows(train, -128);
|
191 |
+
scale_data_rows(train, 1./128);
|
192 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
193 |
+
time=clock();
|
194 |
+
float loss = train_network(net, train);
|
195 |
+
net->seen += imgs;
|
196 |
+
if(avg_loss == -1) avg_loss = loss;
|
197 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
198 |
+
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
|
199 |
+
free_data(train);
|
200 |
+
if(i%10==0){
|
201 |
+
char buff[256];
|
202 |
+
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
203 |
+
save_weights(net, buff);
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
|
208 |
+
void decode_captcha(char *cfgfile, char *weightfile)
|
209 |
+
{
|
210 |
+
setbuf(stdout, NULL);
|
211 |
+
srand(time(0));
|
212 |
+
network net = parse_network_cfg(cfgfile);
|
213 |
+
set_batch_network(&net, 1);
|
214 |
+
if(weightfile){
|
215 |
+
load_weights(&net, weightfile);
|
216 |
+
}
|
217 |
+
char filename[256];
|
218 |
+
while(1){
|
219 |
+
printf("Enter filename: ");
|
220 |
+
fgets(filename, 256, stdin);
|
221 |
+
strtok(filename, "\n");
|
222 |
+
image im = load_image_color(filename, 300, 57);
|
223 |
+
scale_image(im, 1./255.);
|
224 |
+
float *X = im.data;
|
225 |
+
float *predictions = network_predict(net, X);
|
226 |
+
image out = float_to_image(300, 57, 1, predictions);
|
227 |
+
show_image(out, "decoded");
|
228 |
+
#ifdef OPENCV
|
229 |
+
cvWaitKey(0);
|
230 |
+
#endif
|
231 |
+
free_image(im);
|
232 |
+
}
|
233 |
+
}
|
234 |
+
|
235 |
+
void encode_captcha(char *cfgfile, char *weightfile)
|
236 |
+
{
|
237 |
+
float avg_loss = -1;
|
238 |
+
srand(time(0));
|
239 |
+
char *base = basecfg(cfgfile);
|
240 |
+
printf("%s\n", base);
|
241 |
+
network net = parse_network_cfg(cfgfile);
|
242 |
+
if(weightfile){
|
243 |
+
load_weights(&net, weightfile);
|
244 |
+
}
|
245 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
246 |
+
int imgs = 1024;
|
247 |
+
int i = net->seen/imgs;
|
248 |
+
list *plist = get_paths("/data/captcha/encode.list");
|
249 |
+
char **paths = (char **)list_to_array(plist);
|
250 |
+
printf("%d\n", plist->size);
|
251 |
+
clock_t time;
|
252 |
+
while(1){
|
253 |
+
++i;
|
254 |
+
time=clock();
|
255 |
+
data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
|
256 |
+
scale_data_rows(train, 1./255);
|
257 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
258 |
+
time=clock();
|
259 |
+
float loss = train_network(net, train);
|
260 |
+
net->seen += imgs;
|
261 |
+
if(avg_loss == -1) avg_loss = loss;
|
262 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
263 |
+
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
|
264 |
+
free_matrix(train.X);
|
265 |
+
if(i%100==0){
|
266 |
+
char buff[256];
|
267 |
+
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
268 |
+
save_weights(net, buff);
|
269 |
+
}
|
270 |
+
}
|
271 |
+
}
|
272 |
+
|
273 |
+
void validate_captcha(char *cfgfile, char *weightfile)
|
274 |
+
{
|
275 |
+
srand(time(0));
|
276 |
+
char *base = basecfg(cfgfile);
|
277 |
+
printf("%s\n", base);
|
278 |
+
network net = parse_network_cfg(cfgfile);
|
279 |
+
if(weightfile){
|
280 |
+
load_weights(&net, weightfile);
|
281 |
+
}
|
282 |
+
int numchars = 37;
|
283 |
+
list *plist = get_paths("/data/captcha/solved.hard");
|
284 |
+
char **paths = (char **)list_to_array(plist);
|
285 |
+
int imgs = plist->size;
|
286 |
+
data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
|
287 |
+
translate_data_rows(valid, -128);
|
288 |
+
scale_data_rows(valid, 1./128);
|
289 |
+
matrix pred = network_predict_data(net, valid);
|
290 |
+
int i, k;
|
291 |
+
int correct = 0;
|
292 |
+
int total = 0;
|
293 |
+
int accuracy = 0;
|
294 |
+
for(i = 0; i < imgs; ++i){
|
295 |
+
int allcorrect = 1;
|
296 |
+
for(k = 0; k < 10; ++k){
|
297 |
+
char truth = int_to_alphanum(max_index(valid.y.vals[i]+k*numchars, numchars));
|
298 |
+
char prediction = int_to_alphanum(max_index(pred.vals[i]+k*numchars, numchars));
|
299 |
+
if (truth != prediction) allcorrect=0;
|
300 |
+
if (truth != '.' && truth == prediction) ++correct;
|
301 |
+
if (truth != '.' || truth != prediction) ++total;
|
302 |
+
}
|
303 |
+
accuracy += allcorrect;
|
304 |
+
}
|
305 |
+
printf("Word Accuracy: %f, Char Accuracy %f\n", (float)accuracy/imgs, (float)correct/total);
|
306 |
+
free_data(valid);
|
307 |
+
}
|
308 |
+
|
309 |
+
void test_captcha(char *cfgfile, char *weightfile)
|
310 |
+
{
|
311 |
+
setbuf(stdout, NULL);
|
312 |
+
srand(time(0));
|
313 |
+
//char *base = basecfg(cfgfile);
|
314 |
+
//printf("%s\n", base);
|
315 |
+
network net = parse_network_cfg(cfgfile);
|
316 |
+
set_batch_network(&net, 1);
|
317 |
+
if(weightfile){
|
318 |
+
load_weights(&net, weightfile);
|
319 |
+
}
|
320 |
+
char filename[256];
|
321 |
+
while(1){
|
322 |
+
//printf("Enter filename: ");
|
323 |
+
fgets(filename, 256, stdin);
|
324 |
+
strtok(filename, "\n");
|
325 |
+
image im = load_image_color(filename, 200, 60);
|
326 |
+
translate_image(im, -128);
|
327 |
+
scale_image(im, 1/128.);
|
328 |
+
float *X = im.data;
|
329 |
+
float *predictions = network_predict(net, X);
|
330 |
+
print_letters(predictions, 10);
|
331 |
+
free_image(im);
|
332 |
+
}
|
333 |
+
}
|
334 |
+
*/
|
335 |
+
void run_captcha(int argc, char **argv)
|
336 |
+
{
|
337 |
+
if(argc < 4){
|
338 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
339 |
+
return;
|
340 |
+
}
|
341 |
+
|
342 |
+
char *cfg = argv[3];
|
343 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
344 |
+
char *filename = (argc > 5) ? argv[5]: 0;
|
345 |
+
if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
|
346 |
+
else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
|
347 |
+
else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
|
348 |
+
//if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
|
349 |
+
//else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
|
350 |
+
//else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
|
351 |
+
//else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
|
352 |
+
}
|
353 |
+
|
model/examples/cifar.c
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void train_cifar(char *cfgfile, char *weightfile)
|
4 |
+
{
|
5 |
+
srand(time(0));
|
6 |
+
float avg_loss = -1;
|
7 |
+
char *base = basecfg(cfgfile);
|
8 |
+
printf("%s\n", base);
|
9 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
10 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
11 |
+
|
12 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
13 |
+
int classes = 10;
|
14 |
+
int N = 50000;
|
15 |
+
|
16 |
+
char **labels = get_labels("data/cifar/labels.txt");
|
17 |
+
int epoch = (*net->seen)/N;
|
18 |
+
data train = load_all_cifar10();
|
19 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
20 |
+
clock_t time=clock();
|
21 |
+
|
22 |
+
float loss = train_network_sgd(net, train, 1);
|
23 |
+
if(avg_loss == -1) avg_loss = loss;
|
24 |
+
avg_loss = avg_loss*.95 + loss*.05;
|
25 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
|
26 |
+
if(*net->seen/N > epoch){
|
27 |
+
epoch = *net->seen/N;
|
28 |
+
char buff[256];
|
29 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
30 |
+
save_weights(net, buff);
|
31 |
+
}
|
32 |
+
if(get_current_batch(net)%100 == 0){
|
33 |
+
char buff[256];
|
34 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
35 |
+
save_weights(net, buff);
|
36 |
+
}
|
37 |
+
}
|
38 |
+
char buff[256];
|
39 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
40 |
+
save_weights(net, buff);
|
41 |
+
|
42 |
+
free_network(net);
|
43 |
+
free_ptrs((void**)labels, classes);
|
44 |
+
free(base);
|
45 |
+
free_data(train);
|
46 |
+
}
|
47 |
+
|
48 |
+
void train_cifar_distill(char *cfgfile, char *weightfile)
|
49 |
+
{
|
50 |
+
srand(time(0));
|
51 |
+
float avg_loss = -1;
|
52 |
+
char *base = basecfg(cfgfile);
|
53 |
+
printf("%s\n", base);
|
54 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
55 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
56 |
+
|
57 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
58 |
+
int classes = 10;
|
59 |
+
int N = 50000;
|
60 |
+
|
61 |
+
char **labels = get_labels("data/cifar/labels.txt");
|
62 |
+
int epoch = (*net->seen)/N;
|
63 |
+
|
64 |
+
data train = load_all_cifar10();
|
65 |
+
matrix soft = csv_to_matrix("results/ensemble.csv");
|
66 |
+
|
67 |
+
float weight = .9;
|
68 |
+
scale_matrix(soft, weight);
|
69 |
+
scale_matrix(train.y, 1. - weight);
|
70 |
+
matrix_add_matrix(soft, train.y);
|
71 |
+
|
72 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
73 |
+
clock_t time=clock();
|
74 |
+
|
75 |
+
float loss = train_network_sgd(net, train, 1);
|
76 |
+
if(avg_loss == -1) avg_loss = loss;
|
77 |
+
avg_loss = avg_loss*.95 + loss*.05;
|
78 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
|
79 |
+
if(*net->seen/N > epoch){
|
80 |
+
epoch = *net->seen/N;
|
81 |
+
char buff[256];
|
82 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
83 |
+
save_weights(net, buff);
|
84 |
+
}
|
85 |
+
if(get_current_batch(net)%100 == 0){
|
86 |
+
char buff[256];
|
87 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
88 |
+
save_weights(net, buff);
|
89 |
+
}
|
90 |
+
}
|
91 |
+
char buff[256];
|
92 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
93 |
+
save_weights(net, buff);
|
94 |
+
|
95 |
+
free_network(net);
|
96 |
+
free_ptrs((void**)labels, classes);
|
97 |
+
free(base);
|
98 |
+
free_data(train);
|
99 |
+
}
|
100 |
+
|
101 |
+
void test_cifar_multi(char *filename, char *weightfile)
|
102 |
+
{
|
103 |
+
network *net = load_network(filename, weightfile, 0);
|
104 |
+
set_batch_network(net, 1);
|
105 |
+
srand(time(0));
|
106 |
+
|
107 |
+
float avg_acc = 0;
|
108 |
+
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
|
109 |
+
|
110 |
+
int i;
|
111 |
+
for(i = 0; i < test.X.rows; ++i){
|
112 |
+
image im = float_to_image(32, 32, 3, test.X.vals[i]);
|
113 |
+
|
114 |
+
float pred[10] = {0};
|
115 |
+
|
116 |
+
float *p = network_predict(net, im.data);
|
117 |
+
axpy_cpu(10, 1, p, 1, pred, 1);
|
118 |
+
flip_image(im);
|
119 |
+
p = network_predict(net, im.data);
|
120 |
+
axpy_cpu(10, 1, p, 1, pred, 1);
|
121 |
+
|
122 |
+
int index = max_index(pred, 10);
|
123 |
+
int class = max_index(test.y.vals[i], 10);
|
124 |
+
if(index == class) avg_acc += 1;
|
125 |
+
free_image(im);
|
126 |
+
printf("%4d: %.2f%%\n", i, 100.*avg_acc/(i+1));
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
void test_cifar(char *filename, char *weightfile)
|
131 |
+
{
|
132 |
+
network *net = load_network(filename, weightfile, 0);
|
133 |
+
srand(time(0));
|
134 |
+
|
135 |
+
clock_t time;
|
136 |
+
float avg_acc = 0;
|
137 |
+
float avg_top5 = 0;
|
138 |
+
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
|
139 |
+
|
140 |
+
time=clock();
|
141 |
+
|
142 |
+
float *acc = network_accuracies(net, test, 2);
|
143 |
+
avg_acc += acc[0];
|
144 |
+
avg_top5 += acc[1];
|
145 |
+
printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows);
|
146 |
+
free_data(test);
|
147 |
+
}
|
148 |
+
|
149 |
+
void extract_cifar()
|
150 |
+
{
|
151 |
+
char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"};
|
152 |
+
int i;
|
153 |
+
data train = load_all_cifar10();
|
154 |
+
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
|
155 |
+
for(i = 0; i < train.X.rows; ++i){
|
156 |
+
image im = float_to_image(32, 32, 3, train.X.vals[i]);
|
157 |
+
int class = max_index(train.y.vals[i], 10);
|
158 |
+
char buff[256];
|
159 |
+
sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]);
|
160 |
+
save_image_options(im, buff, PNG, 0);
|
161 |
+
}
|
162 |
+
for(i = 0; i < test.X.rows; ++i){
|
163 |
+
image im = float_to_image(32, 32, 3, test.X.vals[i]);
|
164 |
+
int class = max_index(test.y.vals[i], 10);
|
165 |
+
char buff[256];
|
166 |
+
sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]);
|
167 |
+
save_image_options(im, buff, PNG, 0);
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
void test_cifar_csv(char *filename, char *weightfile)
|
172 |
+
{
|
173 |
+
network *net = load_network(filename, weightfile, 0);
|
174 |
+
srand(time(0));
|
175 |
+
|
176 |
+
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
|
177 |
+
|
178 |
+
matrix pred = network_predict_data(net, test);
|
179 |
+
|
180 |
+
int i;
|
181 |
+
for(i = 0; i < test.X.rows; ++i){
|
182 |
+
image im = float_to_image(32, 32, 3, test.X.vals[i]);
|
183 |
+
flip_image(im);
|
184 |
+
}
|
185 |
+
matrix pred2 = network_predict_data(net, test);
|
186 |
+
scale_matrix(pred, .5);
|
187 |
+
scale_matrix(pred2, .5);
|
188 |
+
matrix_add_matrix(pred2, pred);
|
189 |
+
|
190 |
+
matrix_to_csv(pred);
|
191 |
+
fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
|
192 |
+
free_data(test);
|
193 |
+
}
|
194 |
+
|
195 |
+
void test_cifar_csvtrain(char *cfg, char *weights)
|
196 |
+
{
|
197 |
+
network *net = load_network(cfg, weights, 0);
|
198 |
+
srand(time(0));
|
199 |
+
|
200 |
+
data test = load_all_cifar10();
|
201 |
+
|
202 |
+
matrix pred = network_predict_data(net, test);
|
203 |
+
|
204 |
+
int i;
|
205 |
+
for(i = 0; i < test.X.rows; ++i){
|
206 |
+
image im = float_to_image(32, 32, 3, test.X.vals[i]);
|
207 |
+
flip_image(im);
|
208 |
+
}
|
209 |
+
matrix pred2 = network_predict_data(net, test);
|
210 |
+
scale_matrix(pred, .5);
|
211 |
+
scale_matrix(pred2, .5);
|
212 |
+
matrix_add_matrix(pred2, pred);
|
213 |
+
|
214 |
+
matrix_to_csv(pred);
|
215 |
+
fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
|
216 |
+
free_data(test);
|
217 |
+
}
|
218 |
+
|
219 |
+
void eval_cifar_csv()
|
220 |
+
{
|
221 |
+
data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
|
222 |
+
|
223 |
+
matrix pred = csv_to_matrix("results/combined.csv");
|
224 |
+
fprintf(stderr, "%d %d\n", pred.rows, pred.cols);
|
225 |
+
|
226 |
+
fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
|
227 |
+
free_data(test);
|
228 |
+
free_matrix(pred);
|
229 |
+
}
|
230 |
+
|
231 |
+
|
232 |
+
void run_cifar(int argc, char **argv)
|
233 |
+
{
|
234 |
+
if(argc < 4){
|
235 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
236 |
+
return;
|
237 |
+
}
|
238 |
+
|
239 |
+
char *cfg = argv[3];
|
240 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
241 |
+
if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights);
|
242 |
+
else if(0==strcmp(argv[2], "extract")) extract_cifar();
|
243 |
+
else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights);
|
244 |
+
else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights);
|
245 |
+
else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights);
|
246 |
+
else if(0==strcmp(argv[2], "csv")) test_cifar_csv(cfg, weights);
|
247 |
+
else if(0==strcmp(argv[2], "csvtrain")) test_cifar_csvtrain(cfg, weights);
|
248 |
+
else if(0==strcmp(argv[2], "eval")) eval_cifar_csv();
|
249 |
+
}
|
250 |
+
|
251 |
+
|
model/examples/classifier.c
ADDED
@@ -0,0 +1,1098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <sys/time.h>
|
4 |
+
#include <assert.h>
|
5 |
+
|
6 |
+
float *get_regression_values(char **labels, int n)
|
7 |
+
{
|
8 |
+
float *v = calloc(n, sizeof(float));
|
9 |
+
int i;
|
10 |
+
for(i = 0; i < n; ++i){
|
11 |
+
char *p = strchr(labels[i], ' ');
|
12 |
+
*p = 0;
|
13 |
+
v[i] = atof(p+1);
|
14 |
+
}
|
15 |
+
return v;
|
16 |
+
}
|
17 |
+
|
18 |
+
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
19 |
+
{
|
20 |
+
int i;
|
21 |
+
|
22 |
+
float avg_loss = -1;
|
23 |
+
char *base = basecfg(cfgfile);
|
24 |
+
printf("%s\n", base);
|
25 |
+
printf("%d\n", ngpus);
|
26 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
27 |
+
|
28 |
+
srand(time(0));
|
29 |
+
int seed = rand();
|
30 |
+
for(i = 0; i < ngpus; ++i){
|
31 |
+
srand(seed);
|
32 |
+
#ifdef GPU
|
33 |
+
cuda_set_device(gpus[i]);
|
34 |
+
#endif
|
35 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
36 |
+
nets[i]->learning_rate *= ngpus;
|
37 |
+
}
|
38 |
+
srand(time(0));
|
39 |
+
network *net = nets[0];
|
40 |
+
|
41 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
42 |
+
|
43 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
44 |
+
list *options = read_data_cfg(datacfg);
|
45 |
+
|
46 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
47 |
+
int tag = option_find_int_quiet(options, "tag", 0);
|
48 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
49 |
+
char *train_list = option_find_str(options, "train", "data/train.list");
|
50 |
+
char *tree = option_find_str(options, "tree", 0);
|
51 |
+
if (tree) net->hierarchy = read_tree(tree);
|
52 |
+
int classes = option_find_int(options, "classes", 2);
|
53 |
+
|
54 |
+
char **labels = 0;
|
55 |
+
if(!tag){
|
56 |
+
labels = get_labels(label_list);
|
57 |
+
}
|
58 |
+
list *plist = get_paths(train_list);
|
59 |
+
char **paths = (char **)list_to_array(plist);
|
60 |
+
printf("%d\n", plist->size);
|
61 |
+
int N = plist->size;
|
62 |
+
double time;
|
63 |
+
|
64 |
+
load_args args = {0};
|
65 |
+
args.w = net->w;
|
66 |
+
args.h = net->h;
|
67 |
+
args.threads = 32;
|
68 |
+
args.hierarchy = net->hierarchy;
|
69 |
+
|
70 |
+
args.min = net->min_ratio*net->w;
|
71 |
+
args.max = net->max_ratio*net->w;
|
72 |
+
printf("%d %d\n", args.min, args.max);
|
73 |
+
args.angle = net->angle;
|
74 |
+
args.aspect = net->aspect;
|
75 |
+
args.exposure = net->exposure;
|
76 |
+
args.saturation = net->saturation;
|
77 |
+
args.hue = net->hue;
|
78 |
+
args.size = net->w;
|
79 |
+
|
80 |
+
args.paths = paths;
|
81 |
+
args.classes = classes;
|
82 |
+
args.n = imgs;
|
83 |
+
args.m = N;
|
84 |
+
args.labels = labels;
|
85 |
+
if (tag){
|
86 |
+
args.type = TAG_DATA;
|
87 |
+
} else {
|
88 |
+
args.type = CLASSIFICATION_DATA;
|
89 |
+
}
|
90 |
+
|
91 |
+
data train;
|
92 |
+
data buffer;
|
93 |
+
pthread_t load_thread;
|
94 |
+
args.d = &buffer;
|
95 |
+
load_thread = load_data(args);
|
96 |
+
|
97 |
+
int count = 0;
|
98 |
+
int epoch = (*net->seen)/N;
|
99 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
100 |
+
if(net->random && count++%40 == 0){
|
101 |
+
printf("Resizing\n");
|
102 |
+
int dim = (rand() % 11 + 4) * 32;
|
103 |
+
//if (get_current_batch(net)+200 > net->max_batches) dim = 608;
|
104 |
+
//int dim = (rand() % 4 + 16) * 32;
|
105 |
+
printf("%d\n", dim);
|
106 |
+
args.w = dim;
|
107 |
+
args.h = dim;
|
108 |
+
args.size = dim;
|
109 |
+
args.min = net->min_ratio*dim;
|
110 |
+
args.max = net->max_ratio*dim;
|
111 |
+
printf("%d %d\n", args.min, args.max);
|
112 |
+
|
113 |
+
pthread_join(load_thread, 0);
|
114 |
+
train = buffer;
|
115 |
+
free_data(train);
|
116 |
+
load_thread = load_data(args);
|
117 |
+
|
118 |
+
for(i = 0; i < ngpus; ++i){
|
119 |
+
resize_network(nets[i], dim, dim);
|
120 |
+
}
|
121 |
+
net = nets[0];
|
122 |
+
}
|
123 |
+
time = what_time_is_it_now();
|
124 |
+
|
125 |
+
pthread_join(load_thread, 0);
|
126 |
+
train = buffer;
|
127 |
+
load_thread = load_data(args);
|
128 |
+
|
129 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
130 |
+
time = what_time_is_it_now();
|
131 |
+
|
132 |
+
float loss = 0;
|
133 |
+
#ifdef GPU
|
134 |
+
if(ngpus == 1){
|
135 |
+
loss = train_network(net, train);
|
136 |
+
} else {
|
137 |
+
loss = train_networks(nets, ngpus, train, 4);
|
138 |
+
}
|
139 |
+
#else
|
140 |
+
loss = train_network(net, train);
|
141 |
+
#endif
|
142 |
+
if(avg_loss == -1) avg_loss = loss;
|
143 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
144 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
145 |
+
free_data(train);
|
146 |
+
if(*net->seen/N > epoch){
|
147 |
+
epoch = *net->seen/N;
|
148 |
+
char buff[256];
|
149 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
150 |
+
save_weights(net, buff);
|
151 |
+
}
|
152 |
+
if(get_current_batch(net)%1000 == 0){
|
153 |
+
char buff[256];
|
154 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
155 |
+
save_weights(net, buff);
|
156 |
+
}
|
157 |
+
}
|
158 |
+
char buff[256];
|
159 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
160 |
+
save_weights(net, buff);
|
161 |
+
pthread_join(load_thread, 0);
|
162 |
+
|
163 |
+
free_network(net);
|
164 |
+
if(labels) free_ptrs((void**)labels, classes);
|
165 |
+
free_ptrs((void**)paths, plist->size);
|
166 |
+
free_list(plist);
|
167 |
+
free(base);
|
168 |
+
}
|
169 |
+
|
170 |
+
void validate_classifier_crop(char *datacfg, char *filename, char *weightfile)
|
171 |
+
{
|
172 |
+
int i = 0;
|
173 |
+
network *net = load_network(filename, weightfile, 0);
|
174 |
+
srand(time(0));
|
175 |
+
|
176 |
+
list *options = read_data_cfg(datacfg);
|
177 |
+
|
178 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
179 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
180 |
+
int classes = option_find_int(options, "classes", 2);
|
181 |
+
int topk = option_find_int(options, "top", 1);
|
182 |
+
|
183 |
+
char **labels = get_labels(label_list);
|
184 |
+
list *plist = get_paths(valid_list);
|
185 |
+
|
186 |
+
char **paths = (char **)list_to_array(plist);
|
187 |
+
int m = plist->size;
|
188 |
+
free_list(plist);
|
189 |
+
|
190 |
+
clock_t time;
|
191 |
+
float avg_acc = 0;
|
192 |
+
float avg_topk = 0;
|
193 |
+
int splits = m/1000;
|
194 |
+
int num = (i+1)*m/splits - i*m/splits;
|
195 |
+
|
196 |
+
data val, buffer;
|
197 |
+
|
198 |
+
load_args args = {0};
|
199 |
+
args.w = net->w;
|
200 |
+
args.h = net->h;
|
201 |
+
|
202 |
+
args.paths = paths;
|
203 |
+
args.classes = classes;
|
204 |
+
args.n = num;
|
205 |
+
args.m = 0;
|
206 |
+
args.labels = labels;
|
207 |
+
args.d = &buffer;
|
208 |
+
args.type = OLD_CLASSIFICATION_DATA;
|
209 |
+
|
210 |
+
pthread_t load_thread = load_data_in_thread(args);
|
211 |
+
for(i = 1; i <= splits; ++i){
|
212 |
+
time=clock();
|
213 |
+
|
214 |
+
pthread_join(load_thread, 0);
|
215 |
+
val = buffer;
|
216 |
+
|
217 |
+
num = (i+1)*m/splits - i*m/splits;
|
218 |
+
char **part = paths+(i*m/splits);
|
219 |
+
if(i != splits){
|
220 |
+
args.paths = part;
|
221 |
+
load_thread = load_data_in_thread(args);
|
222 |
+
}
|
223 |
+
printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
224 |
+
|
225 |
+
time=clock();
|
226 |
+
float *acc = network_accuracies(net, val, topk);
|
227 |
+
avg_acc += acc[0];
|
228 |
+
avg_topk += acc[1];
|
229 |
+
printf("%d: top 1: %f, top %d: %f, %lf seconds, %d images\n", i, avg_acc/i, topk, avg_topk/i, sec(clock()-time), val.X.rows);
|
230 |
+
free_data(val);
|
231 |
+
}
|
232 |
+
}
|
233 |
+
|
234 |
+
void validate_classifier_10(char *datacfg, char *filename, char *weightfile)
|
235 |
+
{
|
236 |
+
int i, j;
|
237 |
+
network *net = load_network(filename, weightfile, 0);
|
238 |
+
set_batch_network(net, 1);
|
239 |
+
srand(time(0));
|
240 |
+
|
241 |
+
list *options = read_data_cfg(datacfg);
|
242 |
+
|
243 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
244 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
245 |
+
int classes = option_find_int(options, "classes", 2);
|
246 |
+
int topk = option_find_int(options, "top", 1);
|
247 |
+
|
248 |
+
char **labels = get_labels(label_list);
|
249 |
+
list *plist = get_paths(valid_list);
|
250 |
+
|
251 |
+
char **paths = (char **)list_to_array(plist);
|
252 |
+
int m = plist->size;
|
253 |
+
free_list(plist);
|
254 |
+
|
255 |
+
float avg_acc = 0;
|
256 |
+
float avg_topk = 0;
|
257 |
+
int *indexes = calloc(topk, sizeof(int));
|
258 |
+
|
259 |
+
for(i = 0; i < m; ++i){
|
260 |
+
int class = -1;
|
261 |
+
char *path = paths[i];
|
262 |
+
for(j = 0; j < classes; ++j){
|
263 |
+
if(strstr(path, labels[j])){
|
264 |
+
class = j;
|
265 |
+
break;
|
266 |
+
}
|
267 |
+
}
|
268 |
+
int w = net->w;
|
269 |
+
int h = net->h;
|
270 |
+
int shift = 32;
|
271 |
+
image im = load_image_color(paths[i], w+shift, h+shift);
|
272 |
+
image images[10];
|
273 |
+
images[0] = crop_image(im, -shift, -shift, w, h);
|
274 |
+
images[1] = crop_image(im, shift, -shift, w, h);
|
275 |
+
images[2] = crop_image(im, 0, 0, w, h);
|
276 |
+
images[3] = crop_image(im, -shift, shift, w, h);
|
277 |
+
images[4] = crop_image(im, shift, shift, w, h);
|
278 |
+
flip_image(im);
|
279 |
+
images[5] = crop_image(im, -shift, -shift, w, h);
|
280 |
+
images[6] = crop_image(im, shift, -shift, w, h);
|
281 |
+
images[7] = crop_image(im, 0, 0, w, h);
|
282 |
+
images[8] = crop_image(im, -shift, shift, w, h);
|
283 |
+
images[9] = crop_image(im, shift, shift, w, h);
|
284 |
+
float *pred = calloc(classes, sizeof(float));
|
285 |
+
for(j = 0; j < 10; ++j){
|
286 |
+
float *p = network_predict(net, images[j].data);
|
287 |
+
if(net->hierarchy) hierarchy_predictions(p, net->outputs, net->hierarchy, 1, 1);
|
288 |
+
axpy_cpu(classes, 1, p, 1, pred, 1);
|
289 |
+
free_image(images[j]);
|
290 |
+
}
|
291 |
+
free_image(im);
|
292 |
+
top_k(pred, classes, topk, indexes);
|
293 |
+
free(pred);
|
294 |
+
if(indexes[0] == class) avg_acc += 1;
|
295 |
+
for(j = 0; j < topk; ++j){
|
296 |
+
if(indexes[j] == class) avg_topk += 1;
|
297 |
+
}
|
298 |
+
|
299 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
300 |
+
}
|
301 |
+
}
|
302 |
+
|
303 |
+
void validate_classifier_full(char *datacfg, char *filename, char *weightfile)
|
304 |
+
{
|
305 |
+
int i, j;
|
306 |
+
network *net = load_network(filename, weightfile, 0);
|
307 |
+
set_batch_network(net, 1);
|
308 |
+
srand(time(0));
|
309 |
+
|
310 |
+
list *options = read_data_cfg(datacfg);
|
311 |
+
|
312 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
313 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
314 |
+
int classes = option_find_int(options, "classes", 2);
|
315 |
+
int topk = option_find_int(options, "top", 1);
|
316 |
+
|
317 |
+
char **labels = get_labels(label_list);
|
318 |
+
list *plist = get_paths(valid_list);
|
319 |
+
|
320 |
+
char **paths = (char **)list_to_array(plist);
|
321 |
+
int m = plist->size;
|
322 |
+
free_list(plist);
|
323 |
+
|
324 |
+
float avg_acc = 0;
|
325 |
+
float avg_topk = 0;
|
326 |
+
int *indexes = calloc(topk, sizeof(int));
|
327 |
+
|
328 |
+
int size = net->w;
|
329 |
+
for(i = 0; i < m; ++i){
|
330 |
+
int class = -1;
|
331 |
+
char *path = paths[i];
|
332 |
+
for(j = 0; j < classes; ++j){
|
333 |
+
if(strstr(path, labels[j])){
|
334 |
+
class = j;
|
335 |
+
break;
|
336 |
+
}
|
337 |
+
}
|
338 |
+
image im = load_image_color(paths[i], 0, 0);
|
339 |
+
image resized = resize_min(im, size);
|
340 |
+
resize_network(net, resized.w, resized.h);
|
341 |
+
//show_image(im, "orig");
|
342 |
+
//show_image(crop, "cropped");
|
343 |
+
//cvWaitKey(0);
|
344 |
+
float *pred = network_predict(net, resized.data);
|
345 |
+
if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 1, 1);
|
346 |
+
|
347 |
+
free_image(im);
|
348 |
+
free_image(resized);
|
349 |
+
top_k(pred, classes, topk, indexes);
|
350 |
+
|
351 |
+
if(indexes[0] == class) avg_acc += 1;
|
352 |
+
for(j = 0; j < topk; ++j){
|
353 |
+
if(indexes[j] == class) avg_topk += 1;
|
354 |
+
}
|
355 |
+
|
356 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
357 |
+
}
|
358 |
+
}
|
359 |
+
|
360 |
+
|
361 |
+
void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
|
362 |
+
{
|
363 |
+
int i, j;
|
364 |
+
network *net = load_network(filename, weightfile, 0);
|
365 |
+
set_batch_network(net, 1);
|
366 |
+
srand(time(0));
|
367 |
+
|
368 |
+
list *options = read_data_cfg(datacfg);
|
369 |
+
|
370 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
371 |
+
char *leaf_list = option_find_str(options, "leaves", 0);
|
372 |
+
if(leaf_list) change_leaves(net->hierarchy, leaf_list);
|
373 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
374 |
+
int classes = option_find_int(options, "classes", 2);
|
375 |
+
int topk = option_find_int(options, "top", 1);
|
376 |
+
|
377 |
+
char **labels = get_labels(label_list);
|
378 |
+
list *plist = get_paths(valid_list);
|
379 |
+
|
380 |
+
char **paths = (char **)list_to_array(plist);
|
381 |
+
int m = plist->size;
|
382 |
+
free_list(plist);
|
383 |
+
|
384 |
+
float avg_acc = 0;
|
385 |
+
float avg_topk = 0;
|
386 |
+
int *indexes = calloc(topk, sizeof(int));
|
387 |
+
|
388 |
+
for(i = 0; i < m; ++i){
|
389 |
+
int class = -1;
|
390 |
+
char *path = paths[i];
|
391 |
+
for(j = 0; j < classes; ++j){
|
392 |
+
if(strstr(path, labels[j])){
|
393 |
+
class = j;
|
394 |
+
break;
|
395 |
+
}
|
396 |
+
}
|
397 |
+
image im = load_image_color(paths[i], 0, 0);
|
398 |
+
image crop = center_crop_image(im, net->w, net->h);
|
399 |
+
//grayscale_image_3c(crop);
|
400 |
+
//show_image(im, "orig");
|
401 |
+
//show_image(crop, "cropped");
|
402 |
+
//cvWaitKey(0);
|
403 |
+
float *pred = network_predict(net, crop.data);
|
404 |
+
if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 1, 1);
|
405 |
+
|
406 |
+
free_image(im);
|
407 |
+
free_image(crop);
|
408 |
+
top_k(pred, classes, topk, indexes);
|
409 |
+
|
410 |
+
if(indexes[0] == class) avg_acc += 1;
|
411 |
+
for(j = 0; j < topk; ++j){
|
412 |
+
if(indexes[j] == class) avg_topk += 1;
|
413 |
+
}
|
414 |
+
|
415 |
+
printf("%s, %d, %f, %f, \n", paths[i], class, pred[0], pred[1]);
|
416 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
417 |
+
}
|
418 |
+
}
|
419 |
+
|
420 |
+
void validate_classifier_multi(char *datacfg, char *cfg, char *weights)
|
421 |
+
{
|
422 |
+
int i, j;
|
423 |
+
network *net = load_network(cfg, weights, 0);
|
424 |
+
set_batch_network(net, 1);
|
425 |
+
srand(time(0));
|
426 |
+
|
427 |
+
list *options = read_data_cfg(datacfg);
|
428 |
+
|
429 |
+
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
430 |
+
char *valid_list = option_find_str(options, "valid", "data/train.list");
|
431 |
+
int classes = option_find_int(options, "classes", 2);
|
432 |
+
int topk = option_find_int(options, "top", 1);
|
433 |
+
|
434 |
+
char **labels = get_labels(label_list);
|
435 |
+
list *plist = get_paths(valid_list);
|
436 |
+
//int scales[] = {224, 288, 320, 352, 384};
|
437 |
+
int scales[] = {224, 256, 288, 320};
|
438 |
+
int nscales = sizeof(scales)/sizeof(scales[0]);
|
439 |
+
|
440 |
+
char **paths = (char **)list_to_array(plist);
|
441 |
+
int m = plist->size;
|
442 |
+
free_list(plist);
|
443 |
+
|
444 |
+
float avg_acc = 0;
|
445 |
+
float avg_topk = 0;
|
446 |
+
int *indexes = calloc(topk, sizeof(int));
|
447 |
+
|
448 |
+
for(i = 0; i < m; ++i){
|
449 |
+
int class = -1;
|
450 |
+
char *path = paths[i];
|
451 |
+
for(j = 0; j < classes; ++j){
|
452 |
+
if(strstr(path, labels[j])){
|
453 |
+
class = j;
|
454 |
+
break;
|
455 |
+
}
|
456 |
+
}
|
457 |
+
float *pred = calloc(classes, sizeof(float));
|
458 |
+
image im = load_image_color(paths[i], 0, 0);
|
459 |
+
for(j = 0; j < nscales; ++j){
|
460 |
+
image r = resize_max(im, scales[j]);
|
461 |
+
resize_network(net, r.w, r.h);
|
462 |
+
float *p = network_predict(net, r.data);
|
463 |
+
if(net->hierarchy) hierarchy_predictions(p, net->outputs, net->hierarchy, 1 , 1);
|
464 |
+
axpy_cpu(classes, 1, p, 1, pred, 1);
|
465 |
+
flip_image(r);
|
466 |
+
p = network_predict(net, r.data);
|
467 |
+
axpy_cpu(classes, 1, p, 1, pred, 1);
|
468 |
+
if(r.data != im.data) free_image(r);
|
469 |
+
}
|
470 |
+
free_image(im);
|
471 |
+
top_k(pred, classes, topk, indexes);
|
472 |
+
free(pred);
|
473 |
+
if(indexes[0] == class) avg_acc += 1;
|
474 |
+
for(j = 0; j < topk; ++j){
|
475 |
+
if(indexes[j] == class) avg_topk += 1;
|
476 |
+
}
|
477 |
+
|
478 |
+
printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
|
479 |
+
}
|
480 |
+
}
|
481 |
+
|
482 |
+
void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int layer_num)
|
483 |
+
{
|
484 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
485 |
+
set_batch_network(net, 1);
|
486 |
+
srand(2222222);
|
487 |
+
|
488 |
+
list *options = read_data_cfg(datacfg);
|
489 |
+
|
490 |
+
char *name_list = option_find_str(options, "names", 0);
|
491 |
+
if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
|
492 |
+
int top = option_find_int(options, "top", 1);
|
493 |
+
|
494 |
+
int i = 0;
|
495 |
+
char **names = get_labels(name_list);
|
496 |
+
clock_t time;
|
497 |
+
int *indexes = calloc(top, sizeof(int));
|
498 |
+
char buff[256];
|
499 |
+
char *input = buff;
|
500 |
+
while(1){
|
501 |
+
if(filename){
|
502 |
+
strncpy(input, filename, 256);
|
503 |
+
}else{
|
504 |
+
printf("Enter Image Path: ");
|
505 |
+
fflush(stdout);
|
506 |
+
input = fgets(input, 256, stdin);
|
507 |
+
if(!input) return;
|
508 |
+
strtok(input, "\n");
|
509 |
+
}
|
510 |
+
image orig = load_image_color(input, 0, 0);
|
511 |
+
image r = resize_min(orig, 256);
|
512 |
+
image im = crop_image(r, (r.w - 224 - 1)/2 + 1, (r.h - 224 - 1)/2 + 1, 224, 224);
|
513 |
+
float mean[] = {0.48263312050943, 0.45230225481413, 0.40099074308742};
|
514 |
+
float std[] = {0.22590347483426, 0.22120921437787, 0.22103996251583};
|
515 |
+
float var[3];
|
516 |
+
var[0] = std[0]*std[0];
|
517 |
+
var[1] = std[1]*std[1];
|
518 |
+
var[2] = std[2]*std[2];
|
519 |
+
|
520 |
+
normalize_cpu(im.data, mean, var, 1, 3, im.w*im.h);
|
521 |
+
|
522 |
+
float *X = im.data;
|
523 |
+
time=clock();
|
524 |
+
float *predictions = network_predict(net, X);
|
525 |
+
|
526 |
+
layer l = net->layers[layer_num];
|
527 |
+
for(i = 0; i < l.c; ++i){
|
528 |
+
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
|
529 |
+
}
|
530 |
+
#ifdef GPU
|
531 |
+
cuda_pull_array(l.output_gpu, l.output, l.outputs);
|
532 |
+
#endif
|
533 |
+
for(i = 0; i < l.outputs; ++i){
|
534 |
+
printf("%f\n", l.output[i]);
|
535 |
+
}
|
536 |
+
/*
|
537 |
+
|
538 |
+
printf("\n\nWeights\n");
|
539 |
+
for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
|
540 |
+
printf("%f\n", l.filters[i]);
|
541 |
+
}
|
542 |
+
|
543 |
+
printf("\n\nBiases\n");
|
544 |
+
for(i = 0; i < l.n; ++i){
|
545 |
+
printf("%f\n", l.biases[i]);
|
546 |
+
}
|
547 |
+
*/
|
548 |
+
|
549 |
+
top_predictions(net, top, indexes);
|
550 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
551 |
+
for(i = 0; i < top; ++i){
|
552 |
+
int index = indexes[i];
|
553 |
+
printf("%s: %f\n", names[index], predictions[index]);
|
554 |
+
}
|
555 |
+
free_image(im);
|
556 |
+
if (filename) break;
|
557 |
+
}
|
558 |
+
}
|
559 |
+
|
560 |
+
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
|
561 |
+
{
|
562 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
563 |
+
set_batch_network(net, 1);
|
564 |
+
srand(2222222);
|
565 |
+
|
566 |
+
list *options = read_data_cfg(datacfg);
|
567 |
+
|
568 |
+
char *name_list = option_find_str(options, "names", 0);
|
569 |
+
if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
|
570 |
+
if(top == 0) top = option_find_int(options, "top", 1);
|
571 |
+
|
572 |
+
int i = 0;
|
573 |
+
char **names = get_labels(name_list);
|
574 |
+
clock_t time;
|
575 |
+
int *indexes = calloc(top, sizeof(int));
|
576 |
+
char buff[256];
|
577 |
+
char *input = buff;
|
578 |
+
while(1){
|
579 |
+
if(filename){
|
580 |
+
strncpy(input, filename, 256);
|
581 |
+
}else{
|
582 |
+
printf("Enter Image Path: ");
|
583 |
+
fflush(stdout);
|
584 |
+
input = fgets(input, 256, stdin);
|
585 |
+
if(!input) return;
|
586 |
+
strtok(input, "\n");
|
587 |
+
}
|
588 |
+
image im = load_image_color(input, 0, 0);
|
589 |
+
image r = letterbox_image(im, net->w, net->h);
|
590 |
+
//image r = resize_min(im, 320);
|
591 |
+
//printf("%d %d\n", r.w, r.h);
|
592 |
+
//resize_network(net, r.w, r.h);
|
593 |
+
//printf("%d %d\n", r.w, r.h);
|
594 |
+
|
595 |
+
float *X = r.data;
|
596 |
+
time=clock();
|
597 |
+
float *predictions = network_predict(net, X);
|
598 |
+
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
|
599 |
+
top_k(predictions, net->outputs, top, indexes);
|
600 |
+
fprintf(stderr, "%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
601 |
+
for(i = 0; i < top; ++i){
|
602 |
+
int index = indexes[i];
|
603 |
+
//if(net->hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net->hierarchy->parent[index] >= 0) ? names[net->hierarchy->parent[index]] : "Root");
|
604 |
+
//else printf("%s: %f\n",names[index], predictions[index]);
|
605 |
+
printf("%5.2f%%: %s\n", predictions[index]*100, names[index]);
|
606 |
+
}
|
607 |
+
if(r.data != im.data) free_image(r);
|
608 |
+
free_image(im);
|
609 |
+
if (filename) break;
|
610 |
+
}
|
611 |
+
}
|
612 |
+
|
613 |
+
|
614 |
+
void label_classifier(char *datacfg, char *filename, char *weightfile)
|
615 |
+
{
|
616 |
+
int i;
|
617 |
+
network *net = load_network(filename, weightfile, 0);
|
618 |
+
set_batch_network(net, 1);
|
619 |
+
srand(time(0));
|
620 |
+
|
621 |
+
list *options = read_data_cfg(datacfg);
|
622 |
+
|
623 |
+
char *label_list = option_find_str(options, "names", "data/labels.list");
|
624 |
+
char *test_list = option_find_str(options, "test", "data/train.list");
|
625 |
+
int classes = option_find_int(options, "classes", 2);
|
626 |
+
|
627 |
+
char **labels = get_labels(label_list);
|
628 |
+
list *plist = get_paths(test_list);
|
629 |
+
|
630 |
+
char **paths = (char **)list_to_array(plist);
|
631 |
+
int m = plist->size;
|
632 |
+
free_list(plist);
|
633 |
+
|
634 |
+
for(i = 0; i < m; ++i){
|
635 |
+
image im = load_image_color(paths[i], 0, 0);
|
636 |
+
image resized = resize_min(im, net->w);
|
637 |
+
image crop = crop_image(resized, (resized.w - net->w)/2, (resized.h - net->h)/2, net->w, net->h);
|
638 |
+
float *pred = network_predict(net, crop.data);
|
639 |
+
|
640 |
+
if(resized.data != im.data) free_image(resized);
|
641 |
+
free_image(im);
|
642 |
+
free_image(crop);
|
643 |
+
int ind = max_index(pred, classes);
|
644 |
+
|
645 |
+
printf("%s\n", labels[ind]);
|
646 |
+
}
|
647 |
+
}
|
648 |
+
|
649 |
+
void csv_classifier(char *datacfg, char *cfgfile, char *weightfile)
|
650 |
+
{
|
651 |
+
int i,j;
|
652 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
653 |
+
srand(time(0));
|
654 |
+
|
655 |
+
list *options = read_data_cfg(datacfg);
|
656 |
+
|
657 |
+
char *test_list = option_find_str(options, "test", "data/test.list");
|
658 |
+
int top = option_find_int(options, "top", 1);
|
659 |
+
|
660 |
+
list *plist = get_paths(test_list);
|
661 |
+
|
662 |
+
char **paths = (char **)list_to_array(plist);
|
663 |
+
int m = plist->size;
|
664 |
+
free_list(plist);
|
665 |
+
int *indexes = calloc(top, sizeof(int));
|
666 |
+
|
667 |
+
for(i = 0; i < m; ++i){
|
668 |
+
double time = what_time_is_it_now();
|
669 |
+
char *path = paths[i];
|
670 |
+
image im = load_image_color(path, 0, 0);
|
671 |
+
image r = letterbox_image(im, net->w, net->h);
|
672 |
+
float *predictions = network_predict(net, r.data);
|
673 |
+
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
|
674 |
+
top_k(predictions, net->outputs, top, indexes);
|
675 |
+
|
676 |
+
printf("%s", path);
|
677 |
+
for(j = 0; j < top; ++j){
|
678 |
+
printf("\t%d", indexes[j]);
|
679 |
+
}
|
680 |
+
printf("\n");
|
681 |
+
|
682 |
+
free_image(im);
|
683 |
+
free_image(r);
|
684 |
+
|
685 |
+
fprintf(stderr, "%lf seconds, %d images, %d total\n", what_time_is_it_now() - time, i+1, m);
|
686 |
+
}
|
687 |
+
}
|
688 |
+
|
689 |
+
void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer)
|
690 |
+
{
|
691 |
+
int curr = 0;
|
692 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
693 |
+
srand(time(0));
|
694 |
+
|
695 |
+
list *options = read_data_cfg(datacfg);
|
696 |
+
|
697 |
+
char *test_list = option_find_str(options, "test", "data/test.list");
|
698 |
+
int classes = option_find_int(options, "classes", 2);
|
699 |
+
|
700 |
+
list *plist = get_paths(test_list);
|
701 |
+
|
702 |
+
char **paths = (char **)list_to_array(plist);
|
703 |
+
int m = plist->size;
|
704 |
+
free_list(plist);
|
705 |
+
|
706 |
+
clock_t time;
|
707 |
+
|
708 |
+
data val, buffer;
|
709 |
+
|
710 |
+
load_args args = {0};
|
711 |
+
args.w = net->w;
|
712 |
+
args.h = net->h;
|
713 |
+
args.paths = paths;
|
714 |
+
args.classes = classes;
|
715 |
+
args.n = net->batch;
|
716 |
+
args.m = 0;
|
717 |
+
args.labels = 0;
|
718 |
+
args.d = &buffer;
|
719 |
+
args.type = OLD_CLASSIFICATION_DATA;
|
720 |
+
|
721 |
+
pthread_t load_thread = load_data_in_thread(args);
|
722 |
+
for(curr = net->batch; curr < m; curr += net->batch){
|
723 |
+
time=clock();
|
724 |
+
|
725 |
+
pthread_join(load_thread, 0);
|
726 |
+
val = buffer;
|
727 |
+
|
728 |
+
if(curr < m){
|
729 |
+
args.paths = paths + curr;
|
730 |
+
if (curr + net->batch > m) args.n = m - curr;
|
731 |
+
load_thread = load_data_in_thread(args);
|
732 |
+
}
|
733 |
+
fprintf(stderr, "Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
734 |
+
|
735 |
+
time=clock();
|
736 |
+
matrix pred = network_predict_data(net, val);
|
737 |
+
|
738 |
+
int i, j;
|
739 |
+
if (target_layer >= 0){
|
740 |
+
//layer l = net->layers[target_layer];
|
741 |
+
}
|
742 |
+
|
743 |
+
for(i = 0; i < pred.rows; ++i){
|
744 |
+
printf("%s", paths[curr-net->batch+i]);
|
745 |
+
for(j = 0; j < pred.cols; ++j){
|
746 |
+
printf("\t%g", pred.vals[i][j]);
|
747 |
+
}
|
748 |
+
printf("\n");
|
749 |
+
}
|
750 |
+
|
751 |
+
free_matrix(pred);
|
752 |
+
|
753 |
+
fprintf(stderr, "%lf seconds, %d images, %d total\n", sec(clock()-time), val.X.rows, curr);
|
754 |
+
free_data(val);
|
755 |
+
}
|
756 |
+
}
|
757 |
+
|
758 |
+
void file_output_classifier(char *datacfg, char *filename, char *weightfile, char *listfile)
|
759 |
+
{
|
760 |
+
int i,j;
|
761 |
+
network *net = load_network(filename, weightfile, 0);
|
762 |
+
set_batch_network(net, 1);
|
763 |
+
srand(time(0));
|
764 |
+
|
765 |
+
list *options = read_data_cfg(datacfg);
|
766 |
+
|
767 |
+
//char *label_list = option_find_str(options, "names", "data/labels.list");
|
768 |
+
int classes = option_find_int(options, "classes", 2);
|
769 |
+
|
770 |
+
list *plist = get_paths(listfile);
|
771 |
+
|
772 |
+
char **paths = (char **)list_to_array(plist);
|
773 |
+
int m = plist->size;
|
774 |
+
free_list(plist);
|
775 |
+
|
776 |
+
for(i = 0; i < m; ++i){
|
777 |
+
image im = load_image_color(paths[i], 0, 0);
|
778 |
+
image resized = resize_min(im, net->w);
|
779 |
+
image crop = crop_image(resized, (resized.w - net->w)/2, (resized.h - net->h)/2, net->w, net->h);
|
780 |
+
|
781 |
+
float *pred = network_predict(net, crop.data);
|
782 |
+
if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 0, 1);
|
783 |
+
|
784 |
+
if(resized.data != im.data) free_image(resized);
|
785 |
+
free_image(im);
|
786 |
+
free_image(crop);
|
787 |
+
|
788 |
+
printf("%s", paths[i]);
|
789 |
+
for(j = 0; j < classes; ++j){
|
790 |
+
printf("\t%g", pred[j]);
|
791 |
+
}
|
792 |
+
printf("\n");
|
793 |
+
}
|
794 |
+
}
|
795 |
+
|
796 |
+
|
797 |
+
void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
|
798 |
+
{
|
799 |
+
#ifdef OPENCV
|
800 |
+
float threat = 0;
|
801 |
+
float roll = .2;
|
802 |
+
|
803 |
+
printf("Classifier Demo\n");
|
804 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
805 |
+
set_batch_network(net, 1);
|
806 |
+
list *options = read_data_cfg(datacfg);
|
807 |
+
|
808 |
+
srand(2222222);
|
809 |
+
void * cap = open_video_stream(filename, cam_index, 0,0,0);
|
810 |
+
|
811 |
+
int top = option_find_int(options, "top", 1);
|
812 |
+
|
813 |
+
char *name_list = option_find_str(options, "names", 0);
|
814 |
+
char **names = get_labels(name_list);
|
815 |
+
|
816 |
+
int *indexes = calloc(top, sizeof(int));
|
817 |
+
|
818 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
819 |
+
//cvNamedWindow("Threat", CV_WINDOW_NORMAL);
|
820 |
+
//cvResizeWindow("Threat", 512, 512);
|
821 |
+
float fps = 0;
|
822 |
+
int i;
|
823 |
+
|
824 |
+
int count = 0;
|
825 |
+
|
826 |
+
while(1){
|
827 |
+
++count;
|
828 |
+
struct timeval tval_before, tval_after, tval_result;
|
829 |
+
gettimeofday(&tval_before, NULL);
|
830 |
+
|
831 |
+
image in = get_image_from_stream(cap);
|
832 |
+
if(!in.data) break;
|
833 |
+
image in_s = resize_image(in, net->w, net->h);
|
834 |
+
|
835 |
+
image out = in;
|
836 |
+
int x1 = out.w / 20;
|
837 |
+
int y1 = out.h / 20;
|
838 |
+
int x2 = 2*x1;
|
839 |
+
int y2 = out.h - out.h/20;
|
840 |
+
|
841 |
+
int border = .01*out.h;
|
842 |
+
int h = y2 - y1 - 2*border;
|
843 |
+
int w = x2 - x1 - 2*border;
|
844 |
+
|
845 |
+
float *predictions = network_predict(net, in_s.data);
|
846 |
+
float curr_threat = 0;
|
847 |
+
if(1){
|
848 |
+
curr_threat = predictions[0] * 0 +
|
849 |
+
predictions[1] * .6 +
|
850 |
+
predictions[2];
|
851 |
+
} else {
|
852 |
+
curr_threat = predictions[218] +
|
853 |
+
predictions[539] +
|
854 |
+
predictions[540] +
|
855 |
+
predictions[368] +
|
856 |
+
predictions[369] +
|
857 |
+
predictions[370];
|
858 |
+
}
|
859 |
+
threat = roll * curr_threat + (1-roll) * threat;
|
860 |
+
|
861 |
+
draw_box_width(out, x2 + border, y1 + .02*h, x2 + .5 * w, y1 + .02*h + border, border, 0,0,0);
|
862 |
+
if(threat > .97) {
|
863 |
+
draw_box_width(out, x2 + .5 * w + border,
|
864 |
+
y1 + .02*h - 2*border,
|
865 |
+
x2 + .5 * w + 6*border,
|
866 |
+
y1 + .02*h + 3*border, 3*border, 1,0,0);
|
867 |
+
}
|
868 |
+
draw_box_width(out, x2 + .5 * w + border,
|
869 |
+
y1 + .02*h - 2*border,
|
870 |
+
x2 + .5 * w + 6*border,
|
871 |
+
y1 + .02*h + 3*border, .5*border, 0,0,0);
|
872 |
+
draw_box_width(out, x2 + border, y1 + .42*h, x2 + .5 * w, y1 + .42*h + border, border, 0,0,0);
|
873 |
+
if(threat > .57) {
|
874 |
+
draw_box_width(out, x2 + .5 * w + border,
|
875 |
+
y1 + .42*h - 2*border,
|
876 |
+
x2 + .5 * w + 6*border,
|
877 |
+
y1 + .42*h + 3*border, 3*border, 1,1,0);
|
878 |
+
}
|
879 |
+
draw_box_width(out, x2 + .5 * w + border,
|
880 |
+
y1 + .42*h - 2*border,
|
881 |
+
x2 + .5 * w + 6*border,
|
882 |
+
y1 + .42*h + 3*border, .5*border, 0,0,0);
|
883 |
+
|
884 |
+
draw_box_width(out, x1, y1, x2, y2, border, 0,0,0);
|
885 |
+
for(i = 0; i < threat * h ; ++i){
|
886 |
+
float ratio = (float) i / h;
|
887 |
+
float r = (ratio < .5) ? (2*(ratio)) : 1;
|
888 |
+
float g = (ratio < .5) ? 1 : 1 - 2*(ratio - .5);
|
889 |
+
draw_box_width(out, x1 + border, y2 - border - i, x2 - border, y2 - border - i, 1, r, g, 0);
|
890 |
+
}
|
891 |
+
top_predictions(net, top, indexes);
|
892 |
+
char buff[256];
|
893 |
+
sprintf(buff, "/home/pjreddie/tmp/threat_%06d", count);
|
894 |
+
//save_image(out, buff);
|
895 |
+
|
896 |
+
printf("\033[2J");
|
897 |
+
printf("\033[1;1H");
|
898 |
+
printf("\nFPS:%.0f\n",fps);
|
899 |
+
|
900 |
+
for(i = 0; i < top; ++i){
|
901 |
+
int index = indexes[i];
|
902 |
+
printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
|
903 |
+
}
|
904 |
+
|
905 |
+
if(1){
|
906 |
+
show_image(out, "Threat", 10);
|
907 |
+
}
|
908 |
+
free_image(in_s);
|
909 |
+
free_image(in);
|
910 |
+
|
911 |
+
gettimeofday(&tval_after, NULL);
|
912 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
913 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
914 |
+
fps = .9*fps + .1*curr;
|
915 |
+
}
|
916 |
+
#endif
|
917 |
+
}
|
918 |
+
|
919 |
+
|
920 |
+
void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
|
921 |
+
{
|
922 |
+
#ifdef OPENCV
|
923 |
+
int bad_cats[] = {218, 539, 540, 1213, 1501, 1742, 1911, 2415, 4348, 19223, 368, 369, 370, 1133, 1200, 1306, 2122, 2301, 2537, 2823, 3179, 3596, 3639, 4489, 5107, 5140, 5289, 6240, 6631, 6762, 7048, 7171, 7969, 7984, 7989, 8824, 8927, 9915, 10270, 10448, 13401, 15205, 18358, 18894, 18895, 19249, 19697};
|
924 |
+
|
925 |
+
printf("Classifier Demo\n");
|
926 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
927 |
+
set_batch_network(net, 1);
|
928 |
+
list *options = read_data_cfg(datacfg);
|
929 |
+
|
930 |
+
srand(2222222);
|
931 |
+
void * cap = open_video_stream(filename, cam_index, 0,0,0);
|
932 |
+
|
933 |
+
int top = option_find_int(options, "top", 1);
|
934 |
+
|
935 |
+
char *name_list = option_find_str(options, "names", 0);
|
936 |
+
char **names = get_labels(name_list);
|
937 |
+
|
938 |
+
int *indexes = calloc(top, sizeof(int));
|
939 |
+
|
940 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
941 |
+
float fps = 0;
|
942 |
+
int i;
|
943 |
+
|
944 |
+
while(1){
|
945 |
+
struct timeval tval_before, tval_after, tval_result;
|
946 |
+
gettimeofday(&tval_before, NULL);
|
947 |
+
|
948 |
+
image in = get_image_from_stream(cap);
|
949 |
+
image in_s = resize_image(in, net->w, net->h);
|
950 |
+
|
951 |
+
float *predictions = network_predict(net, in_s.data);
|
952 |
+
top_predictions(net, top, indexes);
|
953 |
+
|
954 |
+
printf("\033[2J");
|
955 |
+
printf("\033[1;1H");
|
956 |
+
|
957 |
+
int threat = 0;
|
958 |
+
for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
|
959 |
+
int index = bad_cats[i];
|
960 |
+
if(predictions[index] > .01){
|
961 |
+
printf("Threat Detected!\n");
|
962 |
+
threat = 1;
|
963 |
+
break;
|
964 |
+
}
|
965 |
+
}
|
966 |
+
if(!threat) printf("Scanning...\n");
|
967 |
+
for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
|
968 |
+
int index = bad_cats[i];
|
969 |
+
if(predictions[index] > .01){
|
970 |
+
printf("%s\n", names[index]);
|
971 |
+
}
|
972 |
+
}
|
973 |
+
|
974 |
+
show_image(in, "Threat Detection", 10);
|
975 |
+
free_image(in_s);
|
976 |
+
free_image(in);
|
977 |
+
|
978 |
+
gettimeofday(&tval_after, NULL);
|
979 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
980 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
981 |
+
fps = .9*fps + .1*curr;
|
982 |
+
}
|
983 |
+
#endif
|
984 |
+
}
|
985 |
+
|
986 |
+
void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
|
987 |
+
{
|
988 |
+
#ifdef OPENCV
|
989 |
+
char *base = basecfg(cfgfile);
|
990 |
+
image **alphabet = load_alphabet();
|
991 |
+
printf("Classifier Demo\n");
|
992 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
993 |
+
set_batch_network(net, 1);
|
994 |
+
list *options = read_data_cfg(datacfg);
|
995 |
+
|
996 |
+
srand(2222222);
|
997 |
+
|
998 |
+
int w = 1280;
|
999 |
+
int h = 720;
|
1000 |
+
void * cap = open_video_stream(filename, cam_index, w, h, 0);
|
1001 |
+
|
1002 |
+
int top = option_find_int(options, "top", 1);
|
1003 |
+
|
1004 |
+
char *label_list = option_find_str(options, "labels", 0);
|
1005 |
+
char *name_list = option_find_str(options, "names", label_list);
|
1006 |
+
char **names = get_labels(name_list);
|
1007 |
+
|
1008 |
+
int *indexes = calloc(top, sizeof(int));
|
1009 |
+
|
1010 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
1011 |
+
float fps = 0;
|
1012 |
+
int i;
|
1013 |
+
|
1014 |
+
while(1){
|
1015 |
+
struct timeval tval_before, tval_after, tval_result;
|
1016 |
+
gettimeofday(&tval_before, NULL);
|
1017 |
+
|
1018 |
+
image in = get_image_from_stream(cap);
|
1019 |
+
//image in_s = resize_image(in, net->w, net->h);
|
1020 |
+
image in_s = letterbox_image(in, net->w, net->h);
|
1021 |
+
|
1022 |
+
float *predictions = network_predict(net, in_s.data);
|
1023 |
+
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
|
1024 |
+
top_predictions(net, top, indexes);
|
1025 |
+
|
1026 |
+
printf("\033[2J");
|
1027 |
+
printf("\033[1;1H");
|
1028 |
+
printf("\nFPS:%.0f\n",fps);
|
1029 |
+
|
1030 |
+
int lh = in.h*.03;
|
1031 |
+
int toph = 3*lh;
|
1032 |
+
|
1033 |
+
float rgb[3] = {1,1,1};
|
1034 |
+
for(i = 0; i < top; ++i){
|
1035 |
+
printf("%d\n", toph);
|
1036 |
+
int index = indexes[i];
|
1037 |
+
printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
|
1038 |
+
|
1039 |
+
char buff[1024];
|
1040 |
+
sprintf(buff, "%3.1f%%: %s\n", predictions[index]*100, names[index]);
|
1041 |
+
image label = get_label(alphabet, buff, lh);
|
1042 |
+
draw_label(in, toph, lh, label, rgb);
|
1043 |
+
toph += 2*lh;
|
1044 |
+
free_image(label);
|
1045 |
+
}
|
1046 |
+
|
1047 |
+
show_image(in, base, 10);
|
1048 |
+
free_image(in_s);
|
1049 |
+
free_image(in);
|
1050 |
+
|
1051 |
+
gettimeofday(&tval_after, NULL);
|
1052 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
1053 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
1054 |
+
fps = .9*fps + .1*curr;
|
1055 |
+
}
|
1056 |
+
#endif
|
1057 |
+
}
|
1058 |
+
|
1059 |
+
|
1060 |
+
void run_classifier(int argc, char **argv)
|
1061 |
+
{
|
1062 |
+
if(argc < 4){
|
1063 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
1064 |
+
return;
|
1065 |
+
}
|
1066 |
+
|
1067 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
1068 |
+
int ngpus;
|
1069 |
+
int *gpus = read_intlist(gpu_list, &ngpus, gpu_index);
|
1070 |
+
|
1071 |
+
|
1072 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
1073 |
+
int top = find_int_arg(argc, argv, "-t", 0);
|
1074 |
+
int clear = find_arg(argc, argv, "-clear");
|
1075 |
+
char *data = argv[3];
|
1076 |
+
char *cfg = argv[4];
|
1077 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
1078 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
1079 |
+
char *layer_s = (argc > 7) ? argv[7]: 0;
|
1080 |
+
int layer = layer_s ? atoi(layer_s) : -1;
|
1081 |
+
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
|
1082 |
+
else if(0==strcmp(argv[2], "fout")) file_output_classifier(data, cfg, weights, filename);
|
1083 |
+
else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
|
1084 |
+
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, gpus, ngpus, clear);
|
1085 |
+
else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
|
1086 |
+
else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
|
1087 |
+
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
|
1088 |
+
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
|
1089 |
+
else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights);
|
1090 |
+
else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
|
1091 |
+
else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights);
|
1092 |
+
else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);
|
1093 |
+
else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights);
|
1094 |
+
else if(0==strcmp(argv[2], "validcrop")) validate_classifier_crop(data, cfg, weights);
|
1095 |
+
else if(0==strcmp(argv[2], "validfull")) validate_classifier_full(data, cfg, weights);
|
1096 |
+
}
|
1097 |
+
|
1098 |
+
|
model/examples/coco.c
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <stdio.h>
|
4 |
+
|
5 |
+
char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus","train","truck","boat","traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack","umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair","couch","potted plant","bed","dining table","toilet","tv","laptop","mouse","remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush"};
|
6 |
+
|
7 |
+
int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
8 |
+
|
9 |
+
void train_coco(char *cfgfile, char *weightfile)
|
10 |
+
{
|
11 |
+
//char *train_images = "/home/pjreddie/data/voc/test/train.txt";
|
12 |
+
//char *train_images = "/home/pjreddie/data/coco/train.txt";
|
13 |
+
char *train_images = "data/coco.trainval.txt";
|
14 |
+
//char *train_images = "data/bags.train.list";
|
15 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
16 |
+
srand(time(0));
|
17 |
+
char *base = basecfg(cfgfile);
|
18 |
+
printf("%s\n", base);
|
19 |
+
float avg_loss = -1;
|
20 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
21 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
22 |
+
int imgs = net->batch*net->subdivisions;
|
23 |
+
int i = *net->seen/imgs;
|
24 |
+
data train, buffer;
|
25 |
+
|
26 |
+
|
27 |
+
layer l = net->layers[net->n - 1];
|
28 |
+
|
29 |
+
int side = l.side;
|
30 |
+
int classes = l.classes;
|
31 |
+
float jitter = l.jitter;
|
32 |
+
|
33 |
+
list *plist = get_paths(train_images);
|
34 |
+
//int N = plist->size;
|
35 |
+
char **paths = (char **)list_to_array(plist);
|
36 |
+
|
37 |
+
load_args args = {0};
|
38 |
+
args.w = net->w;
|
39 |
+
args.h = net->h;
|
40 |
+
args.paths = paths;
|
41 |
+
args.n = imgs;
|
42 |
+
args.m = plist->size;
|
43 |
+
args.classes = classes;
|
44 |
+
args.jitter = jitter;
|
45 |
+
args.num_boxes = side;
|
46 |
+
args.d = &buffer;
|
47 |
+
args.type = REGION_DATA;
|
48 |
+
|
49 |
+
args.angle = net->angle;
|
50 |
+
args.exposure = net->exposure;
|
51 |
+
args.saturation = net->saturation;
|
52 |
+
args.hue = net->hue;
|
53 |
+
|
54 |
+
pthread_t load_thread = load_data_in_thread(args);
|
55 |
+
clock_t time;
|
56 |
+
//while(i*imgs < N*120){
|
57 |
+
while(get_current_batch(net) < net->max_batches){
|
58 |
+
i += 1;
|
59 |
+
time=clock();
|
60 |
+
pthread_join(load_thread, 0);
|
61 |
+
train = buffer;
|
62 |
+
load_thread = load_data_in_thread(args);
|
63 |
+
|
64 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
65 |
+
|
66 |
+
/*
|
67 |
+
image im = float_to_image(net->w, net->h, 3, train.X.vals[113]);
|
68 |
+
image copy = copy_image(im);
|
69 |
+
draw_coco(copy, train.y.vals[113], 7, "truth");
|
70 |
+
cvWaitKey(0);
|
71 |
+
free_image(copy);
|
72 |
+
*/
|
73 |
+
|
74 |
+
time=clock();
|
75 |
+
float loss = train_network(net, train);
|
76 |
+
if (avg_loss < 0) avg_loss = loss;
|
77 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
78 |
+
|
79 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
80 |
+
if(i%1000==0 || (i < 1000 && i%100 == 0)){
|
81 |
+
char buff[256];
|
82 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
83 |
+
save_weights(net, buff);
|
84 |
+
}
|
85 |
+
if(i%100==0){
|
86 |
+
char buff[256];
|
87 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
88 |
+
save_weights(net, buff);
|
89 |
+
}
|
90 |
+
free_data(train);
|
91 |
+
}
|
92 |
+
char buff[256];
|
93 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
94 |
+
save_weights(net, buff);
|
95 |
+
}
|
96 |
+
|
97 |
+
static void print_cocos(FILE *fp, int image_id, detection *dets, int num_boxes, int classes, int w, int h)
|
98 |
+
{
|
99 |
+
int i, j;
|
100 |
+
for(i = 0; i < num_boxes; ++i){
|
101 |
+
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
102 |
+
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
103 |
+
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
104 |
+
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
105 |
+
|
106 |
+
if (xmin < 0) xmin = 0;
|
107 |
+
if (ymin < 0) ymin = 0;
|
108 |
+
if (xmax > w) xmax = w;
|
109 |
+
if (ymax > h) ymax = h;
|
110 |
+
|
111 |
+
float bx = xmin;
|
112 |
+
float by = ymin;
|
113 |
+
float bw = xmax - xmin;
|
114 |
+
float bh = ymax - ymin;
|
115 |
+
|
116 |
+
for(j = 0; j < classes; ++j){
|
117 |
+
if (dets[i].prob[j]) fprintf(fp, "{\"image_id\":%d, \"category_id\":%d, \"bbox\":[%f, %f, %f, %f], \"score\":%f},\n", image_id, coco_ids[j], bx, by, bw, bh, dets[i].prob[j]);
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
int get_coco_image_id(char *filename)
|
123 |
+
{
|
124 |
+
char *p = strrchr(filename, '_');
|
125 |
+
return atoi(p+1);
|
126 |
+
}
|
127 |
+
|
128 |
+
void validate_coco(char *cfg, char *weights)
|
129 |
+
{
|
130 |
+
network *net = load_network(cfg, weights, 0);
|
131 |
+
set_batch_network(net, 1);
|
132 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
133 |
+
srand(time(0));
|
134 |
+
|
135 |
+
char *base = "results/";
|
136 |
+
list *plist = get_paths("data/coco_val_5k.list");
|
137 |
+
//list *plist = get_paths("/home/pjreddie/data/people-art/test.txt");
|
138 |
+
//list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
|
139 |
+
char **paths = (char **)list_to_array(plist);
|
140 |
+
|
141 |
+
layer l = net->layers[net->n-1];
|
142 |
+
int classes = l.classes;
|
143 |
+
|
144 |
+
char buff[1024];
|
145 |
+
snprintf(buff, 1024, "%s/coco_results.json", base);
|
146 |
+
FILE *fp = fopen(buff, "w");
|
147 |
+
fprintf(fp, "[\n");
|
148 |
+
|
149 |
+
int m = plist->size;
|
150 |
+
int i=0;
|
151 |
+
int t;
|
152 |
+
|
153 |
+
float thresh = .01;
|
154 |
+
int nms = 1;
|
155 |
+
float iou_thresh = .5;
|
156 |
+
|
157 |
+
int nthreads = 8;
|
158 |
+
image *val = calloc(nthreads, sizeof(image));
|
159 |
+
image *val_resized = calloc(nthreads, sizeof(image));
|
160 |
+
image *buf = calloc(nthreads, sizeof(image));
|
161 |
+
image *buf_resized = calloc(nthreads, sizeof(image));
|
162 |
+
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
163 |
+
|
164 |
+
load_args args = {0};
|
165 |
+
args.w = net->w;
|
166 |
+
args.h = net->h;
|
167 |
+
args.type = IMAGE_DATA;
|
168 |
+
|
169 |
+
for(t = 0; t < nthreads; ++t){
|
170 |
+
args.path = paths[i+t];
|
171 |
+
args.im = &buf[t];
|
172 |
+
args.resized = &buf_resized[t];
|
173 |
+
thr[t] = load_data_in_thread(args);
|
174 |
+
}
|
175 |
+
time_t start = time(0);
|
176 |
+
for(i = nthreads; i < m+nthreads; i += nthreads){
|
177 |
+
fprintf(stderr, "%d\n", i);
|
178 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
179 |
+
pthread_join(thr[t], 0);
|
180 |
+
val[t] = buf[t];
|
181 |
+
val_resized[t] = buf_resized[t];
|
182 |
+
}
|
183 |
+
for(t = 0; t < nthreads && i+t < m; ++t){
|
184 |
+
args.path = paths[i+t];
|
185 |
+
args.im = &buf[t];
|
186 |
+
args.resized = &buf_resized[t];
|
187 |
+
thr[t] = load_data_in_thread(args);
|
188 |
+
}
|
189 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
190 |
+
char *path = paths[i+t-nthreads];
|
191 |
+
int image_id = get_coco_image_id(path);
|
192 |
+
float *X = val_resized[t].data;
|
193 |
+
network_predict(net, X);
|
194 |
+
int w = val[t].w;
|
195 |
+
int h = val[t].h;
|
196 |
+
int nboxes = 0;
|
197 |
+
detection *dets = get_network_boxes(net, w, h, thresh, 0, 0, 0, &nboxes);
|
198 |
+
if (nms) do_nms_sort(dets, l.side*l.side*l.n, classes, iou_thresh);
|
199 |
+
print_cocos(fp, image_id, dets, l.side*l.side*l.n, classes, w, h);
|
200 |
+
free_detections(dets, nboxes);
|
201 |
+
free_image(val[t]);
|
202 |
+
free_image(val_resized[t]);
|
203 |
+
}
|
204 |
+
}
|
205 |
+
fseek(fp, -2, SEEK_CUR);
|
206 |
+
fprintf(fp, "\n]\n");
|
207 |
+
fclose(fp);
|
208 |
+
|
209 |
+
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
|
210 |
+
}
|
211 |
+
|
212 |
+
void validate_coco_recall(char *cfgfile, char *weightfile)
|
213 |
+
{
|
214 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
215 |
+
set_batch_network(net, 1);
|
216 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
217 |
+
srand(time(0));
|
218 |
+
|
219 |
+
char *base = "results/comp4_det_test_";
|
220 |
+
list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
|
221 |
+
char **paths = (char **)list_to_array(plist);
|
222 |
+
|
223 |
+
layer l = net->layers[net->n-1];
|
224 |
+
int classes = l.classes;
|
225 |
+
int side = l.side;
|
226 |
+
|
227 |
+
int j, k;
|
228 |
+
FILE **fps = calloc(classes, sizeof(FILE *));
|
229 |
+
for(j = 0; j < classes; ++j){
|
230 |
+
char buff[1024];
|
231 |
+
snprintf(buff, 1024, "%s%s.txt", base, coco_classes[j]);
|
232 |
+
fps[j] = fopen(buff, "w");
|
233 |
+
}
|
234 |
+
|
235 |
+
int m = plist->size;
|
236 |
+
int i=0;
|
237 |
+
|
238 |
+
float thresh = .001;
|
239 |
+
int nms = 0;
|
240 |
+
float iou_thresh = .5;
|
241 |
+
|
242 |
+
int total = 0;
|
243 |
+
int correct = 0;
|
244 |
+
int proposals = 0;
|
245 |
+
float avg_iou = 0;
|
246 |
+
|
247 |
+
for(i = 0; i < m; ++i){
|
248 |
+
char *path = paths[i];
|
249 |
+
image orig = load_image_color(path, 0, 0);
|
250 |
+
image sized = resize_image(orig, net->w, net->h);
|
251 |
+
char *id = basecfg(path);
|
252 |
+
network_predict(net, sized.data);
|
253 |
+
|
254 |
+
int nboxes = 0;
|
255 |
+
detection *dets = get_network_boxes(net, orig.w, orig.h, thresh, 0, 0, 1, &nboxes);
|
256 |
+
if (nms) do_nms_obj(dets, side*side*l.n, 1, nms);
|
257 |
+
|
258 |
+
char labelpath[4096];
|
259 |
+
find_replace(path, "images", "labels", labelpath);
|
260 |
+
find_replace(labelpath, "JPEGImages", "labels", labelpath);
|
261 |
+
find_replace(labelpath, ".jpg", ".txt", labelpath);
|
262 |
+
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
263 |
+
|
264 |
+
int num_labels = 0;
|
265 |
+
box_label *truth = read_boxes(labelpath, &num_labels);
|
266 |
+
for(k = 0; k < side*side*l.n; ++k){
|
267 |
+
if(dets[k].objectness > thresh){
|
268 |
+
++proposals;
|
269 |
+
}
|
270 |
+
}
|
271 |
+
for (j = 0; j < num_labels; ++j) {
|
272 |
+
++total;
|
273 |
+
box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
|
274 |
+
float best_iou = 0;
|
275 |
+
for(k = 0; k < side*side*l.n; ++k){
|
276 |
+
float iou = box_iou(dets[k].bbox, t);
|
277 |
+
if(dets[k].objectness > thresh && iou > best_iou){
|
278 |
+
best_iou = iou;
|
279 |
+
}
|
280 |
+
}
|
281 |
+
avg_iou += best_iou;
|
282 |
+
if(best_iou > iou_thresh){
|
283 |
+
++correct;
|
284 |
+
}
|
285 |
+
}
|
286 |
+
free_detections(dets, nboxes);
|
287 |
+
fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
|
288 |
+
free(id);
|
289 |
+
free_image(orig);
|
290 |
+
free_image(sized);
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh)
|
295 |
+
{
|
296 |
+
image **alphabet = load_alphabet();
|
297 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
298 |
+
layer l = net->layers[net->n-1];
|
299 |
+
set_batch_network(net, 1);
|
300 |
+
srand(2222222);
|
301 |
+
float nms = .4;
|
302 |
+
clock_t time;
|
303 |
+
char buff[256];
|
304 |
+
char *input = buff;
|
305 |
+
while(1){
|
306 |
+
if(filename){
|
307 |
+
strncpy(input, filename, 256);
|
308 |
+
} else {
|
309 |
+
printf("Enter Image Path: ");
|
310 |
+
fflush(stdout);
|
311 |
+
input = fgets(input, 256, stdin);
|
312 |
+
if(!input) return;
|
313 |
+
strtok(input, "\n");
|
314 |
+
}
|
315 |
+
image im = load_image_color(input,0,0);
|
316 |
+
image sized = resize_image(im, net->w, net->h);
|
317 |
+
float *X = sized.data;
|
318 |
+
time=clock();
|
319 |
+
network_predict(net, X);
|
320 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
321 |
+
|
322 |
+
int nboxes = 0;
|
323 |
+
detection *dets = get_network_boxes(net, 1, 1, thresh, 0, 0, 0, &nboxes);
|
324 |
+
if (nms) do_nms_sort(dets, l.side*l.side*l.n, l.classes, nms);
|
325 |
+
|
326 |
+
draw_detections(im, dets, l.side*l.side*l.n, thresh, coco_classes, alphabet, 80);
|
327 |
+
save_image(im, "prediction");
|
328 |
+
show_image(im, "predictions", 0);
|
329 |
+
free_detections(dets, nboxes);
|
330 |
+
free_image(im);
|
331 |
+
free_image(sized);
|
332 |
+
if (filename) break;
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
void run_coco(int argc, char **argv)
|
337 |
+
{
|
338 |
+
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
339 |
+
float thresh = find_float_arg(argc, argv, "-thresh", .2);
|
340 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
341 |
+
int frame_skip = find_int_arg(argc, argv, "-s", 0);
|
342 |
+
|
343 |
+
if(argc < 4){
|
344 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
345 |
+
return;
|
346 |
+
}
|
347 |
+
|
348 |
+
char *cfg = argv[3];
|
349 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
350 |
+
char *filename = (argc > 5) ? argv[5]: 0;
|
351 |
+
int avg = find_int_arg(argc, argv, "-avg", 1);
|
352 |
+
if(0==strcmp(argv[2], "test")) test_coco(cfg, weights, filename, thresh);
|
353 |
+
else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
|
354 |
+
else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights);
|
355 |
+
else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights);
|
356 |
+
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, avg, .5, 0,0,0,0);
|
357 |
+
}
|
model/examples/darknet.c
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <time.h>
|
4 |
+
#include <stdlib.h>
|
5 |
+
#include <stdio.h>
|
6 |
+
|
7 |
+
extern void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top);
|
8 |
+
extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh, char *outfile, int fullscreen);
|
9 |
+
extern void run_yolo(int argc, char **argv);
|
10 |
+
extern void run_detector(int argc, char **argv);
|
11 |
+
extern void run_coco(int argc, char **argv);
|
12 |
+
extern void run_nightmare(int argc, char **argv);
|
13 |
+
extern void run_classifier(int argc, char **argv);
|
14 |
+
extern void run_regressor(int argc, char **argv);
|
15 |
+
extern void run_segmenter(int argc, char **argv);
|
16 |
+
extern void run_isegmenter(int argc, char **argv);
|
17 |
+
extern void run_char_rnn(int argc, char **argv);
|
18 |
+
extern void run_tag(int argc, char **argv);
|
19 |
+
extern void run_cifar(int argc, char **argv);
|
20 |
+
extern void run_go(int argc, char **argv);
|
21 |
+
extern void run_art(int argc, char **argv);
|
22 |
+
extern void run_super(int argc, char **argv);
|
23 |
+
extern void run_lsd(int argc, char **argv);
|
24 |
+
|
25 |
+
void average(int argc, char *argv[])
|
26 |
+
{
|
27 |
+
char *cfgfile = argv[2];
|
28 |
+
char *outfile = argv[3];
|
29 |
+
gpu_index = -1;
|
30 |
+
network *net = parse_network_cfg(cfgfile);
|
31 |
+
network *sum = parse_network_cfg(cfgfile);
|
32 |
+
|
33 |
+
char *weightfile = argv[4];
|
34 |
+
load_weights(sum, weightfile);
|
35 |
+
|
36 |
+
int i, j;
|
37 |
+
int n = argc - 5;
|
38 |
+
for(i = 0; i < n; ++i){
|
39 |
+
weightfile = argv[i+5];
|
40 |
+
load_weights(net, weightfile);
|
41 |
+
for(j = 0; j < net->n; ++j){
|
42 |
+
layer l = net->layers[j];
|
43 |
+
layer out = sum->layers[j];
|
44 |
+
if(l.type == CONVOLUTIONAL){
|
45 |
+
int num = l.n*l.c*l.size*l.size;
|
46 |
+
axpy_cpu(l.n, 1, l.biases, 1, out.biases, 1);
|
47 |
+
axpy_cpu(num, 1, l.weights, 1, out.weights, 1);
|
48 |
+
if(l.batch_normalize){
|
49 |
+
axpy_cpu(l.n, 1, l.scales, 1, out.scales, 1);
|
50 |
+
axpy_cpu(l.n, 1, l.rolling_mean, 1, out.rolling_mean, 1);
|
51 |
+
axpy_cpu(l.n, 1, l.rolling_variance, 1, out.rolling_variance, 1);
|
52 |
+
}
|
53 |
+
}
|
54 |
+
if(l.type == CONNECTED){
|
55 |
+
axpy_cpu(l.outputs, 1, l.biases, 1, out.biases, 1);
|
56 |
+
axpy_cpu(l.outputs*l.inputs, 1, l.weights, 1, out.weights, 1);
|
57 |
+
}
|
58 |
+
}
|
59 |
+
}
|
60 |
+
n = n+1;
|
61 |
+
for(j = 0; j < net->n; ++j){
|
62 |
+
layer l = sum->layers[j];
|
63 |
+
if(l.type == CONVOLUTIONAL){
|
64 |
+
int num = l.n*l.c*l.size*l.size;
|
65 |
+
scal_cpu(l.n, 1./n, l.biases, 1);
|
66 |
+
scal_cpu(num, 1./n, l.weights, 1);
|
67 |
+
if(l.batch_normalize){
|
68 |
+
scal_cpu(l.n, 1./n, l.scales, 1);
|
69 |
+
scal_cpu(l.n, 1./n, l.rolling_mean, 1);
|
70 |
+
scal_cpu(l.n, 1./n, l.rolling_variance, 1);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
if(l.type == CONNECTED){
|
74 |
+
scal_cpu(l.outputs, 1./n, l.biases, 1);
|
75 |
+
scal_cpu(l.outputs*l.inputs, 1./n, l.weights, 1);
|
76 |
+
}
|
77 |
+
}
|
78 |
+
save_weights(sum, outfile);
|
79 |
+
}
|
80 |
+
|
81 |
+
long numops(network *net)
|
82 |
+
{
|
83 |
+
int i;
|
84 |
+
long ops = 0;
|
85 |
+
for(i = 0; i < net->n; ++i){
|
86 |
+
layer l = net->layers[i];
|
87 |
+
if(l.type == CONVOLUTIONAL){
|
88 |
+
ops += 2l * l.n * l.size*l.size*l.c/l.groups * l.out_h*l.out_w;
|
89 |
+
} else if(l.type == CONNECTED){
|
90 |
+
ops += 2l * l.inputs * l.outputs;
|
91 |
+
} else if (l.type == RNN){
|
92 |
+
ops += 2l * l.input_layer->inputs * l.input_layer->outputs;
|
93 |
+
ops += 2l * l.self_layer->inputs * l.self_layer->outputs;
|
94 |
+
ops += 2l * l.output_layer->inputs * l.output_layer->outputs;
|
95 |
+
} else if (l.type == GRU){
|
96 |
+
ops += 2l * l.uz->inputs * l.uz->outputs;
|
97 |
+
ops += 2l * l.uh->inputs * l.uh->outputs;
|
98 |
+
ops += 2l * l.ur->inputs * l.ur->outputs;
|
99 |
+
ops += 2l * l.wz->inputs * l.wz->outputs;
|
100 |
+
ops += 2l * l.wh->inputs * l.wh->outputs;
|
101 |
+
ops += 2l * l.wr->inputs * l.wr->outputs;
|
102 |
+
} else if (l.type == LSTM){
|
103 |
+
ops += 2l * l.uf->inputs * l.uf->outputs;
|
104 |
+
ops += 2l * l.ui->inputs * l.ui->outputs;
|
105 |
+
ops += 2l * l.ug->inputs * l.ug->outputs;
|
106 |
+
ops += 2l * l.uo->inputs * l.uo->outputs;
|
107 |
+
ops += 2l * l.wf->inputs * l.wf->outputs;
|
108 |
+
ops += 2l * l.wi->inputs * l.wi->outputs;
|
109 |
+
ops += 2l * l.wg->inputs * l.wg->outputs;
|
110 |
+
ops += 2l * l.wo->inputs * l.wo->outputs;
|
111 |
+
}
|
112 |
+
}
|
113 |
+
return ops;
|
114 |
+
}
|
115 |
+
|
116 |
+
void speed(char *cfgfile, int tics)
|
117 |
+
{
|
118 |
+
if (tics == 0) tics = 1000;
|
119 |
+
network *net = parse_network_cfg(cfgfile);
|
120 |
+
set_batch_network(net, 1);
|
121 |
+
int i;
|
122 |
+
double time=what_time_is_it_now();
|
123 |
+
image im = make_image(net->w, net->h, net->c*net->batch);
|
124 |
+
for(i = 0; i < tics; ++i){
|
125 |
+
network_predict(net, im.data);
|
126 |
+
}
|
127 |
+
double t = what_time_is_it_now() - time;
|
128 |
+
long ops = numops(net);
|
129 |
+
printf("\n%d evals, %f Seconds\n", tics, t);
|
130 |
+
printf("Floating Point Operations: %.2f Bn\n", (float)ops/1000000000.);
|
131 |
+
printf("FLOPS: %.2f Bn\n", (float)ops/1000000000.*tics/t);
|
132 |
+
printf("Speed: %f sec/eval\n", t/tics);
|
133 |
+
printf("Speed: %f Hz\n", tics/t);
|
134 |
+
}
|
135 |
+
|
136 |
+
void operations(char *cfgfile)
|
137 |
+
{
|
138 |
+
gpu_index = -1;
|
139 |
+
network *net = parse_network_cfg(cfgfile);
|
140 |
+
long ops = numops(net);
|
141 |
+
printf("Floating Point Operations: %ld\n", ops);
|
142 |
+
printf("Floating Point Operations: %.2f Bn\n", (float)ops/1000000000.);
|
143 |
+
}
|
144 |
+
|
145 |
+
void oneoff(char *cfgfile, char *weightfile, char *outfile)
|
146 |
+
{
|
147 |
+
gpu_index = -1;
|
148 |
+
network *net = parse_network_cfg(cfgfile);
|
149 |
+
int oldn = net->layers[net->n - 2].n;
|
150 |
+
int c = net->layers[net->n - 2].c;
|
151 |
+
scal_cpu(oldn*c, .1, net->layers[net->n - 2].weights, 1);
|
152 |
+
scal_cpu(oldn, 0, net->layers[net->n - 2].biases, 1);
|
153 |
+
net->layers[net->n - 2].n = 11921;
|
154 |
+
net->layers[net->n - 2].biases += 5;
|
155 |
+
net->layers[net->n - 2].weights += 5*c;
|
156 |
+
if(weightfile){
|
157 |
+
load_weights(net, weightfile);
|
158 |
+
}
|
159 |
+
net->layers[net->n - 2].biases -= 5;
|
160 |
+
net->layers[net->n - 2].weights -= 5*c;
|
161 |
+
net->layers[net->n - 2].n = oldn;
|
162 |
+
printf("%d\n", oldn);
|
163 |
+
layer l = net->layers[net->n - 2];
|
164 |
+
copy_cpu(l.n/3, l.biases, 1, l.biases + l.n/3, 1);
|
165 |
+
copy_cpu(l.n/3, l.biases, 1, l.biases + 2*l.n/3, 1);
|
166 |
+
copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + l.n/3*l.c, 1);
|
167 |
+
copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + 2*l.n/3*l.c, 1);
|
168 |
+
*net->seen = 0;
|
169 |
+
save_weights(net, outfile);
|
170 |
+
}
|
171 |
+
|
172 |
+
void oneoff2(char *cfgfile, char *weightfile, char *outfile, int l)
|
173 |
+
{
|
174 |
+
gpu_index = -1;
|
175 |
+
network *net = parse_network_cfg(cfgfile);
|
176 |
+
if(weightfile){
|
177 |
+
load_weights_upto(net, weightfile, 0, net->n);
|
178 |
+
load_weights_upto(net, weightfile, l, net->n);
|
179 |
+
}
|
180 |
+
*net->seen = 0;
|
181 |
+
save_weights_upto(net, outfile, net->n);
|
182 |
+
}
|
183 |
+
|
184 |
+
void partial(char *cfgfile, char *weightfile, char *outfile, int max)
|
185 |
+
{
|
186 |
+
gpu_index = -1;
|
187 |
+
network *net = load_network(cfgfile, weightfile, 1);
|
188 |
+
save_weights_upto(net, outfile, max);
|
189 |
+
}
|
190 |
+
|
191 |
+
void print_weights(char *cfgfile, char *weightfile, int n)
|
192 |
+
{
|
193 |
+
gpu_index = -1;
|
194 |
+
network *net = load_network(cfgfile, weightfile, 1);
|
195 |
+
layer l = net->layers[n];
|
196 |
+
int i, j;
|
197 |
+
//printf("[");
|
198 |
+
for(i = 0; i < l.n; ++i){
|
199 |
+
//printf("[");
|
200 |
+
for(j = 0; j < l.size*l.size*l.c; ++j){
|
201 |
+
//if(j > 0) printf(",");
|
202 |
+
printf("%g ", l.weights[i*l.size*l.size*l.c + j]);
|
203 |
+
}
|
204 |
+
printf("\n");
|
205 |
+
//printf("]%s\n", (i == l.n-1)?"":",");
|
206 |
+
}
|
207 |
+
//printf("]");
|
208 |
+
}
|
209 |
+
|
210 |
+
void rescale_net(char *cfgfile, char *weightfile, char *outfile)
|
211 |
+
{
|
212 |
+
gpu_index = -1;
|
213 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
214 |
+
int i;
|
215 |
+
for(i = 0; i < net->n; ++i){
|
216 |
+
layer l = net->layers[i];
|
217 |
+
if(l.type == CONVOLUTIONAL){
|
218 |
+
rescale_weights(l, 2, -.5);
|
219 |
+
break;
|
220 |
+
}
|
221 |
+
}
|
222 |
+
save_weights(net, outfile);
|
223 |
+
}
|
224 |
+
|
225 |
+
void rgbgr_net(char *cfgfile, char *weightfile, char *outfile)
|
226 |
+
{
|
227 |
+
gpu_index = -1;
|
228 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
229 |
+
int i;
|
230 |
+
for(i = 0; i < net->n; ++i){
|
231 |
+
layer l = net->layers[i];
|
232 |
+
if(l.type == CONVOLUTIONAL){
|
233 |
+
rgbgr_weights(l);
|
234 |
+
break;
|
235 |
+
}
|
236 |
+
}
|
237 |
+
save_weights(net, outfile);
|
238 |
+
}
|
239 |
+
|
240 |
+
void reset_normalize_net(char *cfgfile, char *weightfile, char *outfile)
|
241 |
+
{
|
242 |
+
gpu_index = -1;
|
243 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
244 |
+
int i;
|
245 |
+
for (i = 0; i < net->n; ++i) {
|
246 |
+
layer l = net->layers[i];
|
247 |
+
if (l.type == CONVOLUTIONAL && l.batch_normalize) {
|
248 |
+
denormalize_convolutional_layer(l);
|
249 |
+
}
|
250 |
+
if (l.type == CONNECTED && l.batch_normalize) {
|
251 |
+
denormalize_connected_layer(l);
|
252 |
+
}
|
253 |
+
if (l.type == GRU && l.batch_normalize) {
|
254 |
+
denormalize_connected_layer(*l.input_z_layer);
|
255 |
+
denormalize_connected_layer(*l.input_r_layer);
|
256 |
+
denormalize_connected_layer(*l.input_h_layer);
|
257 |
+
denormalize_connected_layer(*l.state_z_layer);
|
258 |
+
denormalize_connected_layer(*l.state_r_layer);
|
259 |
+
denormalize_connected_layer(*l.state_h_layer);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
save_weights(net, outfile);
|
263 |
+
}
|
264 |
+
|
265 |
+
layer normalize_layer(layer l, int n)
|
266 |
+
{
|
267 |
+
int j;
|
268 |
+
l.batch_normalize=1;
|
269 |
+
l.scales = calloc(n, sizeof(float));
|
270 |
+
for(j = 0; j < n; ++j){
|
271 |
+
l.scales[j] = 1;
|
272 |
+
}
|
273 |
+
l.rolling_mean = calloc(n, sizeof(float));
|
274 |
+
l.rolling_variance = calloc(n, sizeof(float));
|
275 |
+
return l;
|
276 |
+
}
|
277 |
+
|
278 |
+
void normalize_net(char *cfgfile, char *weightfile, char *outfile)
|
279 |
+
{
|
280 |
+
gpu_index = -1;
|
281 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
282 |
+
int i;
|
283 |
+
for(i = 0; i < net->n; ++i){
|
284 |
+
layer l = net->layers[i];
|
285 |
+
if(l.type == CONVOLUTIONAL && !l.batch_normalize){
|
286 |
+
net->layers[i] = normalize_layer(l, l.n);
|
287 |
+
}
|
288 |
+
if (l.type == CONNECTED && !l.batch_normalize) {
|
289 |
+
net->layers[i] = normalize_layer(l, l.outputs);
|
290 |
+
}
|
291 |
+
if (l.type == GRU && l.batch_normalize) {
|
292 |
+
*l.input_z_layer = normalize_layer(*l.input_z_layer, l.input_z_layer->outputs);
|
293 |
+
*l.input_r_layer = normalize_layer(*l.input_r_layer, l.input_r_layer->outputs);
|
294 |
+
*l.input_h_layer = normalize_layer(*l.input_h_layer, l.input_h_layer->outputs);
|
295 |
+
*l.state_z_layer = normalize_layer(*l.state_z_layer, l.state_z_layer->outputs);
|
296 |
+
*l.state_r_layer = normalize_layer(*l.state_r_layer, l.state_r_layer->outputs);
|
297 |
+
*l.state_h_layer = normalize_layer(*l.state_h_layer, l.state_h_layer->outputs);
|
298 |
+
net->layers[i].batch_normalize=1;
|
299 |
+
}
|
300 |
+
}
|
301 |
+
save_weights(net, outfile);
|
302 |
+
}
|
303 |
+
|
304 |
+
void statistics_net(char *cfgfile, char *weightfile)
|
305 |
+
{
|
306 |
+
gpu_index = -1;
|
307 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
308 |
+
int i;
|
309 |
+
for (i = 0; i < net->n; ++i) {
|
310 |
+
layer l = net->layers[i];
|
311 |
+
if (l.type == CONNECTED && l.batch_normalize) {
|
312 |
+
printf("Connected Layer %d\n", i);
|
313 |
+
statistics_connected_layer(l);
|
314 |
+
}
|
315 |
+
if (l.type == GRU && l.batch_normalize) {
|
316 |
+
printf("GRU Layer %d\n", i);
|
317 |
+
printf("Input Z\n");
|
318 |
+
statistics_connected_layer(*l.input_z_layer);
|
319 |
+
printf("Input R\n");
|
320 |
+
statistics_connected_layer(*l.input_r_layer);
|
321 |
+
printf("Input H\n");
|
322 |
+
statistics_connected_layer(*l.input_h_layer);
|
323 |
+
printf("State Z\n");
|
324 |
+
statistics_connected_layer(*l.state_z_layer);
|
325 |
+
printf("State R\n");
|
326 |
+
statistics_connected_layer(*l.state_r_layer);
|
327 |
+
printf("State H\n");
|
328 |
+
statistics_connected_layer(*l.state_h_layer);
|
329 |
+
}
|
330 |
+
printf("\n");
|
331 |
+
}
|
332 |
+
}
|
333 |
+
|
334 |
+
void denormalize_net(char *cfgfile, char *weightfile, char *outfile)
|
335 |
+
{
|
336 |
+
gpu_index = -1;
|
337 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
338 |
+
int i;
|
339 |
+
for (i = 0; i < net->n; ++i) {
|
340 |
+
layer l = net->layers[i];
|
341 |
+
if ((l.type == DECONVOLUTIONAL || l.type == CONVOLUTIONAL) && l.batch_normalize) {
|
342 |
+
denormalize_convolutional_layer(l);
|
343 |
+
net->layers[i].batch_normalize=0;
|
344 |
+
}
|
345 |
+
if (l.type == CONNECTED && l.batch_normalize) {
|
346 |
+
denormalize_connected_layer(l);
|
347 |
+
net->layers[i].batch_normalize=0;
|
348 |
+
}
|
349 |
+
if (l.type == GRU && l.batch_normalize) {
|
350 |
+
denormalize_connected_layer(*l.input_z_layer);
|
351 |
+
denormalize_connected_layer(*l.input_r_layer);
|
352 |
+
denormalize_connected_layer(*l.input_h_layer);
|
353 |
+
denormalize_connected_layer(*l.state_z_layer);
|
354 |
+
denormalize_connected_layer(*l.state_r_layer);
|
355 |
+
denormalize_connected_layer(*l.state_h_layer);
|
356 |
+
l.input_z_layer->batch_normalize = 0;
|
357 |
+
l.input_r_layer->batch_normalize = 0;
|
358 |
+
l.input_h_layer->batch_normalize = 0;
|
359 |
+
l.state_z_layer->batch_normalize = 0;
|
360 |
+
l.state_r_layer->batch_normalize = 0;
|
361 |
+
l.state_h_layer->batch_normalize = 0;
|
362 |
+
net->layers[i].batch_normalize=0;
|
363 |
+
}
|
364 |
+
}
|
365 |
+
save_weights(net, outfile);
|
366 |
+
}
|
367 |
+
|
368 |
+
void mkimg(char *cfgfile, char *weightfile, int h, int w, int num, char *prefix)
|
369 |
+
{
|
370 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
371 |
+
image *ims = get_weights(net->layers[0]);
|
372 |
+
int n = net->layers[0].n;
|
373 |
+
int z;
|
374 |
+
for(z = 0; z < num; ++z){
|
375 |
+
image im = make_image(h, w, 3);
|
376 |
+
fill_image(im, .5);
|
377 |
+
int i;
|
378 |
+
for(i = 0; i < 100; ++i){
|
379 |
+
image r = copy_image(ims[rand()%n]);
|
380 |
+
rotate_image_cw(r, rand()%4);
|
381 |
+
random_distort_image(r, 1, 1.5, 1.5);
|
382 |
+
int dx = rand()%(w-r.w);
|
383 |
+
int dy = rand()%(h-r.h);
|
384 |
+
ghost_image(r, im, dx, dy);
|
385 |
+
free_image(r);
|
386 |
+
}
|
387 |
+
char buff[256];
|
388 |
+
sprintf(buff, "%s/gen_%d", prefix, z);
|
389 |
+
save_image(im, buff);
|
390 |
+
free_image(im);
|
391 |
+
}
|
392 |
+
}
|
393 |
+
|
394 |
+
void visualize(char *cfgfile, char *weightfile)
|
395 |
+
{
|
396 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
397 |
+
visualize_network(net);
|
398 |
+
}
|
399 |
+
|
400 |
+
int main(int argc, char **argv)
|
401 |
+
{
|
402 |
+
//test_resize("data/bad.jpg");
|
403 |
+
//test_box();
|
404 |
+
//test_convolutional_layer();
|
405 |
+
if(argc < 2){
|
406 |
+
fprintf(stderr, "usage: %s <function>\n", argv[0]);
|
407 |
+
return 0;
|
408 |
+
}
|
409 |
+
gpu_index = find_int_arg(argc, argv, "-i", 0);
|
410 |
+
if(find_arg(argc, argv, "-nogpu")) {
|
411 |
+
gpu_index = -1;
|
412 |
+
}
|
413 |
+
|
414 |
+
#ifndef GPU
|
415 |
+
gpu_index = -1;
|
416 |
+
#else
|
417 |
+
if(gpu_index >= 0){
|
418 |
+
cuda_set_device(gpu_index);
|
419 |
+
}
|
420 |
+
#endif
|
421 |
+
|
422 |
+
if (0 == strcmp(argv[1], "average")){
|
423 |
+
average(argc, argv);
|
424 |
+
} else if (0 == strcmp(argv[1], "yolo")){
|
425 |
+
run_yolo(argc, argv);
|
426 |
+
} else if (0 == strcmp(argv[1], "super")){
|
427 |
+
run_super(argc, argv);
|
428 |
+
} else if (0 == strcmp(argv[1], "lsd")){
|
429 |
+
run_lsd(argc, argv);
|
430 |
+
} else if (0 == strcmp(argv[1], "detector")){
|
431 |
+
run_detector(argc, argv);
|
432 |
+
} else if (0 == strcmp(argv[1], "detect")){
|
433 |
+
float thresh = find_float_arg(argc, argv, "-thresh", .5);
|
434 |
+
char *filename = (argc > 4) ? argv[4]: 0;
|
435 |
+
char *outfile = find_char_arg(argc, argv, "-out", 0);
|
436 |
+
int fullscreen = find_arg(argc, argv, "-fullscreen");
|
437 |
+
test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, .5, outfile, fullscreen);
|
438 |
+
} else if (0 == strcmp(argv[1], "cifar")){
|
439 |
+
run_cifar(argc, argv);
|
440 |
+
} else if (0 == strcmp(argv[1], "go")){
|
441 |
+
run_go(argc, argv);
|
442 |
+
} else if (0 == strcmp(argv[1], "rnn")){
|
443 |
+
run_char_rnn(argc, argv);
|
444 |
+
} else if (0 == strcmp(argv[1], "coco")){
|
445 |
+
run_coco(argc, argv);
|
446 |
+
} else if (0 == strcmp(argv[1], "classify")){
|
447 |
+
predict_classifier("cfg/imagenet1k.data", argv[2], argv[3], argv[4], 5);
|
448 |
+
} else if (0 == strcmp(argv[1], "classifier")){
|
449 |
+
run_classifier(argc, argv);
|
450 |
+
} else if (0 == strcmp(argv[1], "regressor")){
|
451 |
+
run_regressor(argc, argv);
|
452 |
+
} else if (0 == strcmp(argv[1], "isegmenter")){
|
453 |
+
run_isegmenter(argc, argv);
|
454 |
+
} else if (0 == strcmp(argv[1], "segmenter")){
|
455 |
+
run_segmenter(argc, argv);
|
456 |
+
} else if (0 == strcmp(argv[1], "art")){
|
457 |
+
run_art(argc, argv);
|
458 |
+
} else if (0 == strcmp(argv[1], "tag")){
|
459 |
+
run_tag(argc, argv);
|
460 |
+
} else if (0 == strcmp(argv[1], "3d")){
|
461 |
+
composite_3d(argv[2], argv[3], argv[4], (argc > 5) ? atof(argv[5]) : 0);
|
462 |
+
} else if (0 == strcmp(argv[1], "test")){
|
463 |
+
test_resize(argv[2]);
|
464 |
+
} else if (0 == strcmp(argv[1], "nightmare")){
|
465 |
+
run_nightmare(argc, argv);
|
466 |
+
} else if (0 == strcmp(argv[1], "rgbgr")){
|
467 |
+
rgbgr_net(argv[2], argv[3], argv[4]);
|
468 |
+
} else if (0 == strcmp(argv[1], "reset")){
|
469 |
+
reset_normalize_net(argv[2], argv[3], argv[4]);
|
470 |
+
} else if (0 == strcmp(argv[1], "denormalize")){
|
471 |
+
denormalize_net(argv[2], argv[3], argv[4]);
|
472 |
+
} else if (0 == strcmp(argv[1], "statistics")){
|
473 |
+
statistics_net(argv[2], argv[3]);
|
474 |
+
} else if (0 == strcmp(argv[1], "normalize")){
|
475 |
+
normalize_net(argv[2], argv[3], argv[4]);
|
476 |
+
} else if (0 == strcmp(argv[1], "rescale")){
|
477 |
+
rescale_net(argv[2], argv[3], argv[4]);
|
478 |
+
} else if (0 == strcmp(argv[1], "ops")){
|
479 |
+
operations(argv[2]);
|
480 |
+
} else if (0 == strcmp(argv[1], "speed")){
|
481 |
+
speed(argv[2], (argc > 3 && argv[3]) ? atoi(argv[3]) : 0);
|
482 |
+
} else if (0 == strcmp(argv[1], "oneoff")){
|
483 |
+
oneoff(argv[2], argv[3], argv[4]);
|
484 |
+
} else if (0 == strcmp(argv[1], "oneoff2")){
|
485 |
+
oneoff2(argv[2], argv[3], argv[4], atoi(argv[5]));
|
486 |
+
} else if (0 == strcmp(argv[1], "print")){
|
487 |
+
print_weights(argv[2], argv[3], atoi(argv[4]));
|
488 |
+
} else if (0 == strcmp(argv[1], "partial")){
|
489 |
+
partial(argv[2], argv[3], argv[4], atoi(argv[5]));
|
490 |
+
} else if (0 == strcmp(argv[1], "average")){
|
491 |
+
average(argc, argv);
|
492 |
+
} else if (0 == strcmp(argv[1], "visualize")){
|
493 |
+
visualize(argv[2], (argc > 3) ? argv[3] : 0);
|
494 |
+
} else if (0 == strcmp(argv[1], "mkimg")){
|
495 |
+
mkimg(argv[2], argv[3], atoi(argv[4]), atoi(argv[5]), atoi(argv[6]), argv[7]);
|
496 |
+
} else if (0 == strcmp(argv[1], "imtest")){
|
497 |
+
test_resize(argv[2]);
|
498 |
+
} else {
|
499 |
+
fprintf(stderr, "Not an option: %s\n", argv[1]);
|
500 |
+
}
|
501 |
+
return 0;
|
502 |
+
}
|
503 |
+
|
model/examples/detector-scipy-opencv.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stupid python path shit.
|
2 |
+
# Instead just add darknet.py to somewhere in your python path
|
3 |
+
# OK actually that might not be a great idea, idk, work in progress
|
4 |
+
# Use at your own risk. or don't, i don't care
|
5 |
+
|
6 |
+
from scipy.misc import imread
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
def array_to_image(arr):
|
10 |
+
arr = arr.transpose(2,0,1)
|
11 |
+
c = arr.shape[0]
|
12 |
+
h = arr.shape[1]
|
13 |
+
w = arr.shape[2]
|
14 |
+
arr = (arr/255.0).flatten()
|
15 |
+
data = dn.c_array(dn.c_float, arr)
|
16 |
+
im = dn.IMAGE(w,h,c,data)
|
17 |
+
return im
|
18 |
+
|
19 |
+
def detect2(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
|
20 |
+
boxes = dn.make_boxes(net)
|
21 |
+
probs = dn.make_probs(net)
|
22 |
+
num = dn.num_boxes(net)
|
23 |
+
dn.network_detect(net, image, thresh, hier_thresh, nms, boxes, probs)
|
24 |
+
res = []
|
25 |
+
for j in range(num):
|
26 |
+
for i in range(meta.classes):
|
27 |
+
if probs[j][i] > 0:
|
28 |
+
res.append((meta.names[i], probs[j][i], (boxes[j].x, boxes[j].y, boxes[j].w, boxes[j].h)))
|
29 |
+
res = sorted(res, key=lambda x: -x[1])
|
30 |
+
dn.free_ptrs(dn.cast(probs, dn.POINTER(dn.c_void_p)), num)
|
31 |
+
return res
|
32 |
+
|
33 |
+
import sys, os
|
34 |
+
sys.path.append(os.path.join(os.getcwd(),'python/'))
|
35 |
+
|
36 |
+
import darknet as dn
|
37 |
+
|
38 |
+
# Darknet
|
39 |
+
net = dn.load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0)
|
40 |
+
meta = dn.load_meta("cfg/coco.data")
|
41 |
+
r = dn.detect(net, meta, "data/dog.jpg")
|
42 |
+
print r
|
43 |
+
|
44 |
+
# scipy
|
45 |
+
arr= imread('data/dog.jpg')
|
46 |
+
im = array_to_image(arr)
|
47 |
+
r = detect2(net, meta, im)
|
48 |
+
print r
|
49 |
+
|
50 |
+
# OpenCV
|
51 |
+
arr = cv2.imread('data/dog.jpg')
|
52 |
+
im = array_to_image(arr)
|
53 |
+
dn.rgbgr_image(im)
|
54 |
+
r = detect2(net, meta, im)
|
55 |
+
print r
|
56 |
+
|
model/examples/detector.c
ADDED
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
4 |
+
|
5 |
+
|
6 |
+
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
7 |
+
{
|
8 |
+
list *options = read_data_cfg(datacfg);
|
9 |
+
char *train_images = option_find_str(options, "train", "data/train.list");
|
10 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
11 |
+
|
12 |
+
srand(time(0));
|
13 |
+
char *base = basecfg(cfgfile);
|
14 |
+
printf("%s\n", base);
|
15 |
+
float avg_loss = -1;
|
16 |
+
network **nets = calloc(ngpus, sizeof(network));
|
17 |
+
|
18 |
+
srand(time(0));
|
19 |
+
int seed = rand();
|
20 |
+
int i;
|
21 |
+
for(i = 0; i < ngpus; ++i){
|
22 |
+
srand(seed);
|
23 |
+
#ifdef GPU
|
24 |
+
cuda_set_device(gpus[i]);
|
25 |
+
#endif
|
26 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
27 |
+
nets[i]->learning_rate *= ngpus;
|
28 |
+
}
|
29 |
+
srand(time(0));
|
30 |
+
network *net = nets[0];
|
31 |
+
|
32 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
33 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
34 |
+
data train, buffer;
|
35 |
+
|
36 |
+
layer l = net->layers[net->n - 1];
|
37 |
+
|
38 |
+
int classes = l.classes;
|
39 |
+
float jitter = l.jitter;
|
40 |
+
|
41 |
+
list *plist = get_paths(train_images);
|
42 |
+
//int N = plist->size;
|
43 |
+
char **paths = (char **)list_to_array(plist);
|
44 |
+
|
45 |
+
load_args args = get_base_args(net);
|
46 |
+
args.coords = l.coords;
|
47 |
+
args.paths = paths;
|
48 |
+
args.n = imgs;
|
49 |
+
args.m = plist->size;
|
50 |
+
args.classes = classes;
|
51 |
+
args.jitter = jitter;
|
52 |
+
args.num_boxes = l.max_boxes;
|
53 |
+
args.d = &buffer;
|
54 |
+
args.type = DETECTION_DATA;
|
55 |
+
//args.type = INSTANCE_DATA;
|
56 |
+
args.threads = 64;
|
57 |
+
|
58 |
+
pthread_t load_thread = load_data(args);
|
59 |
+
double time;
|
60 |
+
int count = 0;
|
61 |
+
//while(i*imgs < N*120){
|
62 |
+
while(get_current_batch(net) < net->max_batches){
|
63 |
+
if(l.random && count++%10 == 0){
|
64 |
+
printf("Resizing\n");
|
65 |
+
int dim = (rand() % 10 + 10) * 32;
|
66 |
+
if (get_current_batch(net)+200 > net->max_batches) dim = 608;
|
67 |
+
//int dim = (rand() % 4 + 16) * 32;
|
68 |
+
printf("%d\n", dim);
|
69 |
+
args.w = dim;
|
70 |
+
args.h = dim;
|
71 |
+
|
72 |
+
pthread_join(load_thread, 0);
|
73 |
+
train = buffer;
|
74 |
+
free_data(train);
|
75 |
+
load_thread = load_data(args);
|
76 |
+
|
77 |
+
#pragma omp parallel for
|
78 |
+
for(i = 0; i < ngpus; ++i){
|
79 |
+
resize_network(nets[i], dim, dim);
|
80 |
+
}
|
81 |
+
net = nets[0];
|
82 |
+
}
|
83 |
+
time=what_time_is_it_now();
|
84 |
+
pthread_join(load_thread, 0);
|
85 |
+
train = buffer;
|
86 |
+
load_thread = load_data(args);
|
87 |
+
|
88 |
+
/*
|
89 |
+
int k;
|
90 |
+
for(k = 0; k < l.max_boxes; ++k){
|
91 |
+
box b = float_to_box(train.y.vals[10] + 1 + k*5);
|
92 |
+
if(!b.x) break;
|
93 |
+
printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h);
|
94 |
+
}
|
95 |
+
*/
|
96 |
+
/*
|
97 |
+
int zz;
|
98 |
+
for(zz = 0; zz < train.X.cols; ++zz){
|
99 |
+
image im = float_to_image(net->w, net->h, 3, train.X.vals[zz]);
|
100 |
+
int k;
|
101 |
+
for(k = 0; k < l.max_boxes; ++k){
|
102 |
+
box b = float_to_box(train.y.vals[zz] + k*5, 1);
|
103 |
+
printf("%f %f %f %f\n", b.x, b.y, b.w, b.h);
|
104 |
+
draw_bbox(im, b, 1, 1,0,0);
|
105 |
+
}
|
106 |
+
show_image(im, "truth11");
|
107 |
+
cvWaitKey(0);
|
108 |
+
save_image(im, "truth11");
|
109 |
+
}
|
110 |
+
*/
|
111 |
+
|
112 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
113 |
+
|
114 |
+
time=what_time_is_it_now();
|
115 |
+
float loss = 0;
|
116 |
+
#ifdef GPU
|
117 |
+
if(ngpus == 1){
|
118 |
+
loss = train_network(net, train);
|
119 |
+
} else {
|
120 |
+
loss = train_networks(nets, ngpus, train, 4);
|
121 |
+
}
|
122 |
+
#else
|
123 |
+
loss = train_network(net, train);
|
124 |
+
#endif
|
125 |
+
if (avg_loss < 0) avg_loss = loss;
|
126 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
127 |
+
|
128 |
+
i = get_current_batch(net);
|
129 |
+
printf("%ld: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, i*imgs);
|
130 |
+
if(i%100==0){
|
131 |
+
#ifdef GPU
|
132 |
+
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
133 |
+
#endif
|
134 |
+
char buff[256];
|
135 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
136 |
+
save_weights(net, buff);
|
137 |
+
}
|
138 |
+
if(i%10000==0 || (i < 1000 && i%100 == 0)){
|
139 |
+
#ifdef GPU
|
140 |
+
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
141 |
+
#endif
|
142 |
+
char buff[256];
|
143 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
144 |
+
save_weights(net, buff);
|
145 |
+
}
|
146 |
+
free_data(train);
|
147 |
+
}
|
148 |
+
#ifdef GPU
|
149 |
+
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
150 |
+
#endif
|
151 |
+
char buff[256];
|
152 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
153 |
+
save_weights(net, buff);
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
static int get_coco_image_id(char *filename)
|
158 |
+
{
|
159 |
+
char *p = strrchr(filename, '/');
|
160 |
+
char *c = strrchr(filename, '_');
|
161 |
+
if(c) p = c;
|
162 |
+
return atoi(p+1);
|
163 |
+
}
|
164 |
+
|
165 |
+
static void print_cocos(FILE *fp, char *image_path, detection *dets, int num_boxes, int classes, int w, int h)
|
166 |
+
{
|
167 |
+
int i, j;
|
168 |
+
int image_id = get_coco_image_id(image_path);
|
169 |
+
for(i = 0; i < num_boxes; ++i){
|
170 |
+
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
171 |
+
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
172 |
+
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
173 |
+
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
174 |
+
|
175 |
+
if (xmin < 0) xmin = 0;
|
176 |
+
if (ymin < 0) ymin = 0;
|
177 |
+
if (xmax > w) xmax = w;
|
178 |
+
if (ymax > h) ymax = h;
|
179 |
+
|
180 |
+
float bx = xmin;
|
181 |
+
float by = ymin;
|
182 |
+
float bw = xmax - xmin;
|
183 |
+
float bh = ymax - ymin;
|
184 |
+
|
185 |
+
for(j = 0; j < classes; ++j){
|
186 |
+
if (dets[i].prob[j]) fprintf(fp, "{\"image_id\":%d, \"category_id\":%d, \"bbox\":[%f, %f, %f, %f], \"score\":%f},\n", image_id, coco_ids[j], bx, by, bw, bh, dets[i].prob[j]);
|
187 |
+
}
|
188 |
+
}
|
189 |
+
}
|
190 |
+
|
191 |
+
void print_detector_detections(FILE **fps, char *id, detection *dets, int total, int classes, int w, int h)
|
192 |
+
{
|
193 |
+
int i, j;
|
194 |
+
for(i = 0; i < total; ++i){
|
195 |
+
float xmin = dets[i].bbox.x - dets[i].bbox.w/2. + 1;
|
196 |
+
float xmax = dets[i].bbox.x + dets[i].bbox.w/2. + 1;
|
197 |
+
float ymin = dets[i].bbox.y - dets[i].bbox.h/2. + 1;
|
198 |
+
float ymax = dets[i].bbox.y + dets[i].bbox.h/2. + 1;
|
199 |
+
|
200 |
+
if (xmin < 1) xmin = 1;
|
201 |
+
if (ymin < 1) ymin = 1;
|
202 |
+
if (xmax > w) xmax = w;
|
203 |
+
if (ymax > h) ymax = h;
|
204 |
+
|
205 |
+
for(j = 0; j < classes; ++j){
|
206 |
+
if (dets[i].prob[j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, dets[i].prob[j],
|
207 |
+
xmin, ymin, xmax, ymax);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
void print_imagenet_detections(FILE *fp, int id, detection *dets, int total, int classes, int w, int h)
|
213 |
+
{
|
214 |
+
int i, j;
|
215 |
+
for(i = 0; i < total; ++i){
|
216 |
+
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
217 |
+
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
218 |
+
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
219 |
+
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
220 |
+
|
221 |
+
if (xmin < 0) xmin = 0;
|
222 |
+
if (ymin < 0) ymin = 0;
|
223 |
+
if (xmax > w) xmax = w;
|
224 |
+
if (ymax > h) ymax = h;
|
225 |
+
|
226 |
+
for(j = 0; j < classes; ++j){
|
227 |
+
int class = j;
|
228 |
+
if (dets[i].prob[class]) fprintf(fp, "%d %d %f %f %f %f %f\n", id, j+1, dets[i].prob[class],
|
229 |
+
xmin, ymin, xmax, ymax);
|
230 |
+
}
|
231 |
+
}
|
232 |
+
}
|
233 |
+
|
234 |
+
void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
|
235 |
+
{
|
236 |
+
int j;
|
237 |
+
list *options = read_data_cfg(datacfg);
|
238 |
+
char *valid_images = option_find_str(options, "valid", "data/train.list");
|
239 |
+
char *name_list = option_find_str(options, "names", "data/names.list");
|
240 |
+
char *prefix = option_find_str(options, "results", "results");
|
241 |
+
char **names = get_labels(name_list);
|
242 |
+
char *mapf = option_find_str(options, "map", 0);
|
243 |
+
int *map = 0;
|
244 |
+
if (mapf) map = read_map(mapf);
|
245 |
+
|
246 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
247 |
+
set_batch_network(net, 2);
|
248 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
249 |
+
srand(time(0));
|
250 |
+
|
251 |
+
list *plist = get_paths(valid_images);
|
252 |
+
char **paths = (char **)list_to_array(plist);
|
253 |
+
|
254 |
+
layer l = net->layers[net->n-1];
|
255 |
+
int classes = l.classes;
|
256 |
+
|
257 |
+
char buff[1024];
|
258 |
+
char *type = option_find_str(options, "eval", "voc");
|
259 |
+
FILE *fp = 0;
|
260 |
+
FILE **fps = 0;
|
261 |
+
int coco = 0;
|
262 |
+
int imagenet = 0;
|
263 |
+
if(0==strcmp(type, "coco")){
|
264 |
+
if(!outfile) outfile = "coco_results";
|
265 |
+
snprintf(buff, 1024, "%s/%s.json", prefix, outfile);
|
266 |
+
fp = fopen(buff, "w");
|
267 |
+
fprintf(fp, "[\n");
|
268 |
+
coco = 1;
|
269 |
+
} else if(0==strcmp(type, "imagenet")){
|
270 |
+
if(!outfile) outfile = "imagenet-detection";
|
271 |
+
snprintf(buff, 1024, "%s/%s.txt", prefix, outfile);
|
272 |
+
fp = fopen(buff, "w");
|
273 |
+
imagenet = 1;
|
274 |
+
classes = 200;
|
275 |
+
} else {
|
276 |
+
if(!outfile) outfile = "comp4_det_test_";
|
277 |
+
fps = calloc(classes, sizeof(FILE *));
|
278 |
+
for(j = 0; j < classes; ++j){
|
279 |
+
snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
|
280 |
+
fps[j] = fopen(buff, "w");
|
281 |
+
}
|
282 |
+
}
|
283 |
+
|
284 |
+
int m = plist->size;
|
285 |
+
int i=0;
|
286 |
+
int t;
|
287 |
+
|
288 |
+
float thresh = .005;
|
289 |
+
float nms = .45;
|
290 |
+
|
291 |
+
int nthreads = 4;
|
292 |
+
image *val = calloc(nthreads, sizeof(image));
|
293 |
+
image *val_resized = calloc(nthreads, sizeof(image));
|
294 |
+
image *buf = calloc(nthreads, sizeof(image));
|
295 |
+
image *buf_resized = calloc(nthreads, sizeof(image));
|
296 |
+
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
297 |
+
|
298 |
+
image input = make_image(net->w, net->h, net->c*2);
|
299 |
+
|
300 |
+
load_args args = {0};
|
301 |
+
args.w = net->w;
|
302 |
+
args.h = net->h;
|
303 |
+
//args.type = IMAGE_DATA;
|
304 |
+
args.type = LETTERBOX_DATA;
|
305 |
+
|
306 |
+
for(t = 0; t < nthreads; ++t){
|
307 |
+
args.path = paths[i+t];
|
308 |
+
args.im = &buf[t];
|
309 |
+
args.resized = &buf_resized[t];
|
310 |
+
thr[t] = load_data_in_thread(args);
|
311 |
+
}
|
312 |
+
double start = what_time_is_it_now();
|
313 |
+
for(i = nthreads; i < m+nthreads; i += nthreads){
|
314 |
+
fprintf(stderr, "%d\n", i);
|
315 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
316 |
+
pthread_join(thr[t], 0);
|
317 |
+
val[t] = buf[t];
|
318 |
+
val_resized[t] = buf_resized[t];
|
319 |
+
}
|
320 |
+
for(t = 0; t < nthreads && i+t < m; ++t){
|
321 |
+
args.path = paths[i+t];
|
322 |
+
args.im = &buf[t];
|
323 |
+
args.resized = &buf_resized[t];
|
324 |
+
thr[t] = load_data_in_thread(args);
|
325 |
+
}
|
326 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
327 |
+
char *path = paths[i+t-nthreads];
|
328 |
+
char *id = basecfg(path);
|
329 |
+
copy_cpu(net->w*net->h*net->c, val_resized[t].data, 1, input.data, 1);
|
330 |
+
flip_image(val_resized[t]);
|
331 |
+
copy_cpu(net->w*net->h*net->c, val_resized[t].data, 1, input.data + net->w*net->h*net->c, 1);
|
332 |
+
|
333 |
+
network_predict(net, input.data);
|
334 |
+
int w = val[t].w;
|
335 |
+
int h = val[t].h;
|
336 |
+
int num = 0;
|
337 |
+
detection *dets = get_network_boxes(net, w, h, thresh, .5, map, 0, &num);
|
338 |
+
if (nms) do_nms_sort(dets, num, classes, nms);
|
339 |
+
if (coco){
|
340 |
+
print_cocos(fp, path, dets, num, classes, w, h);
|
341 |
+
} else if (imagenet){
|
342 |
+
print_imagenet_detections(fp, i+t-nthreads+1, dets, num, classes, w, h);
|
343 |
+
} else {
|
344 |
+
print_detector_detections(fps, id, dets, num, classes, w, h);
|
345 |
+
}
|
346 |
+
free_detections(dets, num);
|
347 |
+
free(id);
|
348 |
+
free_image(val[t]);
|
349 |
+
free_image(val_resized[t]);
|
350 |
+
}
|
351 |
+
}
|
352 |
+
for(j = 0; j < classes; ++j){
|
353 |
+
if(fps) fclose(fps[j]);
|
354 |
+
}
|
355 |
+
if(coco){
|
356 |
+
fseek(fp, -2, SEEK_CUR);
|
357 |
+
fprintf(fp, "\n]\n");
|
358 |
+
fclose(fp);
|
359 |
+
}
|
360 |
+
fprintf(stderr, "Total Detection Time: %f Seconds\n", what_time_is_it_now() - start);
|
361 |
+
}
|
362 |
+
|
363 |
+
|
364 |
+
void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
|
365 |
+
{
|
366 |
+
int j;
|
367 |
+
list *options = read_data_cfg(datacfg);
|
368 |
+
char *valid_images = option_find_str(options, "valid", "data/train.list");
|
369 |
+
char *name_list = option_find_str(options, "names", "data/names.list");
|
370 |
+
char *prefix = option_find_str(options, "results", "results");
|
371 |
+
char **names = get_labels(name_list);
|
372 |
+
char *mapf = option_find_str(options, "map", 0);
|
373 |
+
int *map = 0;
|
374 |
+
if (mapf) map = read_map(mapf);
|
375 |
+
|
376 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
377 |
+
set_batch_network(net, 1);
|
378 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
379 |
+
srand(time(0));
|
380 |
+
|
381 |
+
list *plist = get_paths(valid_images);
|
382 |
+
char **paths = (char **)list_to_array(plist);
|
383 |
+
|
384 |
+
layer l = net->layers[net->n-1];
|
385 |
+
int classes = l.classes;
|
386 |
+
|
387 |
+
char buff[1024];
|
388 |
+
char *type = option_find_str(options, "eval", "voc");
|
389 |
+
FILE *fp = 0;
|
390 |
+
FILE **fps = 0;
|
391 |
+
int coco = 0;
|
392 |
+
int imagenet = 0;
|
393 |
+
if(0==strcmp(type, "coco")){
|
394 |
+
if(!outfile) outfile = "coco_results";
|
395 |
+
snprintf(buff, 1024, "%s/%s.json", prefix, outfile);
|
396 |
+
fp = fopen(buff, "w");
|
397 |
+
fprintf(fp, "[\n");
|
398 |
+
coco = 1;
|
399 |
+
} else if(0==strcmp(type, "imagenet")){
|
400 |
+
if(!outfile) outfile = "imagenet-detection";
|
401 |
+
snprintf(buff, 1024, "%s/%s.txt", prefix, outfile);
|
402 |
+
fp = fopen(buff, "w");
|
403 |
+
imagenet = 1;
|
404 |
+
classes = 200;
|
405 |
+
} else {
|
406 |
+
if(!outfile) outfile = "comp4_det_test_";
|
407 |
+
fps = calloc(classes, sizeof(FILE *));
|
408 |
+
for(j = 0; j < classes; ++j){
|
409 |
+
snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
|
410 |
+
fps[j] = fopen(buff, "w");
|
411 |
+
}
|
412 |
+
}
|
413 |
+
|
414 |
+
|
415 |
+
int m = plist->size;
|
416 |
+
int i=0;
|
417 |
+
int t;
|
418 |
+
|
419 |
+
float thresh = .005;
|
420 |
+
float nms = .45;
|
421 |
+
|
422 |
+
int nthreads = 4;
|
423 |
+
image *val = calloc(nthreads, sizeof(image));
|
424 |
+
image *val_resized = calloc(nthreads, sizeof(image));
|
425 |
+
image *buf = calloc(nthreads, sizeof(image));
|
426 |
+
image *buf_resized = calloc(nthreads, sizeof(image));
|
427 |
+
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
428 |
+
|
429 |
+
load_args args = {0};
|
430 |
+
args.w = net->w;
|
431 |
+
args.h = net->h;
|
432 |
+
//args.type = IMAGE_DATA;
|
433 |
+
args.type = LETTERBOX_DATA;
|
434 |
+
|
435 |
+
for(t = 0; t < nthreads; ++t){
|
436 |
+
args.path = paths[i+t];
|
437 |
+
args.im = &buf[t];
|
438 |
+
args.resized = &buf_resized[t];
|
439 |
+
thr[t] = load_data_in_thread(args);
|
440 |
+
}
|
441 |
+
double start = what_time_is_it_now();
|
442 |
+
for(i = nthreads; i < m+nthreads; i += nthreads){
|
443 |
+
fprintf(stderr, "%d\n", i);
|
444 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
445 |
+
pthread_join(thr[t], 0);
|
446 |
+
val[t] = buf[t];
|
447 |
+
val_resized[t] = buf_resized[t];
|
448 |
+
}
|
449 |
+
for(t = 0; t < nthreads && i+t < m; ++t){
|
450 |
+
args.path = paths[i+t];
|
451 |
+
args.im = &buf[t];
|
452 |
+
args.resized = &buf_resized[t];
|
453 |
+
thr[t] = load_data_in_thread(args);
|
454 |
+
}
|
455 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
456 |
+
char *path = paths[i+t-nthreads];
|
457 |
+
char *id = basecfg(path);
|
458 |
+
float *X = val_resized[t].data;
|
459 |
+
network_predict(net, X);
|
460 |
+
int w = val[t].w;
|
461 |
+
int h = val[t].h;
|
462 |
+
int nboxes = 0;
|
463 |
+
detection *dets = get_network_boxes(net, w, h, thresh, .5, map, 0, &nboxes);
|
464 |
+
if (nms) do_nms_sort(dets, nboxes, classes, nms);
|
465 |
+
if (coco){
|
466 |
+
print_cocos(fp, path, dets, nboxes, classes, w, h);
|
467 |
+
} else if (imagenet){
|
468 |
+
print_imagenet_detections(fp, i+t-nthreads+1, dets, nboxes, classes, w, h);
|
469 |
+
} else {
|
470 |
+
print_detector_detections(fps, id, dets, nboxes, classes, w, h);
|
471 |
+
}
|
472 |
+
free_detections(dets, nboxes);
|
473 |
+
free(id);
|
474 |
+
free_image(val[t]);
|
475 |
+
free_image(val_resized[t]);
|
476 |
+
}
|
477 |
+
}
|
478 |
+
for(j = 0; j < classes; ++j){
|
479 |
+
if(fps) fclose(fps[j]);
|
480 |
+
}
|
481 |
+
if(coco){
|
482 |
+
fseek(fp, -2, SEEK_CUR);
|
483 |
+
fprintf(fp, "\n]\n");
|
484 |
+
fclose(fp);
|
485 |
+
}
|
486 |
+
fprintf(stderr, "Total Detection Time: %f Seconds\n", what_time_is_it_now() - start);
|
487 |
+
}
|
488 |
+
|
489 |
+
void validate_detector_recall(char *cfgfile, char *weightfile)
|
490 |
+
{
|
491 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
492 |
+
set_batch_network(net, 1);
|
493 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
494 |
+
srand(time(0));
|
495 |
+
|
496 |
+
list *plist = get_paths("data/coco_val_5k.list");
|
497 |
+
char **paths = (char **)list_to_array(plist);
|
498 |
+
|
499 |
+
layer l = net->layers[net->n-1];
|
500 |
+
|
501 |
+
int j, k;
|
502 |
+
|
503 |
+
int m = plist->size;
|
504 |
+
int i=0;
|
505 |
+
|
506 |
+
float thresh = .001;
|
507 |
+
float iou_thresh = .5;
|
508 |
+
float nms = .4;
|
509 |
+
|
510 |
+
int total = 0;
|
511 |
+
int correct = 0;
|
512 |
+
int proposals = 0;
|
513 |
+
float avg_iou = 0;
|
514 |
+
|
515 |
+
for(i = 0; i < m; ++i){
|
516 |
+
char *path = paths[i];
|
517 |
+
image orig = load_image_color(path, 0, 0);
|
518 |
+
image sized = resize_image(orig, net->w, net->h);
|
519 |
+
char *id = basecfg(path);
|
520 |
+
network_predict(net, sized.data);
|
521 |
+
int nboxes = 0;
|
522 |
+
detection *dets = get_network_boxes(net, sized.w, sized.h, thresh, .5, 0, 1, &nboxes);
|
523 |
+
if (nms) do_nms_obj(dets, nboxes, 1, nms);
|
524 |
+
|
525 |
+
char labelpath[4096];
|
526 |
+
find_replace(path, "images", "labels", labelpath);
|
527 |
+
find_replace(labelpath, "JPEGImages", "labels", labelpath);
|
528 |
+
find_replace(labelpath, ".jpg", ".txt", labelpath);
|
529 |
+
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
530 |
+
|
531 |
+
int num_labels = 0;
|
532 |
+
box_label *truth = read_boxes(labelpath, &num_labels);
|
533 |
+
for(k = 0; k < nboxes; ++k){
|
534 |
+
if(dets[k].objectness > thresh){
|
535 |
+
++proposals;
|
536 |
+
}
|
537 |
+
}
|
538 |
+
for (j = 0; j < num_labels; ++j) {
|
539 |
+
++total;
|
540 |
+
box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
|
541 |
+
float best_iou = 0;
|
542 |
+
for(k = 0; k < l.w*l.h*l.n; ++k){
|
543 |
+
float iou = box_iou(dets[k].bbox, t);
|
544 |
+
if(dets[k].objectness > thresh && iou > best_iou){
|
545 |
+
best_iou = iou;
|
546 |
+
}
|
547 |
+
}
|
548 |
+
avg_iou += best_iou;
|
549 |
+
if(best_iou > iou_thresh){
|
550 |
+
++correct;
|
551 |
+
}
|
552 |
+
}
|
553 |
+
|
554 |
+
fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
|
555 |
+
free(id);
|
556 |
+
free_image(orig);
|
557 |
+
free_image(sized);
|
558 |
+
}
|
559 |
+
}
|
560 |
+
|
561 |
+
|
562 |
+
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh, char *outfile, int fullscreen)
|
563 |
+
{
|
564 |
+
list *options = read_data_cfg(datacfg);
|
565 |
+
char *name_list = option_find_str(options, "names", "data/names.list");
|
566 |
+
char **names = get_labels(name_list);
|
567 |
+
|
568 |
+
image **alphabet = load_alphabet();
|
569 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
570 |
+
set_batch_network(net, 1);
|
571 |
+
srand(2222222);
|
572 |
+
double time;
|
573 |
+
char buff[256];
|
574 |
+
char *input = buff;
|
575 |
+
float nms=.45;
|
576 |
+
while(1){
|
577 |
+
if(filename){
|
578 |
+
strncpy(input, filename, 256);
|
579 |
+
} else {
|
580 |
+
printf("Enter Image Path: ");
|
581 |
+
fflush(stdout);
|
582 |
+
input = fgets(input, 256, stdin);
|
583 |
+
if(!input) return;
|
584 |
+
strtok(input, "\n");
|
585 |
+
}
|
586 |
+
image im = load_image_color(input,0,0);
|
587 |
+
image sized = letterbox_image(im, net->w, net->h);
|
588 |
+
//image sized = resize_image(im, net->w, net->h);
|
589 |
+
//image sized2 = resize_max(im, net->w);
|
590 |
+
//image sized = crop_image(sized2, -((net->w - sized2.w)/2), -((net->h - sized2.h)/2), net->w, net->h);
|
591 |
+
//resize_network(net, sized.w, sized.h);
|
592 |
+
layer l = net->layers[net->n-1];
|
593 |
+
|
594 |
+
|
595 |
+
float *X = sized.data;
|
596 |
+
time=what_time_is_it_now();
|
597 |
+
network_predict(net, X);
|
598 |
+
printf("%s: Predicted in %f seconds.\n", input, what_time_is_it_now()-time);
|
599 |
+
int nboxes = 0;
|
600 |
+
detection *dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes);
|
601 |
+
//printf("%d\n", nboxes);
|
602 |
+
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
|
603 |
+
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
604 |
+
draw_detections(im, dets, nboxes, thresh, names, alphabet, l.classes);
|
605 |
+
free_detections(dets, nboxes);
|
606 |
+
if(outfile){
|
607 |
+
save_image(im, outfile);
|
608 |
+
}
|
609 |
+
else{
|
610 |
+
save_image(im, "predictions");
|
611 |
+
#ifdef OPENCV
|
612 |
+
make_window("predictions", 512, 512, 0);
|
613 |
+
show_image(im, "predictions", 0);
|
614 |
+
#endif
|
615 |
+
}
|
616 |
+
|
617 |
+
free_image(im);
|
618 |
+
free_image(sized);
|
619 |
+
if (filename) break;
|
620 |
+
}
|
621 |
+
}
|
622 |
+
|
623 |
+
/*
|
624 |
+
void censor_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename, int class, float thresh, int skip)
|
625 |
+
{
|
626 |
+
#ifdef OPENCV
|
627 |
+
char *base = basecfg(cfgfile);
|
628 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
629 |
+
set_batch_network(net, 1);
|
630 |
+
|
631 |
+
srand(2222222);
|
632 |
+
CvCapture * cap;
|
633 |
+
|
634 |
+
int w = 1280;
|
635 |
+
int h = 720;
|
636 |
+
|
637 |
+
if(filename){
|
638 |
+
cap = cvCaptureFromFile(filename);
|
639 |
+
}else{
|
640 |
+
cap = cvCaptureFromCAM(cam_index);
|
641 |
+
}
|
642 |
+
|
643 |
+
if(w){
|
644 |
+
cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_WIDTH, w);
|
645 |
+
}
|
646 |
+
if(h){
|
647 |
+
cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_HEIGHT, h);
|
648 |
+
}
|
649 |
+
|
650 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
651 |
+
cvNamedWindow(base, CV_WINDOW_NORMAL);
|
652 |
+
cvResizeWindow(base, 512, 512);
|
653 |
+
float fps = 0;
|
654 |
+
int i;
|
655 |
+
float nms = .45;
|
656 |
+
|
657 |
+
while(1){
|
658 |
+
image in = get_image_from_stream(cap);
|
659 |
+
//image in_s = resize_image(in, net->w, net->h);
|
660 |
+
image in_s = letterbox_image(in, net->w, net->h);
|
661 |
+
layer l = net->layers[net->n-1];
|
662 |
+
|
663 |
+
float *X = in_s.data;
|
664 |
+
network_predict(net, X);
|
665 |
+
int nboxes = 0;
|
666 |
+
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 0, &nboxes);
|
667 |
+
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
|
668 |
+
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
669 |
+
|
670 |
+
for(i = 0; i < nboxes; ++i){
|
671 |
+
if(dets[i].prob[class] > thresh){
|
672 |
+
box b = dets[i].bbox;
|
673 |
+
int left = b.x-b.w/2.;
|
674 |
+
int top = b.y-b.h/2.;
|
675 |
+
censor_image(in, left, top, b.w, b.h);
|
676 |
+
}
|
677 |
+
}
|
678 |
+
show_image(in, base);
|
679 |
+
cvWaitKey(10);
|
680 |
+
free_detections(dets, nboxes);
|
681 |
+
|
682 |
+
|
683 |
+
free_image(in_s);
|
684 |
+
free_image(in);
|
685 |
+
|
686 |
+
|
687 |
+
float curr = 0;
|
688 |
+
fps = .9*fps + .1*curr;
|
689 |
+
for(i = 0; i < skip; ++i){
|
690 |
+
image in = get_image_from_stream(cap);
|
691 |
+
free_image(in);
|
692 |
+
}
|
693 |
+
}
|
694 |
+
#endif
|
695 |
+
}
|
696 |
+
|
697 |
+
void extract_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename, int class, float thresh, int skip)
|
698 |
+
{
|
699 |
+
#ifdef OPENCV
|
700 |
+
char *base = basecfg(cfgfile);
|
701 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
702 |
+
set_batch_network(net, 1);
|
703 |
+
|
704 |
+
srand(2222222);
|
705 |
+
CvCapture * cap;
|
706 |
+
|
707 |
+
int w = 1280;
|
708 |
+
int h = 720;
|
709 |
+
|
710 |
+
if(filename){
|
711 |
+
cap = cvCaptureFromFile(filename);
|
712 |
+
}else{
|
713 |
+
cap = cvCaptureFromCAM(cam_index);
|
714 |
+
}
|
715 |
+
|
716 |
+
if(w){
|
717 |
+
cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_WIDTH, w);
|
718 |
+
}
|
719 |
+
if(h){
|
720 |
+
cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_HEIGHT, h);
|
721 |
+
}
|
722 |
+
|
723 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
724 |
+
cvNamedWindow(base, CV_WINDOW_NORMAL);
|
725 |
+
cvResizeWindow(base, 512, 512);
|
726 |
+
float fps = 0;
|
727 |
+
int i;
|
728 |
+
int count = 0;
|
729 |
+
float nms = .45;
|
730 |
+
|
731 |
+
while(1){
|
732 |
+
image in = get_image_from_stream(cap);
|
733 |
+
//image in_s = resize_image(in, net->w, net->h);
|
734 |
+
image in_s = letterbox_image(in, net->w, net->h);
|
735 |
+
layer l = net->layers[net->n-1];
|
736 |
+
|
737 |
+
show_image(in, base);
|
738 |
+
|
739 |
+
int nboxes = 0;
|
740 |
+
float *X = in_s.data;
|
741 |
+
network_predict(net, X);
|
742 |
+
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 1, &nboxes);
|
743 |
+
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
|
744 |
+
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
745 |
+
|
746 |
+
for(i = 0; i < nboxes; ++i){
|
747 |
+
if(dets[i].prob[class] > thresh){
|
748 |
+
box b = dets[i].bbox;
|
749 |
+
int size = b.w*in.w > b.h*in.h ? b.w*in.w : b.h*in.h;
|
750 |
+
int dx = b.x*in.w-size/2.;
|
751 |
+
int dy = b.y*in.h-size/2.;
|
752 |
+
image bim = crop_image(in, dx, dy, size, size);
|
753 |
+
char buff[2048];
|
754 |
+
sprintf(buff, "results/extract/%07d", count);
|
755 |
+
++count;
|
756 |
+
save_image(bim, buff);
|
757 |
+
free_image(bim);
|
758 |
+
}
|
759 |
+
}
|
760 |
+
free_detections(dets, nboxes);
|
761 |
+
|
762 |
+
|
763 |
+
free_image(in_s);
|
764 |
+
free_image(in);
|
765 |
+
|
766 |
+
|
767 |
+
float curr = 0;
|
768 |
+
fps = .9*fps + .1*curr;
|
769 |
+
for(i = 0; i < skip; ++i){
|
770 |
+
image in = get_image_from_stream(cap);
|
771 |
+
free_image(in);
|
772 |
+
}
|
773 |
+
}
|
774 |
+
#endif
|
775 |
+
}
|
776 |
+
*/
|
777 |
+
|
778 |
+
/*
|
779 |
+
void network_detect(network *net, image im, float thresh, float hier_thresh, float nms, detection *dets)
|
780 |
+
{
|
781 |
+
network_predict_image(net, im);
|
782 |
+
layer l = net->layers[net->n-1];
|
783 |
+
int nboxes = num_boxes(net);
|
784 |
+
fill_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 0, dets);
|
785 |
+
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
786 |
+
}
|
787 |
+
*/
|
788 |
+
|
789 |
+
void run_detector(int argc, char **argv)
|
790 |
+
{
|
791 |
+
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
792 |
+
float thresh = find_float_arg(argc, argv, "-thresh", .5);
|
793 |
+
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
|
794 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
795 |
+
int frame_skip = find_int_arg(argc, argv, "-s", 0);
|
796 |
+
int avg = find_int_arg(argc, argv, "-avg", 3);
|
797 |
+
if(argc < 4){
|
798 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
799 |
+
return;
|
800 |
+
}
|
801 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
802 |
+
char *outfile = find_char_arg(argc, argv, "-out", 0);
|
803 |
+
int *gpus = 0;
|
804 |
+
int gpu = 0;
|
805 |
+
int ngpus = 0;
|
806 |
+
if(gpu_list){
|
807 |
+
printf("%s\n", gpu_list);
|
808 |
+
int len = strlen(gpu_list);
|
809 |
+
ngpus = 1;
|
810 |
+
int i;
|
811 |
+
for(i = 0; i < len; ++i){
|
812 |
+
if (gpu_list[i] == ',') ++ngpus;
|
813 |
+
}
|
814 |
+
gpus = calloc(ngpus, sizeof(int));
|
815 |
+
for(i = 0; i < ngpus; ++i){
|
816 |
+
gpus[i] = atoi(gpu_list);
|
817 |
+
gpu_list = strchr(gpu_list, ',')+1;
|
818 |
+
}
|
819 |
+
} else {
|
820 |
+
gpu = gpu_index;
|
821 |
+
gpus = &gpu;
|
822 |
+
ngpus = 1;
|
823 |
+
}
|
824 |
+
|
825 |
+
int clear = find_arg(argc, argv, "-clear");
|
826 |
+
int fullscreen = find_arg(argc, argv, "-fullscreen");
|
827 |
+
int width = find_int_arg(argc, argv, "-w", 0);
|
828 |
+
int height = find_int_arg(argc, argv, "-h", 0);
|
829 |
+
int fps = find_int_arg(argc, argv, "-fps", 0);
|
830 |
+
//int class = find_int_arg(argc, argv, "-class", 0);
|
831 |
+
|
832 |
+
char *datacfg = argv[3];
|
833 |
+
char *cfg = argv[4];
|
834 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
835 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
836 |
+
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, outfile, fullscreen);
|
837 |
+
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
|
838 |
+
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
|
839 |
+
else if(0==strcmp(argv[2], "valid2")) validate_detector_flip(datacfg, cfg, weights, outfile);
|
840 |
+
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
|
841 |
+
else if(0==strcmp(argv[2], "demo")) {
|
842 |
+
list *options = read_data_cfg(datacfg);
|
843 |
+
int classes = option_find_int(options, "classes", 20);
|
844 |
+
char *name_list = option_find_str(options, "names", "data/names.list");
|
845 |
+
char **names = get_labels(name_list);
|
846 |
+
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, avg, hier_thresh, width, height, fps, fullscreen);
|
847 |
+
}
|
848 |
+
//else if(0==strcmp(argv[2], "extract")) extract_detector(datacfg, cfg, weights, cam_index, filename, class, thresh, frame_skip);
|
849 |
+
//else if(0==strcmp(argv[2], "censor")) censor_detector(datacfg, cfg, weights, cam_index, filename, class, thresh, frame_skip);
|
850 |
+
}
|
model/examples/detector.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stupid python path shit.
|
2 |
+
# Instead just add darknet.py to somewhere in your python path
|
3 |
+
# OK actually that might not be a great idea, idk, work in progress
|
4 |
+
# Use at your own risk. or don't, i don't care
|
5 |
+
|
6 |
+
import sys, os
|
7 |
+
sys.path.append(os.path.join(os.getcwd(),'python/'))
|
8 |
+
|
9 |
+
import darknet as dn
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
dn.set_gpu(0)
|
13 |
+
net = dn.load_net("cfg/yolo-thor.cfg", "/home/pjreddie/backup/yolo-thor_final.weights", 0)
|
14 |
+
meta = dn.load_meta("cfg/thor.data")
|
15 |
+
r = dn.detect(net, meta, "data/bedroom.jpg")
|
16 |
+
print r
|
17 |
+
|
18 |
+
# And then down here you could detect a lot more images like:
|
19 |
+
r = dn.detect(net, meta, "data/eagle.jpg")
|
20 |
+
print r
|
21 |
+
r = dn.detect(net, meta, "data/giraffe.jpg")
|
22 |
+
print r
|
23 |
+
r = dn.detect(net, meta, "data/horses.jpg")
|
24 |
+
print r
|
25 |
+
r = dn.detect(net, meta, "data/person.jpg")
|
26 |
+
print r
|
27 |
+
|
model/examples/dice.c
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
char *dice_labels[] = {"face1","face2","face3","face4","face5","face6"};
|
4 |
+
|
5 |
+
void train_dice(char *cfgfile, char *weightfile)
|
6 |
+
{
|
7 |
+
srand(time(0));
|
8 |
+
float avg_loss = -1;
|
9 |
+
char *base = basecfg(cfgfile);
|
10 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
11 |
+
printf("%s\n", base);
|
12 |
+
network net = parse_network_cfg(cfgfile);
|
13 |
+
if(weightfile){
|
14 |
+
load_weights(&net, weightfile);
|
15 |
+
}
|
16 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
17 |
+
int imgs = 1024;
|
18 |
+
int i = *net.seen/imgs;
|
19 |
+
char **labels = dice_labels;
|
20 |
+
list *plist = get_paths("data/dice/dice.train.list");
|
21 |
+
char **paths = (char **)list_to_array(plist);
|
22 |
+
printf("%d\n", plist->size);
|
23 |
+
clock_t time;
|
24 |
+
while(1){
|
25 |
+
++i;
|
26 |
+
time=clock();
|
27 |
+
data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h);
|
28 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
29 |
+
|
30 |
+
time=clock();
|
31 |
+
float loss = train_network(net, train);
|
32 |
+
if(avg_loss == -1) avg_loss = loss;
|
33 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
34 |
+
printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
|
35 |
+
free_data(train);
|
36 |
+
if((i % 100) == 0) net.learning_rate *= .1;
|
37 |
+
if(i%100==0){
|
38 |
+
char buff[256];
|
39 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, i);
|
40 |
+
save_weights(net, buff);
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
void validate_dice(char *filename, char *weightfile)
|
46 |
+
{
|
47 |
+
network net = parse_network_cfg(filename);
|
48 |
+
if(weightfile){
|
49 |
+
load_weights(&net, weightfile);
|
50 |
+
}
|
51 |
+
srand(time(0));
|
52 |
+
|
53 |
+
char **labels = dice_labels;
|
54 |
+
list *plist = get_paths("data/dice/dice.val.list");
|
55 |
+
|
56 |
+
char **paths = (char **)list_to_array(plist);
|
57 |
+
int m = plist->size;
|
58 |
+
free_list(plist);
|
59 |
+
|
60 |
+
data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h);
|
61 |
+
float *acc = network_accuracies(net, val, 2);
|
62 |
+
printf("Validation Accuracy: %f, %d images\n", acc[0], m);
|
63 |
+
free_data(val);
|
64 |
+
}
|
65 |
+
|
66 |
+
void test_dice(char *cfgfile, char *weightfile, char *filename)
|
67 |
+
{
|
68 |
+
network net = parse_network_cfg(cfgfile);
|
69 |
+
if(weightfile){
|
70 |
+
load_weights(&net, weightfile);
|
71 |
+
}
|
72 |
+
set_batch_network(&net, 1);
|
73 |
+
srand(2222222);
|
74 |
+
int i = 0;
|
75 |
+
char **names = dice_labels;
|
76 |
+
char buff[256];
|
77 |
+
char *input = buff;
|
78 |
+
int indexes[6];
|
79 |
+
while(1){
|
80 |
+
if(filename){
|
81 |
+
strncpy(input, filename, 256);
|
82 |
+
}else{
|
83 |
+
printf("Enter Image Path: ");
|
84 |
+
fflush(stdout);
|
85 |
+
input = fgets(input, 256, stdin);
|
86 |
+
if(!input) return;
|
87 |
+
strtok(input, "\n");
|
88 |
+
}
|
89 |
+
image im = load_image_color(input, net.w, net.h);
|
90 |
+
float *X = im.data;
|
91 |
+
float *predictions = network_predict(net, X);
|
92 |
+
top_predictions(net, 6, indexes);
|
93 |
+
for(i = 0; i < 6; ++i){
|
94 |
+
int index = indexes[i];
|
95 |
+
printf("%s: %f\n", names[index], predictions[index]);
|
96 |
+
}
|
97 |
+
free_image(im);
|
98 |
+
if (filename) break;
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
void run_dice(int argc, char **argv)
|
103 |
+
{
|
104 |
+
if(argc < 4){
|
105 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
106 |
+
return;
|
107 |
+
}
|
108 |
+
|
109 |
+
char *cfg = argv[3];
|
110 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
111 |
+
char *filename = (argc > 5) ? argv[5]: 0;
|
112 |
+
if(0==strcmp(argv[2], "test")) test_dice(cfg, weights, filename);
|
113 |
+
else if(0==strcmp(argv[2], "train")) train_dice(cfg, weights);
|
114 |
+
else if(0==strcmp(argv[2], "valid")) validate_dice(cfg, weights);
|
115 |
+
}
|
116 |
+
|
model/examples/go.c
ADDED
@@ -0,0 +1,1370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <assert.h>
|
4 |
+
#include <math.h>
|
5 |
+
#include <unistd.h>
|
6 |
+
|
7 |
+
int inverted = 1;
|
8 |
+
int noi = 1;
|
9 |
+
static const int nind = 10;
|
10 |
+
int legal_go(float *b, float *ko, int p, int r, int c);
|
11 |
+
int check_ko(float *x, float *ko);
|
12 |
+
|
13 |
+
typedef struct {
|
14 |
+
char **data;
|
15 |
+
int n;
|
16 |
+
} moves;
|
17 |
+
|
18 |
+
char *fgetgo(FILE *fp)
|
19 |
+
{
|
20 |
+
if(feof(fp)) return 0;
|
21 |
+
size_t size = 96;
|
22 |
+
char *line = malloc(size*sizeof(char));
|
23 |
+
if(size != fread(line, sizeof(char), size, fp)){
|
24 |
+
free(line);
|
25 |
+
return 0;
|
26 |
+
}
|
27 |
+
|
28 |
+
return line;
|
29 |
+
}
|
30 |
+
|
31 |
+
moves load_go_moves(char *filename)
|
32 |
+
{
|
33 |
+
moves m;
|
34 |
+
m.n = 128;
|
35 |
+
m.data = calloc(128, sizeof(char*));
|
36 |
+
FILE *fp = fopen(filename, "rb");
|
37 |
+
int count = 0;
|
38 |
+
char *line = 0;
|
39 |
+
while ((line = fgetgo(fp))) {
|
40 |
+
if (count >= m.n) {
|
41 |
+
m.n *= 2;
|
42 |
+
m.data = realloc(m.data, m.n*sizeof(char*));
|
43 |
+
}
|
44 |
+
m.data[count] = line;
|
45 |
+
++count;
|
46 |
+
}
|
47 |
+
printf("%d\n", count);
|
48 |
+
m.n = count;
|
49 |
+
m.data = realloc(m.data, count*sizeof(char*));
|
50 |
+
return m;
|
51 |
+
}
|
52 |
+
|
53 |
+
void string_to_board(char *s, float *board)
|
54 |
+
{
|
55 |
+
int i, j;
|
56 |
+
memset(board, 0, 2*19*19*sizeof(float));
|
57 |
+
int count = 0;
|
58 |
+
for(i = 0; i < 91; ++i){
|
59 |
+
char c = s[i];
|
60 |
+
for(j = 0; j < 4; ++j){
|
61 |
+
int me = (c >> (2*j)) & 1;
|
62 |
+
int you = (c >> (2*j + 1)) & 1;
|
63 |
+
if (me) board[count] = 1;
|
64 |
+
else if (you) board[count + 19*19] = 1;
|
65 |
+
++count;
|
66 |
+
if(count >= 19*19) break;
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
void board_to_string(char *s, float *board)
|
72 |
+
{
|
73 |
+
int i, j;
|
74 |
+
memset(s, 0, (19*19/4+1)*sizeof(char));
|
75 |
+
int count = 0;
|
76 |
+
for(i = 0; i < 91; ++i){
|
77 |
+
for(j = 0; j < 4; ++j){
|
78 |
+
int me = (board[count] == 1);
|
79 |
+
int you = (board[count + 19*19] == 1);
|
80 |
+
if (me) s[i] = s[i] | (1<<(2*j));
|
81 |
+
if (you) s[i] = s[i] | (1<<(2*j + 1));
|
82 |
+
++count;
|
83 |
+
if(count >= 19*19) break;
|
84 |
+
}
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
static int occupied(float *b, int i)
|
89 |
+
{
|
90 |
+
if (b[i]) return 1;
|
91 |
+
if (b[i+19*19]) return -1;
|
92 |
+
return 0;
|
93 |
+
}
|
94 |
+
|
95 |
+
data random_go_moves(moves m, int n)
|
96 |
+
{
|
97 |
+
data d = {0};
|
98 |
+
d.X = make_matrix(n, 19*19*3);
|
99 |
+
d.y = make_matrix(n, 19*19+2);
|
100 |
+
int i, j;
|
101 |
+
for(i = 0; i < n; ++i){
|
102 |
+
float *board = d.X.vals[i];
|
103 |
+
float *label = d.y.vals[i];
|
104 |
+
char *b = m.data[rand()%m.n];
|
105 |
+
int player = b[0] - '0';
|
106 |
+
int result = b[1] - '0';
|
107 |
+
int row = b[2];
|
108 |
+
int col = b[3];
|
109 |
+
string_to_board(b+4, board);
|
110 |
+
if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
|
111 |
+
label[19*19+1] = (player==result);
|
112 |
+
if(row >= 19 || col >= 19){
|
113 |
+
label[19*19] = 1;
|
114 |
+
} else {
|
115 |
+
label[col + 19*row] = 1;
|
116 |
+
if(occupied(board, col + 19*row)) printf("hey\n");
|
117 |
+
}
|
118 |
+
|
119 |
+
int flip = rand()%2;
|
120 |
+
int rotate = rand()%4;
|
121 |
+
image in = float_to_image(19, 19, 3, board);
|
122 |
+
image out = float_to_image(19, 19, 1, label);
|
123 |
+
if(flip){
|
124 |
+
flip_image(in);
|
125 |
+
flip_image(out);
|
126 |
+
}
|
127 |
+
rotate_image_cw(in, rotate);
|
128 |
+
rotate_image_cw(out, rotate);
|
129 |
+
}
|
130 |
+
return d;
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ngpus, int clear)
|
135 |
+
{
|
136 |
+
int i;
|
137 |
+
float avg_loss = -1;
|
138 |
+
char *base = basecfg(cfgfile);
|
139 |
+
printf("%s\n", base);
|
140 |
+
printf("%d\n", ngpus);
|
141 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
142 |
+
|
143 |
+
srand(time(0));
|
144 |
+
int seed = rand();
|
145 |
+
for(i = 0; i < ngpus; ++i){
|
146 |
+
srand(seed);
|
147 |
+
#ifdef GPU
|
148 |
+
cuda_set_device(gpus[i]);
|
149 |
+
#endif
|
150 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
151 |
+
nets[i]->learning_rate *= ngpus;
|
152 |
+
}
|
153 |
+
network *net = nets[0];
|
154 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
155 |
+
|
156 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
157 |
+
|
158 |
+
char buff[256];
|
159 |
+
moves m = load_go_moves(filename);
|
160 |
+
//moves m = load_go_moves("games.txt");
|
161 |
+
|
162 |
+
int N = m.n;
|
163 |
+
printf("Moves: %d\n", N);
|
164 |
+
int epoch = (*net->seen)/N;
|
165 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
166 |
+
double time=what_time_is_it_now();
|
167 |
+
|
168 |
+
data train = random_go_moves(m, net->batch*net->subdivisions*ngpus);
|
169 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now() - time);
|
170 |
+
time=what_time_is_it_now();
|
171 |
+
|
172 |
+
float loss = 0;
|
173 |
+
#ifdef GPU
|
174 |
+
if(ngpus == 1){
|
175 |
+
loss = train_network(net, train);
|
176 |
+
} else {
|
177 |
+
loss = train_networks(nets, ngpus, train, 10);
|
178 |
+
}
|
179 |
+
#else
|
180 |
+
loss = train_network(net, train);
|
181 |
+
#endif
|
182 |
+
free_data(train);
|
183 |
+
|
184 |
+
if(avg_loss == -1) avg_loss = loss;
|
185 |
+
avg_loss = avg_loss*.95 + loss*.05;
|
186 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
187 |
+
if(*net->seen/N > epoch){
|
188 |
+
epoch = *net->seen/N;
|
189 |
+
char buff[256];
|
190 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
|
191 |
+
save_weights(net, buff);
|
192 |
+
|
193 |
+
}
|
194 |
+
if(get_current_batch(net)%1000 == 0){
|
195 |
+
char buff[256];
|
196 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
197 |
+
save_weights(net, buff);
|
198 |
+
}
|
199 |
+
if(get_current_batch(net)%10000 == 0){
|
200 |
+
char buff[256];
|
201 |
+
sprintf(buff, "%s/%s_%ld.backup",backup_directory,base,get_current_batch(net));
|
202 |
+
save_weights(net, buff);
|
203 |
+
}
|
204 |
+
}
|
205 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
206 |
+
save_weights(net, buff);
|
207 |
+
|
208 |
+
free_network(net);
|
209 |
+
free(base);
|
210 |
+
}
|
211 |
+
|
212 |
+
static void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
|
213 |
+
{
|
214 |
+
if (row < 0 || row > 18 || col < 0 || col > 18) return;
|
215 |
+
int index = row*19 + col;
|
216 |
+
if (occupied(board,index) != side) return;
|
217 |
+
if (visited[index]) return;
|
218 |
+
visited[index] = 1;
|
219 |
+
lib[index] += 1;
|
220 |
+
propagate_liberty(board, lib, visited, row+1, col, side);
|
221 |
+
propagate_liberty(board, lib, visited, row-1, col, side);
|
222 |
+
propagate_liberty(board, lib, visited, row, col+1, side);
|
223 |
+
propagate_liberty(board, lib, visited, row, col-1, side);
|
224 |
+
}
|
225 |
+
|
226 |
+
|
227 |
+
static int *calculate_liberties(float *board)
|
228 |
+
{
|
229 |
+
int *lib = calloc(19*19, sizeof(int));
|
230 |
+
int visited[19*19];
|
231 |
+
int i, j;
|
232 |
+
for(j = 0; j < 19; ++j){
|
233 |
+
for(i = 0; i < 19; ++i){
|
234 |
+
memset(visited, 0, 19*19*sizeof(int));
|
235 |
+
int index = j*19 + i;
|
236 |
+
if(!occupied(board,index)){
|
237 |
+
if ((i > 0) && occupied(board,index - 1)) propagate_liberty(board, lib, visited, j, i-1, occupied(board,index-1));
|
238 |
+
if ((i < 18) && occupied(board,index + 1)) propagate_liberty(board, lib, visited, j, i+1, occupied(board,index+1));
|
239 |
+
if ((j > 0) && occupied(board,index - 19)) propagate_liberty(board, lib, visited, j-1, i, occupied(board,index-19));
|
240 |
+
if ((j < 18) && occupied(board,index + 19)) propagate_liberty(board, lib, visited, j+1, i, occupied(board,index+19));
|
241 |
+
}
|
242 |
+
}
|
243 |
+
}
|
244 |
+
return lib;
|
245 |
+
}
|
246 |
+
|
247 |
+
void print_board(FILE *stream, float *board, int player, int *indexes)
|
248 |
+
{
|
249 |
+
int i,j,n;
|
250 |
+
fprintf(stream, " ");
|
251 |
+
for(i = 0; i < 19; ++i){
|
252 |
+
fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
|
253 |
+
}
|
254 |
+
fprintf(stream, "\n");
|
255 |
+
for(j = 0; j < 19; ++j){
|
256 |
+
fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
|
257 |
+
for(i = 0; i < 19; ++i){
|
258 |
+
int index = j*19 + i;
|
259 |
+
if(indexes){
|
260 |
+
int found = 0;
|
261 |
+
for(n = 0; n < nind; ++n){
|
262 |
+
if(index == indexes[n]){
|
263 |
+
found = 1;
|
264 |
+
/*
|
265 |
+
if(n == 0) fprintf(stream, "\uff11");
|
266 |
+
else if(n == 1) fprintf(stream, "\uff12");
|
267 |
+
else if(n == 2) fprintf(stream, "\uff13");
|
268 |
+
else if(n == 3) fprintf(stream, "\uff14");
|
269 |
+
else if(n == 4) fprintf(stream, "\uff15");
|
270 |
+
*/
|
271 |
+
fprintf(stream, " %d", n+1);
|
272 |
+
}
|
273 |
+
}
|
274 |
+
if(found) continue;
|
275 |
+
}
|
276 |
+
//if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
|
277 |
+
//else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
|
278 |
+
if (occupied(board, index) == player) fprintf(stream, " X");
|
279 |
+
else if (occupied(board, index) ==-player) fprintf(stream, " O");
|
280 |
+
else fprintf(stream, " .");
|
281 |
+
}
|
282 |
+
fprintf(stream, "\n");
|
283 |
+
}
|
284 |
+
}
|
285 |
+
|
286 |
+
void flip_board(float *board)
|
287 |
+
{
|
288 |
+
int i;
|
289 |
+
for(i = 0; i < 19*19; ++i){
|
290 |
+
float swap = board[i];
|
291 |
+
board[i] = board[i+19*19];
|
292 |
+
board[i+19*19] = swap;
|
293 |
+
board[i+19*19*2] = 1-board[i+19*19*2];
|
294 |
+
}
|
295 |
+
}
|
296 |
+
|
297 |
+
float predict_move2(network *net, float *board, float *move, int multi)
|
298 |
+
{
|
299 |
+
float *output = network_predict(net, board);
|
300 |
+
copy_cpu(19*19+1, output, 1, move, 1);
|
301 |
+
float result = output[19*19 + 1];
|
302 |
+
int i;
|
303 |
+
if(multi){
|
304 |
+
image bim = float_to_image(19, 19, 3, board);
|
305 |
+
for(i = 1; i < 8; ++i){
|
306 |
+
rotate_image_cw(bim, i);
|
307 |
+
if(i >= 4) flip_image(bim);
|
308 |
+
|
309 |
+
float *output = network_predict(net, board);
|
310 |
+
image oim = float_to_image(19, 19, 1, output);
|
311 |
+
result += output[19*19 + 1];
|
312 |
+
|
313 |
+
if(i >= 4) flip_image(oim);
|
314 |
+
rotate_image_cw(oim, -i);
|
315 |
+
|
316 |
+
axpy_cpu(19*19+1, 1, output, 1, move, 1);
|
317 |
+
|
318 |
+
if(i >= 4) flip_image(bim);
|
319 |
+
rotate_image_cw(bim, -i);
|
320 |
+
}
|
321 |
+
result = result/8;
|
322 |
+
scal_cpu(19*19+1, 1./8., move, 1);
|
323 |
+
}
|
324 |
+
for(i = 0; i < 19*19; ++i){
|
325 |
+
if(board[i] || board[i+19*19]) move[i] = 0;
|
326 |
+
}
|
327 |
+
return result;
|
328 |
+
}
|
329 |
+
|
330 |
+
static void remove_connected(float *b, int *lib, int p, int r, int c)
|
331 |
+
{
|
332 |
+
if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
|
333 |
+
if (occupied(b, r*19 + c) != p) return;
|
334 |
+
if (lib[r*19 + c] != 1) return;
|
335 |
+
b[r*19 + c] = 0;
|
336 |
+
b[19*19 + r*19 + c] = 0;
|
337 |
+
remove_connected(b, lib, p, r+1, c);
|
338 |
+
remove_connected(b, lib, p, r-1, c);
|
339 |
+
remove_connected(b, lib, p, r, c+1);
|
340 |
+
remove_connected(b, lib, p, r, c-1);
|
341 |
+
}
|
342 |
+
|
343 |
+
|
344 |
+
void move_go(float *b, int p, int r, int c)
|
345 |
+
{
|
346 |
+
int *l = calculate_liberties(b);
|
347 |
+
if(p > 0) b[r*19 + c] = 1;
|
348 |
+
else b[19*19 + r*19 + c] = 1;
|
349 |
+
remove_connected(b, l, -p, r+1, c);
|
350 |
+
remove_connected(b, l, -p, r-1, c);
|
351 |
+
remove_connected(b, l, -p, r, c+1);
|
352 |
+
remove_connected(b, l, -p, r, c-1);
|
353 |
+
free(l);
|
354 |
+
}
|
355 |
+
|
356 |
+
int compare_board(float *a, float *b)
|
357 |
+
{
|
358 |
+
if(memcmp(a, b, 19*19*3*sizeof(float)) == 0) return 1;
|
359 |
+
return 0;
|
360 |
+
}
|
361 |
+
|
362 |
+
typedef struct mcts_tree{
|
363 |
+
float *board;
|
364 |
+
struct mcts_tree **children;
|
365 |
+
float *prior;
|
366 |
+
int *visit_count;
|
367 |
+
float *value;
|
368 |
+
float *mean;
|
369 |
+
float *prob;
|
370 |
+
int total_count;
|
371 |
+
float result;
|
372 |
+
int done;
|
373 |
+
int pass;
|
374 |
+
} mcts_tree;
|
375 |
+
|
376 |
+
void free_mcts(mcts_tree *root)
|
377 |
+
{
|
378 |
+
if(!root) return;
|
379 |
+
int i;
|
380 |
+
free(root->board);
|
381 |
+
for(i = 0; i < 19*19+1; ++i){
|
382 |
+
if(root->children[i]) free_mcts(root->children[i]);
|
383 |
+
}
|
384 |
+
free(root->children);
|
385 |
+
free(root->prior);
|
386 |
+
free(root->visit_count);
|
387 |
+
free(root->value);
|
388 |
+
free(root->mean);
|
389 |
+
free(root->prob);
|
390 |
+
free(root);
|
391 |
+
}
|
392 |
+
|
393 |
+
float *network_predict_rotations(network *net, float *next)
|
394 |
+
{
|
395 |
+
int n = net->batch;
|
396 |
+
float *in = calloc(19*19*3*n, sizeof(float));
|
397 |
+
image im = float_to_image(19, 19, 3, next);
|
398 |
+
int i,j;
|
399 |
+
int *inds = random_index_order(0, 8);
|
400 |
+
for(j = 0; j < n; ++j){
|
401 |
+
i = inds[j];
|
402 |
+
rotate_image_cw(im, i);
|
403 |
+
if(i >= 4) flip_image(im);
|
404 |
+
memcpy(in + 19*19*3*j, im.data, 19*19*3*sizeof(float));
|
405 |
+
if(i >= 4) flip_image(im);
|
406 |
+
rotate_image_cw(im, -i);
|
407 |
+
}
|
408 |
+
float *pred = network_predict(net, in);
|
409 |
+
for(j = 0; j < n; ++j){
|
410 |
+
i = inds[j];
|
411 |
+
image im = float_to_image(19, 19, 1, pred + j*(19*19 + 2));
|
412 |
+
if(i >= 4) flip_image(im);
|
413 |
+
rotate_image_cw(im, -i);
|
414 |
+
if(j > 0){
|
415 |
+
axpy_cpu(19*19+2, 1, im.data, 1, pred, 1);
|
416 |
+
}
|
417 |
+
}
|
418 |
+
free(in);
|
419 |
+
free(inds);
|
420 |
+
scal_cpu(19*19+2, 1./n, pred, 1);
|
421 |
+
return pred;
|
422 |
+
}
|
423 |
+
|
424 |
+
mcts_tree *expand(float *next, float *ko, network *net)
|
425 |
+
{
|
426 |
+
mcts_tree *root = calloc(1, sizeof(mcts_tree));
|
427 |
+
root->board = next;
|
428 |
+
root->children = calloc(19*19+1, sizeof(mcts_tree*));
|
429 |
+
root->prior = calloc(19*19 + 1, sizeof(float));
|
430 |
+
root->prob = calloc(19*19 + 1, sizeof(float));
|
431 |
+
root->mean = calloc(19*19 + 1, sizeof(float));
|
432 |
+
root->value = calloc(19*19 + 1, sizeof(float));
|
433 |
+
root->visit_count = calloc(19*19 + 1, sizeof(int));
|
434 |
+
root->total_count = 1;
|
435 |
+
int i;
|
436 |
+
float *pred = network_predict_rotations(net, next);
|
437 |
+
copy_cpu(19*19+1, pred, 1, root->prior, 1);
|
438 |
+
float val = 2*pred[19*19 + 1] - 1;
|
439 |
+
root->result = val;
|
440 |
+
for(i = 0; i < 19*19+1; ++i) {
|
441 |
+
root->visit_count[i] = 0;
|
442 |
+
root->value[i] = 0;
|
443 |
+
root->mean[i] = val;
|
444 |
+
if(i < 19*19 && occupied(next, i)){
|
445 |
+
root->value[i] = -1;
|
446 |
+
root->mean[i] = -1;
|
447 |
+
root->prior[i] = 0;
|
448 |
+
}
|
449 |
+
}
|
450 |
+
//print_board(stderr, next, flip?-1:1, 0);
|
451 |
+
return root;
|
452 |
+
}
|
453 |
+
|
454 |
+
float *copy_board(float *board)
|
455 |
+
{
|
456 |
+
float *next = calloc(19*19*3, sizeof(float));
|
457 |
+
copy_cpu(19*19*3, board, 1, next, 1);
|
458 |
+
return next;
|
459 |
+
}
|
460 |
+
|
461 |
+
float select_mcts(mcts_tree *root, network *net, float *prev, float cpuct)
|
462 |
+
{
|
463 |
+
if(root->done) return -root->result;
|
464 |
+
int i;
|
465 |
+
float max = -1000;
|
466 |
+
int max_i = 0;
|
467 |
+
for(i = 0; i < 19*19+1; ++i){
|
468 |
+
root->prob[i] = root->mean[i] + cpuct*root->prior[i] * sqrt(root->total_count) / (1. + root->visit_count[i]);
|
469 |
+
if(root->prob[i] > max){
|
470 |
+
max = root->prob[i];
|
471 |
+
max_i = i;
|
472 |
+
}
|
473 |
+
}
|
474 |
+
float val;
|
475 |
+
i = max_i;
|
476 |
+
root->visit_count[i]++;
|
477 |
+
root->total_count++;
|
478 |
+
if (root->children[i]) {
|
479 |
+
val = select_mcts(root->children[i], net, root->board, cpuct);
|
480 |
+
} else {
|
481 |
+
if(max_i < 19*19 && !legal_go(root->board, prev, 1, max_i/19, max_i%19)) {
|
482 |
+
root->mean[i] = -1;
|
483 |
+
root->value[i] = -1;
|
484 |
+
root->prior[i] = 0;
|
485 |
+
--root->total_count;
|
486 |
+
return select_mcts(root, net, prev, cpuct);
|
487 |
+
//printf("Detected ko\n");
|
488 |
+
//getchar();
|
489 |
+
} else {
|
490 |
+
float *next = copy_board(root->board);
|
491 |
+
if (max_i < 19*19) {
|
492 |
+
move_go(next, 1, max_i / 19, max_i % 19);
|
493 |
+
}
|
494 |
+
flip_board(next);
|
495 |
+
root->children[i] = expand(next, root->board, net);
|
496 |
+
val = -root->children[i]->result;
|
497 |
+
if(max_i == 19*19){
|
498 |
+
root->children[i]->pass = 1;
|
499 |
+
if (root->pass){
|
500 |
+
root->children[i]->done = 1;
|
501 |
+
}
|
502 |
+
}
|
503 |
+
}
|
504 |
+
}
|
505 |
+
root->value[i] += val;
|
506 |
+
root->mean[i] = root->value[i]/root->visit_count[i];
|
507 |
+
return -val;
|
508 |
+
}
|
509 |
+
|
510 |
+
mcts_tree *run_mcts(mcts_tree *tree, network *net, float *board, float *ko, int player, int n, float cpuct, float secs)
|
511 |
+
{
|
512 |
+
int i;
|
513 |
+
double t = what_time_is_it_now();
|
514 |
+
if(player < 0) flip_board(board);
|
515 |
+
if(!tree) tree = expand(copy_board(board), ko, net);
|
516 |
+
assert(compare_board(tree->board, board));
|
517 |
+
for(i = 0; i < n; ++i){
|
518 |
+
if (secs > 0 && (what_time_is_it_now() - t) > secs) break;
|
519 |
+
int max_i = max_int_index(tree->visit_count, 19*19+1);
|
520 |
+
if (tree->visit_count[max_i] >= n) break;
|
521 |
+
select_mcts(tree, net, ko, cpuct);
|
522 |
+
}
|
523 |
+
if(player < 0) flip_board(board);
|
524 |
+
//fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
|
525 |
+
return tree;
|
526 |
+
}
|
527 |
+
|
528 |
+
mcts_tree *move_mcts(mcts_tree *tree, int index)
|
529 |
+
{
|
530 |
+
if(index < 0 || index > 19*19 || !tree || !tree->children[index]) {
|
531 |
+
free_mcts(tree);
|
532 |
+
tree = 0;
|
533 |
+
} else {
|
534 |
+
mcts_tree *swap = tree;
|
535 |
+
tree = tree->children[index];
|
536 |
+
swap->children[index] = 0;
|
537 |
+
free_mcts(swap);
|
538 |
+
}
|
539 |
+
return tree;
|
540 |
+
}
|
541 |
+
|
542 |
+
typedef struct {
|
543 |
+
float value;
|
544 |
+
float mcts;
|
545 |
+
int row;
|
546 |
+
int col;
|
547 |
+
} move;
|
548 |
+
|
549 |
+
move pick_move(mcts_tree *tree, float temp, int player)
|
550 |
+
{
|
551 |
+
int i;
|
552 |
+
float probs[19*19+1] = {0};
|
553 |
+
move m = {0};
|
554 |
+
double sum = 0;
|
555 |
+
/*
|
556 |
+
for(i = 0; i < 19*19+1; ++i){
|
557 |
+
probs[i] = tree->visit_count[i];
|
558 |
+
}
|
559 |
+
*/
|
560 |
+
//softmax(probs, 19*19+1, temp, 1, probs);
|
561 |
+
for(i = 0; i < 19*19+1; ++i){
|
562 |
+
sum += pow(tree->visit_count[i], 1./temp);
|
563 |
+
}
|
564 |
+
for(i = 0; i < 19*19+1; ++i){
|
565 |
+
probs[i] = pow(tree->visit_count[i], 1./temp) / sum;
|
566 |
+
}
|
567 |
+
|
568 |
+
int index = sample_array(probs, 19*19+1);
|
569 |
+
m.row = index / 19;
|
570 |
+
m.col = index % 19;
|
571 |
+
m.value = (tree->result+1.)/2.;
|
572 |
+
m.mcts = (tree->mean[index]+1.)/2.;
|
573 |
+
|
574 |
+
int indexes[nind];
|
575 |
+
top_k(probs, 19*19+1, nind, indexes);
|
576 |
+
print_board(stderr, tree->board, player, indexes);
|
577 |
+
|
578 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", index/19, index%19, tree->result, tree->prior[index], probs[index], tree->mean[index], (tree->children[index])?tree->children[index]->result:0, tree->visit_count[index]);
|
579 |
+
int ind = max_index(probs, 19*19+1);
|
580 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
|
581 |
+
ind = max_index(tree->prior, 19*19+1);
|
582 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
|
583 |
+
return m;
|
584 |
+
}
|
585 |
+
|
586 |
+
/*
|
587 |
+
float predict_move(network *net, float *board, float *move, int multi, float *ko, float temp)
|
588 |
+
{
|
589 |
+
|
590 |
+
int i;
|
591 |
+
|
592 |
+
int max_v = 0;
|
593 |
+
int max_i = 0;
|
594 |
+
for(i = 0; i < 19*19+1; ++i){
|
595 |
+
if(root->visit_count[i] > max_v){
|
596 |
+
max_v = root->visit_count[i];
|
597 |
+
max_i = i;
|
598 |
+
}
|
599 |
+
}
|
600 |
+
fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
|
601 |
+
int ind = max_index(root->mean, 19*19+1);
|
602 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", max_i/19, max_i%19, root->result, root->prior[max_i], root->prob[max_i], root->mean[max_i], (root->children[max_i])?root->children[max_i]->result:0, root->visit_count[max_i]);
|
603 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
|
604 |
+
ind = max_index(root->prior, 19*19+1);
|
605 |
+
fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
|
606 |
+
if(root->result < -.9 && root->mean[max_i] < -.9) return -1000.f;
|
607 |
+
|
608 |
+
float val = root->result;
|
609 |
+
free_mcts(root);
|
610 |
+
return val;
|
611 |
+
}
|
612 |
+
*/
|
613 |
+
|
614 |
+
static int makes_safe_go(float *b, int *lib, int p, int r, int c){
|
615 |
+
if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
|
616 |
+
if (occupied(b,r*19 + c) == -p){
|
617 |
+
if (lib[r*19 + c] > 1) return 0;
|
618 |
+
else return 1;
|
619 |
+
}
|
620 |
+
if (!occupied(b,r*19 + c)) return 1;
|
621 |
+
if (lib[r*19 + c] > 1) return 1;
|
622 |
+
return 0;
|
623 |
+
}
|
624 |
+
|
625 |
+
int suicide_go(float *b, int p, int r, int c)
|
626 |
+
{
|
627 |
+
int *l = calculate_liberties(b);
|
628 |
+
int safe = 0;
|
629 |
+
safe = safe || makes_safe_go(b, l, p, r+1, c);
|
630 |
+
safe = safe || makes_safe_go(b, l, p, r-1, c);
|
631 |
+
safe = safe || makes_safe_go(b, l, p, r, c+1);
|
632 |
+
safe = safe || makes_safe_go(b, l, p, r, c-1);
|
633 |
+
free(l);
|
634 |
+
return !safe;
|
635 |
+
}
|
636 |
+
|
637 |
+
int check_ko(float *x, float *ko)
|
638 |
+
{
|
639 |
+
if(!ko) return 0;
|
640 |
+
float curr[19*19*3];
|
641 |
+
copy_cpu(19*19*3, x, 1, curr, 1);
|
642 |
+
if(curr[19*19*2] != ko[19*19*2]) flip_board(curr);
|
643 |
+
if(compare_board(curr, ko)) return 1;
|
644 |
+
return 0;
|
645 |
+
}
|
646 |
+
|
647 |
+
int legal_go(float *b, float *ko, int p, int r, int c)
|
648 |
+
{
|
649 |
+
if (occupied(b, r*19+c)) return 0;
|
650 |
+
float curr[19*19*3];
|
651 |
+
copy_cpu(19*19*3, b, 1, curr, 1);
|
652 |
+
move_go(curr, p, r, c);
|
653 |
+
if(check_ko(curr, ko)) return 0;
|
654 |
+
if(suicide_go(b, p, r, c)) return 0;
|
655 |
+
return 1;
|
656 |
+
}
|
657 |
+
|
658 |
+
/*
|
659 |
+
move generate_move(mcts_tree *root, network *net, int player, float *board, int multi, float temp, float *ko, int print)
|
660 |
+
{
|
661 |
+
move m = {0};
|
662 |
+
//root = run_mcts(tree, network *net, float *board, float *ko, int n, float cpuct)
|
663 |
+
int i, j;
|
664 |
+
int empty = 1;
|
665 |
+
for(i = 0; i < 19*19; ++i){
|
666 |
+
if (occupied(board, i)) {
|
667 |
+
empty = 0;
|
668 |
+
break;
|
669 |
+
}
|
670 |
+
}
|
671 |
+
if(empty) {
|
672 |
+
m.value = .5;
|
673 |
+
m.mcts = .5;
|
674 |
+
m.row = 3;
|
675 |
+
m.col = 15;
|
676 |
+
return m;
|
677 |
+
}
|
678 |
+
|
679 |
+
float move[362];
|
680 |
+
if (player < 0) flip_board(board);
|
681 |
+
float result = predict_move(net, board, move, multi, ko, temp);
|
682 |
+
if (player < 0) flip_board(board);
|
683 |
+
if(result == -1000.f) return -2;
|
684 |
+
|
685 |
+
for(i = 0; i < 19; ++i){
|
686 |
+
for(j = 0; j < 19; ++j){
|
687 |
+
if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
|
688 |
+
}
|
689 |
+
}
|
690 |
+
|
691 |
+
int indexes[nind];
|
692 |
+
top_k(move, 19*19+1, nind, indexes);
|
693 |
+
|
694 |
+
|
695 |
+
int max = max_index(move, 19*19+1);
|
696 |
+
int row = max / 19;
|
697 |
+
int col = max % 19;
|
698 |
+
int index = sample_array(move, 19*19+1);
|
699 |
+
|
700 |
+
if(print){
|
701 |
+
top_k(move, 19*19+1, nind, indexes);
|
702 |
+
for(i = 0; i < nind; ++i){
|
703 |
+
if (!move[indexes[i]]) indexes[i] = -1;
|
704 |
+
}
|
705 |
+
print_board(stderr, board, 1, indexes);
|
706 |
+
fprintf(stderr, "%s To Move\n", player > 0 ? "X" : "O");
|
707 |
+
fprintf(stderr, "%.2f%% Win Chance\n", (result+1)/2*100);
|
708 |
+
for(i = 0; i < nind; ++i){
|
709 |
+
int index = indexes[i];
|
710 |
+
int row = index / 19;
|
711 |
+
int col = index % 19;
|
712 |
+
if(row == 19){
|
713 |
+
fprintf(stderr, "%d: Pass, %.2f%%\n", i+1, move[index]*100);
|
714 |
+
} else {
|
715 |
+
fprintf(stderr, "%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
|
716 |
+
}
|
717 |
+
}
|
718 |
+
}
|
719 |
+
if (row == 19) return -1;
|
720 |
+
|
721 |
+
if (suicide_go(board, player, row, col)){
|
722 |
+
return -1;
|
723 |
+
}
|
724 |
+
|
725 |
+
if (suicide_go(board, player, index/19, index%19)){
|
726 |
+
index = max;
|
727 |
+
}
|
728 |
+
if (index == 19*19) return -1;
|
729 |
+
return index;
|
730 |
+
}
|
731 |
+
*/
|
732 |
+
|
733 |
+
void valid_go(char *cfgfile, char *weightfile, int multi, char *filename)
|
734 |
+
{
|
735 |
+
srand(time(0));
|
736 |
+
char *base = basecfg(cfgfile);
|
737 |
+
printf("%s\n", base);
|
738 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
739 |
+
set_batch_network(net, 1);
|
740 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
741 |
+
|
742 |
+
float *board = calloc(19*19*3, sizeof(float));
|
743 |
+
float *move = calloc(19*19+2, sizeof(float));
|
744 |
+
// moves m = load_go_moves("/home/pjreddie/backup/go.test");
|
745 |
+
moves m = load_go_moves(filename);
|
746 |
+
|
747 |
+
int N = m.n;
|
748 |
+
int i,j;
|
749 |
+
int correct = 0;
|
750 |
+
for (i = 0; i <N; ++i) {
|
751 |
+
char *b = m.data[i];
|
752 |
+
int player = b[0] - '0';
|
753 |
+
//int result = b[1] - '0';
|
754 |
+
int row = b[2];
|
755 |
+
int col = b[3];
|
756 |
+
int truth = col + 19*row;
|
757 |
+
string_to_board(b+4, board);
|
758 |
+
if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
|
759 |
+
predict_move2(net, board, move, multi);
|
760 |
+
int index = max_index(move, 19*19+1);
|
761 |
+
if(index == truth) ++correct;
|
762 |
+
printf("%d Accuracy %f\n", i, (float) correct/(i+1));
|
763 |
+
}
|
764 |
+
}
|
765 |
+
|
766 |
+
int print_game(float *board, FILE *fp)
|
767 |
+
{
|
768 |
+
int i, j;
|
769 |
+
int count = 3;
|
770 |
+
fprintf(fp, "komi 6.5\n");
|
771 |
+
fprintf(fp, "boardsize 19\n");
|
772 |
+
fprintf(fp, "clear_board\n");
|
773 |
+
for(j = 0; j < 19; ++j){
|
774 |
+
for(i = 0; i < 19; ++i){
|
775 |
+
if(occupied(board,j*19 + i) == 1) fprintf(fp, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
|
776 |
+
if(occupied(board,j*19 + i) == -1) fprintf(fp, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
|
777 |
+
if(occupied(board,j*19 + i)) ++count;
|
778 |
+
}
|
779 |
+
}
|
780 |
+
return count;
|
781 |
+
}
|
782 |
+
|
783 |
+
|
784 |
+
int stdin_ready()
|
785 |
+
{
|
786 |
+
fd_set readfds;
|
787 |
+
FD_ZERO(&readfds);
|
788 |
+
|
789 |
+
struct timeval timeout;
|
790 |
+
timeout.tv_sec = 0;
|
791 |
+
timeout.tv_usec = 0;
|
792 |
+
FD_SET(STDIN_FILENO, &readfds);
|
793 |
+
|
794 |
+
if (select(1, &readfds, NULL, NULL, &timeout)){
|
795 |
+
return 1;
|
796 |
+
}
|
797 |
+
return 0;
|
798 |
+
}
|
799 |
+
|
800 |
+
mcts_tree *ponder(mcts_tree *tree, network *net, float *b, float *ko, int player, float cpuct)
|
801 |
+
{
|
802 |
+
double t = what_time_is_it_now();
|
803 |
+
int count = 0;
|
804 |
+
if (tree) count = tree->total_count;
|
805 |
+
while(!stdin_ready()){
|
806 |
+
if (what_time_is_it_now() - t > 120) break;
|
807 |
+
tree = run_mcts(tree, net, b, ko, player, 100000, cpuct, .1);
|
808 |
+
}
|
809 |
+
fprintf(stderr, "Pondered %d moves...\n", tree->total_count - count);
|
810 |
+
return tree;
|
811 |
+
}
|
812 |
+
|
813 |
+
void engine_go(char *filename, char *weightfile, int mcts_iters, float secs, float temp, float cpuct, int anon, int resign)
|
814 |
+
{
|
815 |
+
mcts_tree *root = 0;
|
816 |
+
network *net = load_network(filename, weightfile, 0);
|
817 |
+
set_batch_network(net, 1);
|
818 |
+
srand(time(0));
|
819 |
+
float *board = calloc(19*19*3, sizeof(float));
|
820 |
+
flip_board(board);
|
821 |
+
float *one = calloc(19*19*3, sizeof(float));
|
822 |
+
float *two = calloc(19*19*3, sizeof(float));
|
823 |
+
int ponder_player = 0;
|
824 |
+
int passed = 0;
|
825 |
+
int move_num = 0;
|
826 |
+
int main_time = 0;
|
827 |
+
int byo_yomi_time = 0;
|
828 |
+
int byo_yomi_stones = 0;
|
829 |
+
int black_time_left = 0;
|
830 |
+
int black_stones_left = 0;
|
831 |
+
int white_time_left = 0;
|
832 |
+
int white_stones_left = 0;
|
833 |
+
float orig_time = secs;
|
834 |
+
int old_ponder = 0;
|
835 |
+
while(1){
|
836 |
+
if(ponder_player){
|
837 |
+
root = ponder(root, net, board, two, ponder_player, cpuct);
|
838 |
+
}
|
839 |
+
old_ponder = ponder_player;
|
840 |
+
ponder_player = 0;
|
841 |
+
char buff[256];
|
842 |
+
int id = 0;
|
843 |
+
int has_id = (scanf("%d", &id) == 1);
|
844 |
+
scanf("%s", buff);
|
845 |
+
if (feof(stdin)) break;
|
846 |
+
fprintf(stderr, "%s\n", buff);
|
847 |
+
char ids[256];
|
848 |
+
sprintf(ids, "%d", id);
|
849 |
+
//fprintf(stderr, "%s\n", buff);
|
850 |
+
if (!has_id) ids[0] = 0;
|
851 |
+
if (!strcmp(buff, "protocol_version")){
|
852 |
+
printf("=%s 2\n\n", ids);
|
853 |
+
} else if (!strcmp(buff, "name")){
|
854 |
+
if(anon){
|
855 |
+
printf("=%s The Fool!\n\n", ids);
|
856 |
+
}else{
|
857 |
+
printf("=%s DarkGo\n\n", ids);
|
858 |
+
}
|
859 |
+
} else if (!strcmp(buff, "time_settings")){
|
860 |
+
ponder_player = old_ponder;
|
861 |
+
scanf("%d %d %d", &main_time, &byo_yomi_time, &byo_yomi_stones);
|
862 |
+
printf("=%s \n\n", ids);
|
863 |
+
} else if (!strcmp(buff, "time_left")){
|
864 |
+
ponder_player = old_ponder;
|
865 |
+
char color[256];
|
866 |
+
int time = 0, stones = 0;
|
867 |
+
scanf("%s %d %d", color, &time, &stones);
|
868 |
+
if (color[0] == 'b' || color[0] == 'B'){
|
869 |
+
black_time_left = time;
|
870 |
+
black_stones_left = stones;
|
871 |
+
} else {
|
872 |
+
white_time_left = time;
|
873 |
+
white_stones_left = stones;
|
874 |
+
}
|
875 |
+
printf("=%s \n\n", ids);
|
876 |
+
} else if (!strcmp(buff, "version")){
|
877 |
+
if(anon){
|
878 |
+
printf("=%s :-DDDD\n\n", ids);
|
879 |
+
}else {
|
880 |
+
printf("=%s 1.0. Want more DarkGo? You can find me on OGS, unlimited games, no waiting! https://online-go.com/user/view/434218\n\n", ids);
|
881 |
+
}
|
882 |
+
} else if (!strcmp(buff, "known_command")){
|
883 |
+
char comm[256];
|
884 |
+
scanf("%s", comm);
|
885 |
+
int known = (!strcmp(comm, "protocol_version") ||
|
886 |
+
!strcmp(comm, "name") ||
|
887 |
+
!strcmp(comm, "version") ||
|
888 |
+
!strcmp(comm, "known_command") ||
|
889 |
+
!strcmp(comm, "list_commands") ||
|
890 |
+
!strcmp(comm, "quit") ||
|
891 |
+
!strcmp(comm, "boardsize") ||
|
892 |
+
!strcmp(comm, "clear_board") ||
|
893 |
+
!strcmp(comm, "komi") ||
|
894 |
+
!strcmp(comm, "final_status_list") ||
|
895 |
+
!strcmp(comm, "play") ||
|
896 |
+
!strcmp(comm, "genmove_white") ||
|
897 |
+
!strcmp(comm, "genmove_black") ||
|
898 |
+
!strcmp(comm, "fixed_handicap") ||
|
899 |
+
!strcmp(comm, "genmove"));
|
900 |
+
if(known) printf("=%s true\n\n", ids);
|
901 |
+
else printf("=%s false\n\n", ids);
|
902 |
+
} else if (!strcmp(buff, "list_commands")){
|
903 |
+
printf("=%s protocol_version\nshowboard\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove_black\ngenmove_white\ngenmove\nfinal_status_list\nfixed_handicap\n\n", ids);
|
904 |
+
} else if (!strcmp(buff, "quit")){
|
905 |
+
break;
|
906 |
+
} else if (!strcmp(buff, "boardsize")){
|
907 |
+
int boardsize = 0;
|
908 |
+
scanf("%d", &boardsize);
|
909 |
+
//fprintf(stderr, "%d\n", boardsize);
|
910 |
+
if(boardsize != 19){
|
911 |
+
printf("?%s unacceptable size\n\n", ids);
|
912 |
+
} else {
|
913 |
+
root = move_mcts(root, -1);
|
914 |
+
memset(board, 0, 3*19*19*sizeof(float));
|
915 |
+
flip_board(board);
|
916 |
+
move_num = 0;
|
917 |
+
printf("=%s \n\n", ids);
|
918 |
+
}
|
919 |
+
} else if (!strcmp(buff, "fixed_handicap")){
|
920 |
+
int handicap = 0;
|
921 |
+
scanf("%d", &handicap);
|
922 |
+
int indexes[] = {72, 288, 300, 60, 180, 174, 186, 66, 294};
|
923 |
+
int i;
|
924 |
+
for(i = 0; i < handicap; ++i){
|
925 |
+
board[indexes[i]] = 1;
|
926 |
+
++move_num;
|
927 |
+
}
|
928 |
+
root = move_mcts(root, -1);
|
929 |
+
} else if (!strcmp(buff, "clear_board")){
|
930 |
+
passed = 0;
|
931 |
+
memset(board, 0, 3*19*19*sizeof(float));
|
932 |
+
flip_board(board);
|
933 |
+
move_num = 0;
|
934 |
+
root = move_mcts(root, -1);
|
935 |
+
printf("=%s \n\n", ids);
|
936 |
+
} else if (!strcmp(buff, "komi")){
|
937 |
+
float komi = 0;
|
938 |
+
scanf("%f", &komi);
|
939 |
+
printf("=%s \n\n", ids);
|
940 |
+
} else if (!strcmp(buff, "showboard")){
|
941 |
+
printf("=%s \n", ids);
|
942 |
+
print_board(stdout, board, 1, 0);
|
943 |
+
printf("\n");
|
944 |
+
} else if (!strcmp(buff, "play") || !strcmp(buff, "black") || !strcmp(buff, "white")){
|
945 |
+
++move_num;
|
946 |
+
char color[256];
|
947 |
+
if(!strcmp(buff, "play"))
|
948 |
+
{
|
949 |
+
scanf("%s ", color);
|
950 |
+
} else {
|
951 |
+
scanf(" ");
|
952 |
+
color[0] = buff[0];
|
953 |
+
}
|
954 |
+
char c;
|
955 |
+
int r;
|
956 |
+
int count = scanf("%c%d", &c, &r);
|
957 |
+
int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
|
958 |
+
if((c == 'p' || c == 'P') && count < 2) {
|
959 |
+
passed = 1;
|
960 |
+
printf("=%s \n\n", ids);
|
961 |
+
char *line = fgetl(stdin);
|
962 |
+
free(line);
|
963 |
+
fflush(stdout);
|
964 |
+
fflush(stderr);
|
965 |
+
root = move_mcts(root, 19*19);
|
966 |
+
continue;
|
967 |
+
} else {
|
968 |
+
passed = 0;
|
969 |
+
}
|
970 |
+
if(c >= 'A' && c <= 'Z') c = c - 'A';
|
971 |
+
if(c >= 'a' && c <= 'z') c = c - 'a';
|
972 |
+
if(c >= 8) --c;
|
973 |
+
r = 19 - r;
|
974 |
+
fprintf(stderr, "move: %d %d\n", r, c);
|
975 |
+
|
976 |
+
float *swap = two;
|
977 |
+
two = one;
|
978 |
+
one = swap;
|
979 |
+
move_go(board, player, r, c);
|
980 |
+
copy_cpu(19*19*3, board, 1, one, 1);
|
981 |
+
if(root) fprintf(stderr, "Prior: %f\n", root->prior[r*19 + c]);
|
982 |
+
if(root) fprintf(stderr, "Mean: %f\n", root->mean[r*19 + c]);
|
983 |
+
if(root) fprintf(stderr, "Result: %f\n", root->result);
|
984 |
+
root = move_mcts(root, r*19 + c);
|
985 |
+
if(root) fprintf(stderr, "Visited: %d\n", root->total_count);
|
986 |
+
else fprintf(stderr, "NOT VISITED\n");
|
987 |
+
|
988 |
+
printf("=%s \n\n", ids);
|
989 |
+
//print_board(stderr, board, 1, 0);
|
990 |
+
} else if (!strcmp(buff, "genmove") || !strcmp(buff, "genmove_black") || !strcmp(buff, "genmove_white")){
|
991 |
+
++move_num;
|
992 |
+
int player = 0;
|
993 |
+
if(!strcmp(buff, "genmove")){
|
994 |
+
char color[256];
|
995 |
+
scanf("%s", color);
|
996 |
+
player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
|
997 |
+
} else if (!strcmp(buff, "genmove_black")){
|
998 |
+
player = 1;
|
999 |
+
} else {
|
1000 |
+
player = -1;
|
1001 |
+
}
|
1002 |
+
if(player > 0){
|
1003 |
+
if(black_time_left <= 30) secs = 2.5;
|
1004 |
+
else secs = orig_time;
|
1005 |
+
} else {
|
1006 |
+
if(white_time_left <= 30) secs = 2.5;
|
1007 |
+
else secs = orig_time;
|
1008 |
+
}
|
1009 |
+
ponder_player = -player;
|
1010 |
+
|
1011 |
+
//tree = generate_move(net, player, board, multi, .1, two, 1);
|
1012 |
+
double t = what_time_is_it_now();
|
1013 |
+
root = run_mcts(root, net, board, two, player, mcts_iters, cpuct, secs);
|
1014 |
+
fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
|
1015 |
+
move m = pick_move(root, temp, player);
|
1016 |
+
root = move_mcts(root, m.row*19 + m.col);
|
1017 |
+
|
1018 |
+
|
1019 |
+
if(move_num > resign && m.value < .1 && m.mcts < .1){
|
1020 |
+
printf("=%s resign\n\n", ids);
|
1021 |
+
} else if(m.row == 19){
|
1022 |
+
printf("=%s pass\n\n", ids);
|
1023 |
+
passed = 0;
|
1024 |
+
} else {
|
1025 |
+
int row = m.row;
|
1026 |
+
int col = m.col;
|
1027 |
+
|
1028 |
+
float *swap = two;
|
1029 |
+
two = one;
|
1030 |
+
one = swap;
|
1031 |
+
|
1032 |
+
move_go(board, player, row, col);
|
1033 |
+
copy_cpu(19*19*3, board, 1, one, 1);
|
1034 |
+
row = 19 - row;
|
1035 |
+
if (col >= 8) ++col;
|
1036 |
+
printf("=%s %c%d\n\n", ids, 'A' + col, row);
|
1037 |
+
}
|
1038 |
+
|
1039 |
+
} else if (!strcmp(buff, "p")){
|
1040 |
+
//print_board(board, 1, 0);
|
1041 |
+
} else if (!strcmp(buff, "final_status_list")){
|
1042 |
+
char type[256];
|
1043 |
+
scanf("%s", type);
|
1044 |
+
fprintf(stderr, "final_status\n");
|
1045 |
+
char *line = fgetl(stdin);
|
1046 |
+
free(line);
|
1047 |
+
if(type[0] == 'd' || type[0] == 'D'){
|
1048 |
+
int i;
|
1049 |
+
FILE *f = fopen("game.txt", "w");
|
1050 |
+
int count = print_game(board, f);
|
1051 |
+
fprintf(f, "%s final_status_list dead\n", ids);
|
1052 |
+
fclose(f);
|
1053 |
+
FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
|
1054 |
+
for(i = 0; i < count; ++i){
|
1055 |
+
free(fgetl(p));
|
1056 |
+
free(fgetl(p));
|
1057 |
+
}
|
1058 |
+
char *l = 0;
|
1059 |
+
while((l = fgetl(p))){
|
1060 |
+
printf("%s\n", l);
|
1061 |
+
free(l);
|
1062 |
+
}
|
1063 |
+
} else {
|
1064 |
+
printf("?%s unknown command\n\n", ids);
|
1065 |
+
}
|
1066 |
+
} else if (!strcmp(buff, "kgs-genmove_cleanup")){
|
1067 |
+
char type[256];
|
1068 |
+
scanf("%s", type);
|
1069 |
+
fprintf(stderr, "kgs-genmove_cleanup\n");
|
1070 |
+
char *line = fgetl(stdin);
|
1071 |
+
free(line);
|
1072 |
+
int i;
|
1073 |
+
FILE *f = fopen("game.txt", "w");
|
1074 |
+
int count = print_game(board, f);
|
1075 |
+
fprintf(f, "%s kgs-genmove_cleanup %s\n", ids, type);
|
1076 |
+
fclose(f);
|
1077 |
+
FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
|
1078 |
+
for(i = 0; i < count; ++i){
|
1079 |
+
free(fgetl(p));
|
1080 |
+
free(fgetl(p));
|
1081 |
+
}
|
1082 |
+
char *l = 0;
|
1083 |
+
while((l = fgetl(p))){
|
1084 |
+
printf("%s\n", l);
|
1085 |
+
free(l);
|
1086 |
+
}
|
1087 |
+
} else {
|
1088 |
+
char *line = fgetl(stdin);
|
1089 |
+
free(line);
|
1090 |
+
printf("?%s unknown command\n\n", ids);
|
1091 |
+
}
|
1092 |
+
fflush(stdout);
|
1093 |
+
fflush(stderr);
|
1094 |
+
}
|
1095 |
+
printf("%d %d %d\n",passed, black_stones_left, white_stones_left);
|
1096 |
+
}
|
1097 |
+
|
1098 |
+
void test_go(char *cfg, char *weights, int multi)
|
1099 |
+
{
|
1100 |
+
int i;
|
1101 |
+
network *net = load_network(cfg, weights, 0);
|
1102 |
+
set_batch_network(net, 1);
|
1103 |
+
srand(time(0));
|
1104 |
+
float *board = calloc(19*19*3, sizeof(float));
|
1105 |
+
flip_board(board);
|
1106 |
+
float *move = calloc(19*19+1, sizeof(float));
|
1107 |
+
int color = 1;
|
1108 |
+
while(1){
|
1109 |
+
float result = predict_move2(net, board, move, multi);
|
1110 |
+
printf("%.2f%% Win Chance\n", (result+1)/2*100);
|
1111 |
+
|
1112 |
+
int indexes[nind];
|
1113 |
+
int row, col;
|
1114 |
+
top_k(move, 19*19+1, nind, indexes);
|
1115 |
+
print_board(stderr, board, color, indexes);
|
1116 |
+
for(i = 0; i < nind; ++i){
|
1117 |
+
int index = indexes[i];
|
1118 |
+
row = index / 19;
|
1119 |
+
col = index % 19;
|
1120 |
+
if(row == 19){
|
1121 |
+
printf("%d: Pass, %.2f%%\n", i+1, move[index]*100);
|
1122 |
+
} else {
|
1123 |
+
printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
|
1124 |
+
}
|
1125 |
+
}
|
1126 |
+
//if(color == 1) printf("\u25EF Enter move: ");
|
1127 |
+
//else printf("\u25C9 Enter move: ");
|
1128 |
+
if(color == 1) printf("X Enter move: ");
|
1129 |
+
else printf("O Enter move: ");
|
1130 |
+
|
1131 |
+
char c;
|
1132 |
+
char *line = fgetl(stdin);
|
1133 |
+
int picked = 1;
|
1134 |
+
int dnum = sscanf(line, "%d", &picked);
|
1135 |
+
int cnum = sscanf(line, "%c", &c);
|
1136 |
+
if (strlen(line) == 0 || dnum) {
|
1137 |
+
--picked;
|
1138 |
+
if (picked < nind){
|
1139 |
+
int index = indexes[picked];
|
1140 |
+
row = index / 19;
|
1141 |
+
col = index % 19;
|
1142 |
+
if(row < 19){
|
1143 |
+
move_go(board, 1, row, col);
|
1144 |
+
}
|
1145 |
+
}
|
1146 |
+
} else if (cnum){
|
1147 |
+
if (c <= 'T' && c >= 'A'){
|
1148 |
+
int num = sscanf(line, "%c %d", &c, &row);
|
1149 |
+
row = (inverted)?19 - row : row-1;
|
1150 |
+
col = c - 'A';
|
1151 |
+
if (col > 7 && noi) col -= 1;
|
1152 |
+
if (num == 2) move_go(board, 1, row, col);
|
1153 |
+
} else if (c == 'p') {
|
1154 |
+
// Pass
|
1155 |
+
} else if(c=='b' || c == 'w'){
|
1156 |
+
char g;
|
1157 |
+
int num = sscanf(line, "%c %c %d", &g, &c, &row);
|
1158 |
+
row = (inverted)?19 - row : row-1;
|
1159 |
+
col = c - 'A';
|
1160 |
+
if (col > 7 && noi) col -= 1;
|
1161 |
+
if (num == 3) {
|
1162 |
+
int mc = (g == 'b') ? 1 : -1;
|
1163 |
+
if (mc == color) {
|
1164 |
+
board[row*19 + col] = 1;
|
1165 |
+
} else {
|
1166 |
+
board[19*19 + row*19 + col] = 1;
|
1167 |
+
}
|
1168 |
+
}
|
1169 |
+
} else if(c == 'c'){
|
1170 |
+
char g;
|
1171 |
+
int num = sscanf(line, "%c %c %d", &g, &c, &row);
|
1172 |
+
row = (inverted)?19 - row : row-1;
|
1173 |
+
col = c - 'A';
|
1174 |
+
if (col > 7 && noi) col -= 1;
|
1175 |
+
if (num == 3) {
|
1176 |
+
board[row*19 + col] = 0;
|
1177 |
+
board[19*19 + row*19 + col] = 0;
|
1178 |
+
}
|
1179 |
+
}
|
1180 |
+
}
|
1181 |
+
free(line);
|
1182 |
+
flip_board(board);
|
1183 |
+
color = -color;
|
1184 |
+
}
|
1185 |
+
}
|
1186 |
+
|
1187 |
+
float score_game(float *board)
|
1188 |
+
{
|
1189 |
+
int i;
|
1190 |
+
FILE *f = fopen("game.txt", "w");
|
1191 |
+
int count = print_game(board, f);
|
1192 |
+
fprintf(f, "final_score\n");
|
1193 |
+
fclose(f);
|
1194 |
+
FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
|
1195 |
+
for(i = 0; i < count; ++i){
|
1196 |
+
free(fgetl(p));
|
1197 |
+
free(fgetl(p));
|
1198 |
+
}
|
1199 |
+
char *l = 0;
|
1200 |
+
float score = 0;
|
1201 |
+
char player = 0;
|
1202 |
+
while((l = fgetl(p))){
|
1203 |
+
fprintf(stderr, "%s \t", l);
|
1204 |
+
int n = sscanf(l, "= %c+%f", &player, &score);
|
1205 |
+
free(l);
|
1206 |
+
if (n == 2) break;
|
1207 |
+
}
|
1208 |
+
if(player == 'W') score = -score;
|
1209 |
+
pclose(p);
|
1210 |
+
return score;
|
1211 |
+
}
|
1212 |
+
|
1213 |
+
void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
|
1214 |
+
{
|
1215 |
+
mcts_tree *tree1 = 0;
|
1216 |
+
mcts_tree *tree2 = 0;
|
1217 |
+
network *net = load_network(filename, weightfile, 0);
|
1218 |
+
//set_batch_network(net, 1);
|
1219 |
+
|
1220 |
+
network *net2;
|
1221 |
+
if (f2) {
|
1222 |
+
net2 = parse_network_cfg(f2);
|
1223 |
+
if(w2){
|
1224 |
+
load_weights(net2, w2);
|
1225 |
+
}
|
1226 |
+
} else {
|
1227 |
+
net2 = calloc(1, sizeof(network));
|
1228 |
+
*net2 = *net;
|
1229 |
+
}
|
1230 |
+
srand(time(0));
|
1231 |
+
char boards[600][93];
|
1232 |
+
int count = 0;
|
1233 |
+
//set_batch_network(net, 1);
|
1234 |
+
//set_batch_network(net2, 1);
|
1235 |
+
float *board = calloc(19*19*3, sizeof(float));
|
1236 |
+
flip_board(board);
|
1237 |
+
float *one = calloc(19*19*3, sizeof(float));
|
1238 |
+
float *two = calloc(19*19*3, sizeof(float));
|
1239 |
+
int done = 0;
|
1240 |
+
int player = 1;
|
1241 |
+
int p1 = 0;
|
1242 |
+
int p2 = 0;
|
1243 |
+
int total = 0;
|
1244 |
+
float temp = .1;
|
1245 |
+
int mcts_iters = 500;
|
1246 |
+
float cpuct = 5;
|
1247 |
+
while(1){
|
1248 |
+
if (done){
|
1249 |
+
tree1 = move_mcts(tree1, -1);
|
1250 |
+
tree2 = move_mcts(tree2, -1);
|
1251 |
+
float score = score_game(board);
|
1252 |
+
if((score > 0) == (total%2==0)) ++p1;
|
1253 |
+
else ++p2;
|
1254 |
+
++total;
|
1255 |
+
fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
|
1256 |
+
sleep(1);
|
1257 |
+
/*
|
1258 |
+
int i = (score > 0)? 0 : 1;
|
1259 |
+
int j;
|
1260 |
+
for(; i < count; i += 2){
|
1261 |
+
for(j = 0; j < 93; ++j){
|
1262 |
+
printf("%c", boards[i][j]);
|
1263 |
+
}
|
1264 |
+
printf("\n");
|
1265 |
+
}
|
1266 |
+
*/
|
1267 |
+
memset(board, 0, 3*19*19*sizeof(float));
|
1268 |
+
flip_board(board);
|
1269 |
+
player = 1;
|
1270 |
+
done = 0;
|
1271 |
+
count = 0;
|
1272 |
+
fflush(stdout);
|
1273 |
+
fflush(stderr);
|
1274 |
+
}
|
1275 |
+
//print_board(stderr, board, 1, 0);
|
1276 |
+
//sleep(1);
|
1277 |
+
|
1278 |
+
if ((total%2==0) == (player==1)){
|
1279 |
+
//mcts_iters = 4500;
|
1280 |
+
cpuct = 5;
|
1281 |
+
} else {
|
1282 |
+
//mcts_iters = 500;
|
1283 |
+
cpuct = 1;
|
1284 |
+
}
|
1285 |
+
network *use = ((total%2==0) == (player==1)) ? net : net2;
|
1286 |
+
mcts_tree *t = ((total%2==0) == (player==1)) ? tree1 : tree2;
|
1287 |
+
t = run_mcts(t, use, board, two, player, mcts_iters, cpuct, 0);
|
1288 |
+
move m = pick_move(t, temp, player);
|
1289 |
+
if(((total%2==0) == (player==1))) tree1 = t;
|
1290 |
+
else tree2 = t;
|
1291 |
+
|
1292 |
+
tree1 = move_mcts(tree1, m.row*19 + m.col);
|
1293 |
+
tree2 = move_mcts(tree2, m.row*19 + m.col);
|
1294 |
+
|
1295 |
+
if(m.row == 19){
|
1296 |
+
done = 1;
|
1297 |
+
continue;
|
1298 |
+
}
|
1299 |
+
int row = m.row;
|
1300 |
+
int col = m.col;
|
1301 |
+
|
1302 |
+
float *swap = two;
|
1303 |
+
two = one;
|
1304 |
+
one = swap;
|
1305 |
+
|
1306 |
+
if(player < 0) flip_board(board);
|
1307 |
+
boards[count][0] = row;
|
1308 |
+
boards[count][1] = col;
|
1309 |
+
board_to_string(boards[count] + 2, board);
|
1310 |
+
if(player < 0) flip_board(board);
|
1311 |
+
++count;
|
1312 |
+
|
1313 |
+
move_go(board, player, row, col);
|
1314 |
+
copy_cpu(19*19*3, board, 1, one, 1);
|
1315 |
+
|
1316 |
+
player = -player;
|
1317 |
+
}
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
void run_go(int argc, char **argv)
|
1321 |
+
{
|
1322 |
+
//boards_go();
|
1323 |
+
if(argc < 4){
|
1324 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
1325 |
+
return;
|
1326 |
+
}
|
1327 |
+
|
1328 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
1329 |
+
int *gpus = 0;
|
1330 |
+
int gpu = 0;
|
1331 |
+
int ngpus = 0;
|
1332 |
+
if(gpu_list){
|
1333 |
+
printf("%s\n", gpu_list);
|
1334 |
+
int len = strlen(gpu_list);
|
1335 |
+
ngpus = 1;
|
1336 |
+
int i;
|
1337 |
+
for(i = 0; i < len; ++i){
|
1338 |
+
if (gpu_list[i] == ',') ++ngpus;
|
1339 |
+
}
|
1340 |
+
gpus = calloc(ngpus, sizeof(int));
|
1341 |
+
for(i = 0; i < ngpus; ++i){
|
1342 |
+
gpus[i] = atoi(gpu_list);
|
1343 |
+
gpu_list = strchr(gpu_list, ',')+1;
|
1344 |
+
}
|
1345 |
+
} else {
|
1346 |
+
gpu = gpu_index;
|
1347 |
+
gpus = &gpu;
|
1348 |
+
ngpus = 1;
|
1349 |
+
}
|
1350 |
+
int clear = find_arg(argc, argv, "-clear");
|
1351 |
+
|
1352 |
+
char *cfg = argv[3];
|
1353 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
1354 |
+
char *c2 = (argc > 5) ? argv[5] : 0;
|
1355 |
+
char *w2 = (argc > 6) ? argv[6] : 0;
|
1356 |
+
int multi = find_arg(argc, argv, "-multi");
|
1357 |
+
int anon = find_arg(argc, argv, "-anon");
|
1358 |
+
int iters = find_int_arg(argc, argv, "-iters", 500);
|
1359 |
+
int resign = find_int_arg(argc, argv, "-resign", 175);
|
1360 |
+
float cpuct = find_float_arg(argc, argv, "-cpuct", 5);
|
1361 |
+
float temp = find_float_arg(argc, argv, "-temp", .1);
|
1362 |
+
float time = find_float_arg(argc, argv, "-time", 0);
|
1363 |
+
if(0==strcmp(argv[2], "train")) train_go(cfg, weights, c2, gpus, ngpus, clear);
|
1364 |
+
else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi, c2);
|
1365 |
+
else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
|
1366 |
+
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
|
1367 |
+
else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, iters, time, temp, cpuct, anon, resign);
|
1368 |
+
}
|
1369 |
+
|
1370 |
+
|
model/examples/instance-segmenter.c
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
#include <sys/time.h>
|
3 |
+
#include <assert.h>
|
4 |
+
|
5 |
+
void normalize_image2(image p);
|
6 |
+
void train_isegmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int display)
|
7 |
+
{
|
8 |
+
int i;
|
9 |
+
|
10 |
+
float avg_loss = -1;
|
11 |
+
char *base = basecfg(cfgfile);
|
12 |
+
printf("%s\n", base);
|
13 |
+
printf("%d\n", ngpus);
|
14 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
15 |
+
|
16 |
+
srand(time(0));
|
17 |
+
int seed = rand();
|
18 |
+
for(i = 0; i < ngpus; ++i){
|
19 |
+
srand(seed);
|
20 |
+
#ifdef GPU
|
21 |
+
cuda_set_device(gpus[i]);
|
22 |
+
#endif
|
23 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
24 |
+
nets[i]->learning_rate *= ngpus;
|
25 |
+
}
|
26 |
+
srand(time(0));
|
27 |
+
network *net = nets[0];
|
28 |
+
image pred = get_network_image(net);
|
29 |
+
|
30 |
+
image embed = pred;
|
31 |
+
embed.c = 3;
|
32 |
+
embed.data += embed.w*embed.h*80;
|
33 |
+
|
34 |
+
int div = net->w/pred.w;
|
35 |
+
assert(pred.w * div == net->w);
|
36 |
+
assert(pred.h * div == net->h);
|
37 |
+
|
38 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
39 |
+
|
40 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
41 |
+
list *options = read_data_cfg(datacfg);
|
42 |
+
|
43 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
44 |
+
char *train_list = option_find_str(options, "train", "data/train.list");
|
45 |
+
|
46 |
+
list *plist = get_paths(train_list);
|
47 |
+
char **paths = (char **)list_to_array(plist);
|
48 |
+
printf("%d\n", plist->size);
|
49 |
+
int N = plist->size;
|
50 |
+
|
51 |
+
load_args args = {0};
|
52 |
+
args.w = net->w;
|
53 |
+
args.h = net->h;
|
54 |
+
args.threads = 32;
|
55 |
+
args.scale = div;
|
56 |
+
args.num_boxes = 90;
|
57 |
+
|
58 |
+
args.min = net->min_crop;
|
59 |
+
args.max = net->max_crop;
|
60 |
+
args.angle = net->angle;
|
61 |
+
args.aspect = net->aspect;
|
62 |
+
args.exposure = net->exposure;
|
63 |
+
args.saturation = net->saturation;
|
64 |
+
args.hue = net->hue;
|
65 |
+
args.size = net->w;
|
66 |
+
args.classes = 80;
|
67 |
+
|
68 |
+
args.paths = paths;
|
69 |
+
args.n = imgs;
|
70 |
+
args.m = N;
|
71 |
+
args.type = ISEG_DATA;
|
72 |
+
|
73 |
+
data train;
|
74 |
+
data buffer;
|
75 |
+
pthread_t load_thread;
|
76 |
+
args.d = &buffer;
|
77 |
+
load_thread = load_data(args);
|
78 |
+
|
79 |
+
int epoch = (*net->seen)/N;
|
80 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
81 |
+
double time = what_time_is_it_now();
|
82 |
+
|
83 |
+
pthread_join(load_thread, 0);
|
84 |
+
train = buffer;
|
85 |
+
load_thread = load_data(args);
|
86 |
+
|
87 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
88 |
+
time = what_time_is_it_now();
|
89 |
+
|
90 |
+
float loss = 0;
|
91 |
+
#ifdef GPU
|
92 |
+
if(ngpus == 1){
|
93 |
+
loss = train_network(net, train);
|
94 |
+
} else {
|
95 |
+
loss = train_networks(nets, ngpus, train, 4);
|
96 |
+
}
|
97 |
+
#else
|
98 |
+
loss = train_network(net, train);
|
99 |
+
#endif
|
100 |
+
if(display){
|
101 |
+
image tr = float_to_image(net->w/div, net->h/div, 80, train.y.vals[net->batch*(net->subdivisions-1)]);
|
102 |
+
image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]);
|
103 |
+
pred.c = 80;
|
104 |
+
image mask = mask_to_rgb(tr);
|
105 |
+
image prmask = mask_to_rgb(pred);
|
106 |
+
image ecopy = copy_image(embed);
|
107 |
+
normalize_image2(ecopy);
|
108 |
+
show_image(ecopy, "embed", 1);
|
109 |
+
free_image(ecopy);
|
110 |
+
|
111 |
+
show_image(im, "input", 1);
|
112 |
+
show_image(prmask, "pred", 1);
|
113 |
+
show_image(mask, "truth", 100);
|
114 |
+
free_image(mask);
|
115 |
+
free_image(prmask);
|
116 |
+
}
|
117 |
+
if(avg_loss == -1) avg_loss = loss;
|
118 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
119 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
120 |
+
free_data(train);
|
121 |
+
if(*net->seen/N > epoch){
|
122 |
+
epoch = *net->seen/N;
|
123 |
+
char buff[256];
|
124 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
125 |
+
save_weights(net, buff);
|
126 |
+
}
|
127 |
+
if(get_current_batch(net)%100 == 0){
|
128 |
+
char buff[256];
|
129 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
130 |
+
save_weights(net, buff);
|
131 |
+
}
|
132 |
+
}
|
133 |
+
char buff[256];
|
134 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
135 |
+
save_weights(net, buff);
|
136 |
+
|
137 |
+
free_network(net);
|
138 |
+
free_ptrs((void**)paths, plist->size);
|
139 |
+
free_list(plist);
|
140 |
+
free(base);
|
141 |
+
}
|
142 |
+
|
143 |
+
void predict_isegmenter(char *datafile, char *cfg, char *weights, char *filename)
|
144 |
+
{
|
145 |
+
network *net = load_network(cfg, weights, 0);
|
146 |
+
set_batch_network(net, 1);
|
147 |
+
srand(2222222);
|
148 |
+
|
149 |
+
clock_t time;
|
150 |
+
char buff[256];
|
151 |
+
char *input = buff;
|
152 |
+
while(1){
|
153 |
+
if(filename){
|
154 |
+
strncpy(input, filename, 256);
|
155 |
+
}else{
|
156 |
+
printf("Enter Image Path: ");
|
157 |
+
fflush(stdout);
|
158 |
+
input = fgets(input, 256, stdin);
|
159 |
+
if(!input) return;
|
160 |
+
strtok(input, "\n");
|
161 |
+
}
|
162 |
+
image im = load_image_color(input, 0, 0);
|
163 |
+
image sized = letterbox_image(im, net->w, net->h);
|
164 |
+
|
165 |
+
float *X = sized.data;
|
166 |
+
time=clock();
|
167 |
+
float *predictions = network_predict(net, X);
|
168 |
+
image pred = get_network_image(net);
|
169 |
+
image prmask = mask_to_rgb(pred);
|
170 |
+
printf("Predicted: %f\n", predictions[0]);
|
171 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
172 |
+
show_image(sized, "orig", 1);
|
173 |
+
show_image(prmask, "pred", 0);
|
174 |
+
free_image(im);
|
175 |
+
free_image(sized);
|
176 |
+
free_image(prmask);
|
177 |
+
if (filename) break;
|
178 |
+
}
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
void demo_isegmenter(char *datacfg, char *cfg, char *weights, int cam_index, const char *filename)
|
183 |
+
{
|
184 |
+
#ifdef OPENCV
|
185 |
+
printf("Classifier Demo\n");
|
186 |
+
network *net = load_network(cfg, weights, 0);
|
187 |
+
set_batch_network(net, 1);
|
188 |
+
|
189 |
+
srand(2222222);
|
190 |
+
void * cap = open_video_stream(filename, cam_index, 0,0,0);
|
191 |
+
|
192 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
193 |
+
float fps = 0;
|
194 |
+
|
195 |
+
while(1){
|
196 |
+
struct timeval tval_before, tval_after, tval_result;
|
197 |
+
gettimeofday(&tval_before, NULL);
|
198 |
+
|
199 |
+
image in = get_image_from_stream(cap);
|
200 |
+
image in_s = letterbox_image(in, net->w, net->h);
|
201 |
+
|
202 |
+
network_predict(net, in_s.data);
|
203 |
+
|
204 |
+
printf("\033[2J");
|
205 |
+
printf("\033[1;1H");
|
206 |
+
printf("\nFPS:%.0f\n",fps);
|
207 |
+
|
208 |
+
image pred = get_network_image(net);
|
209 |
+
image prmask = mask_to_rgb(pred);
|
210 |
+
show_image(prmask, "Segmenter", 10);
|
211 |
+
|
212 |
+
free_image(in_s);
|
213 |
+
free_image(in);
|
214 |
+
free_image(prmask);
|
215 |
+
|
216 |
+
gettimeofday(&tval_after, NULL);
|
217 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
218 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
219 |
+
fps = .9*fps + .1*curr;
|
220 |
+
}
|
221 |
+
#endif
|
222 |
+
}
|
223 |
+
|
224 |
+
|
225 |
+
void run_isegmenter(int argc, char **argv)
|
226 |
+
{
|
227 |
+
if(argc < 4){
|
228 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
229 |
+
return;
|
230 |
+
}
|
231 |
+
|
232 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
233 |
+
int *gpus = 0;
|
234 |
+
int gpu = 0;
|
235 |
+
int ngpus = 0;
|
236 |
+
if(gpu_list){
|
237 |
+
printf("%s\n", gpu_list);
|
238 |
+
int len = strlen(gpu_list);
|
239 |
+
ngpus = 1;
|
240 |
+
int i;
|
241 |
+
for(i = 0; i < len; ++i){
|
242 |
+
if (gpu_list[i] == ',') ++ngpus;
|
243 |
+
}
|
244 |
+
gpus = calloc(ngpus, sizeof(int));
|
245 |
+
for(i = 0; i < ngpus; ++i){
|
246 |
+
gpus[i] = atoi(gpu_list);
|
247 |
+
gpu_list = strchr(gpu_list, ',')+1;
|
248 |
+
}
|
249 |
+
} else {
|
250 |
+
gpu = gpu_index;
|
251 |
+
gpus = &gpu;
|
252 |
+
ngpus = 1;
|
253 |
+
}
|
254 |
+
|
255 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
256 |
+
int clear = find_arg(argc, argv, "-clear");
|
257 |
+
int display = find_arg(argc, argv, "-display");
|
258 |
+
char *data = argv[3];
|
259 |
+
char *cfg = argv[4];
|
260 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
261 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
262 |
+
if(0==strcmp(argv[2], "test")) predict_isegmenter(data, cfg, weights, filename);
|
263 |
+
else if(0==strcmp(argv[2], "train")) train_isegmenter(data, cfg, weights, gpus, ngpus, clear, display);
|
264 |
+
else if(0==strcmp(argv[2], "demo")) demo_isegmenter(data, cfg, weights, cam_index, filename);
|
265 |
+
}
|
266 |
+
|
267 |
+
|
model/examples/lsd.c
ADDED
@@ -0,0 +1,1378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include "darknet.h"
|
3 |
+
|
4 |
+
/*
|
5 |
+
void train_lsd3(char *fcfg, char *fweight, char *gcfg, char *gweight, char *acfg, char *aweight, int clear)
|
6 |
+
{
|
7 |
+
#ifdef GPU
|
8 |
+
//char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
9 |
+
char *train_images = "/home/pjreddie/data/imagenet/imagenet1k.train.list";
|
10 |
+
//char *style_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
11 |
+
char *style_images = "/home/pjreddie/zelda.txt";
|
12 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
13 |
+
srand(time(0));
|
14 |
+
network fnet = load_network(fcfg, fweight, clear);
|
15 |
+
network gnet = load_network(gcfg, gweight, clear);
|
16 |
+
network anet = load_network(acfg, aweight, clear);
|
17 |
+
char *gbase = basecfg(gcfg);
|
18 |
+
char *abase = basecfg(acfg);
|
19 |
+
|
20 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", gnet->learning_rate, gnet->momentum, gnet->decay);
|
21 |
+
int imgs = gnet->batch*gnet->subdivisions;
|
22 |
+
int i = *gnet->seen/imgs;
|
23 |
+
data train, tbuffer;
|
24 |
+
data style, sbuffer;
|
25 |
+
|
26 |
+
|
27 |
+
list *slist = get_paths(style_images);
|
28 |
+
char **spaths = (char **)list_to_array(slist);
|
29 |
+
|
30 |
+
list *tlist = get_paths(train_images);
|
31 |
+
char **tpaths = (char **)list_to_array(tlist);
|
32 |
+
|
33 |
+
load_args targs= get_base_args(gnet);
|
34 |
+
targs.paths = tpaths;
|
35 |
+
targs.n = imgs;
|
36 |
+
targs.m = tlist->size;
|
37 |
+
targs.d = &tbuffer;
|
38 |
+
targs.type = CLASSIFICATION_DATA;
|
39 |
+
targs.classes = 1;
|
40 |
+
char *ls[1] = {"zelda"};
|
41 |
+
targs.labels = ls;
|
42 |
+
|
43 |
+
load_args sargs = get_base_args(gnet);
|
44 |
+
sargs.paths = spaths;
|
45 |
+
sargs.n = imgs;
|
46 |
+
sargs.m = slist->size;
|
47 |
+
sargs.d = &sbuffer;
|
48 |
+
sargs.type = CLASSIFICATION_DATA;
|
49 |
+
sargs.classes = 1;
|
50 |
+
sargs.labels = ls;
|
51 |
+
|
52 |
+
pthread_t tload_thread = load_data_in_thread(targs);
|
53 |
+
pthread_t sload_thread = load_data_in_thread(sargs);
|
54 |
+
clock_t time;
|
55 |
+
|
56 |
+
float aloss_avg = -1;
|
57 |
+
float floss_avg = -1;
|
58 |
+
|
59 |
+
fnet->train=1;
|
60 |
+
int x_size = fnet->inputs*fnet->batch;
|
61 |
+
int y_size = fnet->truths*fnet->batch;
|
62 |
+
float *X = calloc(x_size, sizeof(float));
|
63 |
+
float *y = calloc(y_size, sizeof(float));
|
64 |
+
|
65 |
+
|
66 |
+
int ax_size = anet->inputs*anet->batch;
|
67 |
+
int ay_size = anet->truths*anet->batch;
|
68 |
+
fill_gpu(ay_size, .9, anet->truth_gpu, 1);
|
69 |
+
anet->delta_gpu = cuda_make_array(0, ax_size);
|
70 |
+
anet->train = 1;
|
71 |
+
|
72 |
+
int gx_size = gnet->inputs*gnet->batch;
|
73 |
+
int gy_size = gnet->truths*gnet->batch;
|
74 |
+
gstate.input = cuda_make_array(0, gx_size);
|
75 |
+
gstate.truth = 0;
|
76 |
+
gstate.delta = 0;
|
77 |
+
gstate.train = 1;
|
78 |
+
|
79 |
+
while (get_current_batch(gnet) < gnet->max_batches) {
|
80 |
+
i += 1;
|
81 |
+
time=clock();
|
82 |
+
pthread_join(tload_thread, 0);
|
83 |
+
pthread_join(sload_thread, 0);
|
84 |
+
train = tbuffer;
|
85 |
+
style = sbuffer;
|
86 |
+
tload_thread = load_data_in_thread(targs);
|
87 |
+
sload_thread = load_data_in_thread(sargs);
|
88 |
+
|
89 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
90 |
+
|
91 |
+
data generated = copy_data(train);
|
92 |
+
time=clock();
|
93 |
+
|
94 |
+
int j, k;
|
95 |
+
float floss = 0;
|
96 |
+
for(j = 0; j < fnet->subdivisions; ++j){
|
97 |
+
layer imlayer = gnet->layers[gnet->n - 1];
|
98 |
+
get_next_batch(train, fnet->batch, j*fnet->batch, X, y);
|
99 |
+
|
100 |
+
cuda_push_array(fstate.input, X, x_size);
|
101 |
+
cuda_push_array(gstate.input, X, gx_size);
|
102 |
+
*gnet->seen += gnet->batch;
|
103 |
+
|
104 |
+
forward_network_gpu(fnet, fstate);
|
105 |
+
float *feats = fnet->layers[fnet->n - 2].output_gpu;
|
106 |
+
copy_gpu(y_size, feats, 1, fstate.truth, 1);
|
107 |
+
|
108 |
+
forward_network_gpu(gnet, gstate);
|
109 |
+
float *gen = gnet->layers[gnet->n-1].output_gpu;
|
110 |
+
copy_gpu(x_size, gen, 1, fstate.input, 1);
|
111 |
+
|
112 |
+
fill_gpu(x_size, 0, fstate.delta, 1);
|
113 |
+
forward_network_gpu(fnet, fstate);
|
114 |
+
backward_network_gpu(fnet, fstate);
|
115 |
+
//HERE
|
116 |
+
|
117 |
+
astate.input = gen;
|
118 |
+
fill_gpu(ax_size, 0, astate.delta, 1);
|
119 |
+
forward_network_gpu(anet, astate);
|
120 |
+
backward_network_gpu(anet, astate);
|
121 |
+
|
122 |
+
float *delta = imlayer.delta_gpu;
|
123 |
+
fill_gpu(x_size, 0, delta, 1);
|
124 |
+
scal_gpu(x_size, 100, astate.delta, 1);
|
125 |
+
scal_gpu(x_size, .001, fstate.delta, 1);
|
126 |
+
axpy_gpu(x_size, 1, fstate.delta, 1, delta, 1);
|
127 |
+
axpy_gpu(x_size, 1, astate.delta, 1, delta, 1);
|
128 |
+
|
129 |
+
//fill_gpu(x_size, 0, delta, 1);
|
130 |
+
//cuda_push_array(delta, X, x_size);
|
131 |
+
//axpy_gpu(x_size, -1, imlayer.output_gpu, 1, delta, 1);
|
132 |
+
//printf("pix error: %f\n", cuda_mag_array(delta, x_size));
|
133 |
+
printf("fea error: %f\n", cuda_mag_array(fstate.delta, x_size));
|
134 |
+
printf("adv error: %f\n", cuda_mag_array(astate.delta, x_size));
|
135 |
+
//axpy_gpu(x_size, 1, astate.delta, 1, delta, 1);
|
136 |
+
|
137 |
+
backward_network_gpu(gnet, gstate);
|
138 |
+
|
139 |
+
floss += get_network_cost(fnet) /(fnet->subdivisions*fnet->batch);
|
140 |
+
|
141 |
+
cuda_pull_array(imlayer.output_gpu, imlayer.output, imlayer.outputs*imlayer.batch);
|
142 |
+
for(k = 0; k < gnet->batch; ++k){
|
143 |
+
int index = j*gnet->batch + k;
|
144 |
+
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, generated.X.vals[index], 1);
|
145 |
+
generated.y.vals[index][0] = .1;
|
146 |
+
style.y.vals[index][0] = .9;
|
147 |
+
}
|
148 |
+
}
|
149 |
+
|
150 |
+
*/
|
151 |
+
/*
|
152 |
+
image sim = float_to_image(anet->w, anet->h, anet->c, style.X.vals[j]);
|
153 |
+
show_image(sim, "style");
|
154 |
+
cvWaitKey(0);
|
155 |
+
*/
|
156 |
+
/*
|
157 |
+
|
158 |
+
harmless_update_network_gpu(anet);
|
159 |
+
|
160 |
+
data merge = concat_data(style, generated);
|
161 |
+
randomize_data(merge);
|
162 |
+
float aloss = train_network(anet, merge);
|
163 |
+
|
164 |
+
update_network_gpu(gnet);
|
165 |
+
|
166 |
+
free_data(merge);
|
167 |
+
free_data(train);
|
168 |
+
free_data(generated);
|
169 |
+
free_data(style);
|
170 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
171 |
+
if (floss_avg < 0) floss_avg = floss;
|
172 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
173 |
+
floss_avg = floss_avg*.9 + floss*.1;
|
174 |
+
|
175 |
+
printf("%d: gen: %f, adv: %f | gen_avg: %f, adv_avg: %f, %f rate, %lf seconds, %d images\n", i, floss, aloss, floss_avg, aloss_avg, get_current_rate(gnet), sec(clock()-time), i*imgs);
|
176 |
+
if(i%1000==0){
|
177 |
+
char buff[256];
|
178 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, gbase, i);
|
179 |
+
save_weights(gnet, buff);
|
180 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
181 |
+
save_weights(anet, buff);
|
182 |
+
}
|
183 |
+
if(i%100==0){
|
184 |
+
char buff[256];
|
185 |
+
sprintf(buff, "%s/%s.backup", backup_directory, gbase);
|
186 |
+
save_weights(gnet, buff);
|
187 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
188 |
+
save_weights(anet, buff);
|
189 |
+
}
|
190 |
+
}
|
191 |
+
#endif
|
192 |
+
}
|
193 |
+
*/
|
194 |
+
|
195 |
+
/*
|
196 |
+
void train_pix2pix(char *cfg, char *weight, char *acfg, char *aweight, int clear)
|
197 |
+
{
|
198 |
+
#ifdef GPU
|
199 |
+
//char *train_images = "/home/pjreddie/data/coco/train1.txt";
|
200 |
+
//char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
201 |
+
char *train_images = "/home/pjreddie/data/imagenet/imagenet1k.train.list";
|
202 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
203 |
+
srand(time(0));
|
204 |
+
char *base = basecfg(cfg);
|
205 |
+
char *abase = basecfg(acfg);
|
206 |
+
printf("%s\n", base);
|
207 |
+
network net = load_network(cfg, weight, clear);
|
208 |
+
network anet = load_network(acfg, aweight, clear);
|
209 |
+
|
210 |
+
int i, j, k;
|
211 |
+
layer imlayer = {0};
|
212 |
+
for (i = 0; i < net->n; ++i) {
|
213 |
+
if (net->layers[i].out_c == 3) {
|
214 |
+
imlayer = net->layers[i];
|
215 |
+
break;
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
220 |
+
int imgs = net->batch*net->subdivisions;
|
221 |
+
i = *net->seen/imgs;
|
222 |
+
data train, buffer;
|
223 |
+
|
224 |
+
|
225 |
+
list *plist = get_paths(train_images);
|
226 |
+
//int N = plist->size;
|
227 |
+
char **paths = (char **)list_to_array(plist);
|
228 |
+
|
229 |
+
load_args args = {0};
|
230 |
+
args.w = net->w;
|
231 |
+
args.h = net->h;
|
232 |
+
args.paths = paths;
|
233 |
+
args.n = imgs;
|
234 |
+
args.m = plist->size;
|
235 |
+
args.d = &buffer;
|
236 |
+
|
237 |
+
args.min = net->min_crop;
|
238 |
+
args.max = net->max_crop;
|
239 |
+
args.angle = net->angle;
|
240 |
+
args.aspect = net->aspect;
|
241 |
+
args.exposure = net->exposure;
|
242 |
+
args.saturation = net->saturation;
|
243 |
+
args.hue = net->hue;
|
244 |
+
args.size = net->w;
|
245 |
+
args.type = CLASSIFICATION_DATA;
|
246 |
+
args.classes = 1;
|
247 |
+
char *ls[1] = {"coco"};
|
248 |
+
args.labels = ls;
|
249 |
+
|
250 |
+
pthread_t load_thread = load_data_in_thread(args);
|
251 |
+
clock_t time;
|
252 |
+
|
253 |
+
network_state gstate = {0};
|
254 |
+
gstate.index = 0;
|
255 |
+
gstate.net = net;
|
256 |
+
int x_size = get_network_input_size(net)*net->batch;
|
257 |
+
int y_size = x_size;
|
258 |
+
gstate.input = cuda_make_array(0, x_size);
|
259 |
+
gstate.truth = cuda_make_array(0, y_size);
|
260 |
+
gstate.delta = 0;
|
261 |
+
gstate.train = 1;
|
262 |
+
float *pixs = calloc(x_size, sizeof(float));
|
263 |
+
float *graypixs = calloc(x_size, sizeof(float));
|
264 |
+
float *y = calloc(y_size, sizeof(float));
|
265 |
+
|
266 |
+
network_state astate = {0};
|
267 |
+
astate.index = 0;
|
268 |
+
astate.net = anet;
|
269 |
+
int ay_size = get_network_output_size(anet)*anet->batch;
|
270 |
+
astate.input = 0;
|
271 |
+
astate.truth = 0;
|
272 |
+
astate.delta = 0;
|
273 |
+
astate.train = 1;
|
274 |
+
|
275 |
+
float *imerror = cuda_make_array(0, imlayer.outputs);
|
276 |
+
float *ones_gpu = cuda_make_array(0, ay_size);
|
277 |
+
fill_gpu(ay_size, .9, ones_gpu, 1);
|
278 |
+
|
279 |
+
float aloss_avg = -1;
|
280 |
+
float gloss_avg = -1;
|
281 |
+
|
282 |
+
//data generated = copy_data(train);
|
283 |
+
|
284 |
+
while (get_current_batch(net) < net->max_batches) {
|
285 |
+
i += 1;
|
286 |
+
time=clock();
|
287 |
+
pthread_join(load_thread, 0);
|
288 |
+
train = buffer;
|
289 |
+
load_thread = load_data_in_thread(args);
|
290 |
+
|
291 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
292 |
+
|
293 |
+
data gray = copy_data(train);
|
294 |
+
for(j = 0; j < imgs; ++j){
|
295 |
+
image gim = float_to_image(net->w, net->h, net->c, gray.X.vals[j]);
|
296 |
+
grayscale_image_3c(gim);
|
297 |
+
train.y.vals[j][0] = .9;
|
298 |
+
|
299 |
+
image yim = float_to_image(net->w, net->h, net->c, train.X.vals[j]);
|
300 |
+
//rgb_to_yuv(yim);
|
301 |
+
}
|
302 |
+
time=clock();
|
303 |
+
float gloss = 0;
|
304 |
+
|
305 |
+
for(j = 0; j < net->subdivisions; ++j){
|
306 |
+
get_next_batch(train, net->batch, j*net->batch, pixs, y);
|
307 |
+
get_next_batch(gray, net->batch, j*net->batch, graypixs, y);
|
308 |
+
cuda_push_array(gstate.input, graypixs, x_size);
|
309 |
+
cuda_push_array(gstate.truth, pixs, y_size);
|
310 |
+
*/
|
311 |
+
/*
|
312 |
+
image origi = float_to_image(net->w, net->h, 3, pixs);
|
313 |
+
image grayi = float_to_image(net->w, net->h, 3, graypixs);
|
314 |
+
show_image(grayi, "gray");
|
315 |
+
show_image(origi, "orig");
|
316 |
+
cvWaitKey(0);
|
317 |
+
*/
|
318 |
+
/*
|
319 |
+
*net->seen += net->batch;
|
320 |
+
forward_network_gpu(net, gstate);
|
321 |
+
|
322 |
+
fill_gpu(imlayer.outputs, 0, imerror, 1);
|
323 |
+
astate.input = imlayer.output_gpu;
|
324 |
+
astate.delta = imerror;
|
325 |
+
astate.truth = ones_gpu;
|
326 |
+
forward_network_gpu(anet, astate);
|
327 |
+
backward_network_gpu(anet, astate);
|
328 |
+
|
329 |
+
scal_gpu(imlayer.outputs, .1, net->layers[net->n-1].delta_gpu, 1);
|
330 |
+
|
331 |
+
backward_network_gpu(net, gstate);
|
332 |
+
|
333 |
+
scal_gpu(imlayer.outputs, 1000, imerror, 1);
|
334 |
+
|
335 |
+
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs));
|
336 |
+
printf("features %f\n", cuda_mag_array(net->layers[net->n-1].delta_gpu, imlayer.outputs));
|
337 |
+
|
338 |
+
axpy_gpu(imlayer.outputs, 1, imerror, 1, imlayer.delta_gpu, 1);
|
339 |
+
|
340 |
+
gloss += get_network_cost(net) /(net->subdivisions*net->batch);
|
341 |
+
|
342 |
+
cuda_pull_array(imlayer.output_gpu, imlayer.output, imlayer.outputs*imlayer.batch);
|
343 |
+
for(k = 0; k < net->batch; ++k){
|
344 |
+
int index = j*net->batch + k;
|
345 |
+
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, gray.X.vals[index], 1);
|
346 |
+
gray.y.vals[index][0] = .1;
|
347 |
+
}
|
348 |
+
}
|
349 |
+
harmless_update_network_gpu(anet);
|
350 |
+
|
351 |
+
data merge = concat_data(train, gray);
|
352 |
+
randomize_data(merge);
|
353 |
+
float aloss = train_network(anet, merge);
|
354 |
+
|
355 |
+
update_network_gpu(net);
|
356 |
+
update_network_gpu(anet);
|
357 |
+
free_data(merge);
|
358 |
+
free_data(train);
|
359 |
+
free_data(gray);
|
360 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
361 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
362 |
+
gloss_avg = gloss_avg*.9 + gloss*.1;
|
363 |
+
|
364 |
+
printf("%d: gen: %f, adv: %f | gen_avg: %f, adv_avg: %f, %f rate, %lf seconds, %d images\n", i, gloss, aloss, gloss_avg, aloss_avg, get_current_rate(net), sec(clock()-time), i*imgs);
|
365 |
+
if(i%1000==0){
|
366 |
+
char buff[256];
|
367 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
368 |
+
save_weights(net, buff);
|
369 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
370 |
+
save_weights(anet, buff);
|
371 |
+
}
|
372 |
+
if(i%100==0){
|
373 |
+
char buff[256];
|
374 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
375 |
+
save_weights(net, buff);
|
376 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
377 |
+
save_weights(anet, buff);
|
378 |
+
}
|
379 |
+
}
|
380 |
+
char buff[256];
|
381 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
382 |
+
save_weights(net, buff);
|
383 |
+
#endif
|
384 |
+
}
|
385 |
+
*/
|
386 |
+
|
387 |
+
void slerp(float *start, float *end, float s, int n, float *out)
|
388 |
+
{
|
389 |
+
float omega = acos(dot_cpu(n, start, 1, end, 1));
|
390 |
+
float so = sin(omega);
|
391 |
+
fill_cpu(n, 0, out, 1);
|
392 |
+
axpy_cpu(n, sin((1-s)*omega)/so, start, 1, out, 1);
|
393 |
+
axpy_cpu(n, sin(s*omega)/so, end, 1, out, 1);
|
394 |
+
|
395 |
+
float mag = mag_array(out, n);
|
396 |
+
scale_array(out, n, 1./mag);
|
397 |
+
}
|
398 |
+
|
399 |
+
image random_unit_vector_image(int w, int h, int c)
|
400 |
+
{
|
401 |
+
image im = make_image(w, h, c);
|
402 |
+
int i;
|
403 |
+
for(i = 0; i < im.w*im.h*im.c; ++i){
|
404 |
+
im.data[i] = rand_normal();
|
405 |
+
}
|
406 |
+
float mag = mag_array(im.data, im.w*im.h*im.c);
|
407 |
+
scale_array(im.data, im.w*im.h*im.c, 1./mag);
|
408 |
+
return im;
|
409 |
+
}
|
410 |
+
|
411 |
+
void inter_dcgan(char *cfgfile, char *weightfile)
|
412 |
+
{
|
413 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
414 |
+
set_batch_network(net, 1);
|
415 |
+
srand(2222222);
|
416 |
+
|
417 |
+
clock_t time;
|
418 |
+
char buff[256];
|
419 |
+
char *input = buff;
|
420 |
+
int i, imlayer = 0;
|
421 |
+
|
422 |
+
for (i = 0; i < net->n; ++i) {
|
423 |
+
if (net->layers[i].out_c == 3) {
|
424 |
+
imlayer = i;
|
425 |
+
printf("%d\n", i);
|
426 |
+
break;
|
427 |
+
}
|
428 |
+
}
|
429 |
+
image start = random_unit_vector_image(net->w, net->h, net->c);
|
430 |
+
image end = random_unit_vector_image(net->w, net->h, net->c);
|
431 |
+
image im = make_image(net->w, net->h, net->c);
|
432 |
+
image orig = copy_image(start);
|
433 |
+
|
434 |
+
int c = 0;
|
435 |
+
int count = 0;
|
436 |
+
int max_count = 15;
|
437 |
+
while(1){
|
438 |
+
++c;
|
439 |
+
|
440 |
+
if(count == max_count){
|
441 |
+
count = 0;
|
442 |
+
free_image(start);
|
443 |
+
start = end;
|
444 |
+
end = random_unit_vector_image(net->w, net->h, net->c);
|
445 |
+
if(c > 300){
|
446 |
+
end = orig;
|
447 |
+
}
|
448 |
+
if(c>300 + max_count) return;
|
449 |
+
}
|
450 |
+
++count;
|
451 |
+
|
452 |
+
slerp(start.data, end.data, (float)count / max_count, im.w*im.h*im.c, im.data);
|
453 |
+
|
454 |
+
float *X = im.data;
|
455 |
+
time=clock();
|
456 |
+
network_predict(net, X);
|
457 |
+
image out = get_network_image_layer(net, imlayer);
|
458 |
+
//yuv_to_rgb(out);
|
459 |
+
normalize_image(out);
|
460 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
461 |
+
//char buff[256];
|
462 |
+
sprintf(buff, "out%05d", c);
|
463 |
+
save_image(out, "out");
|
464 |
+
save_image(out, buff);
|
465 |
+
show_image(out, "out", 0);
|
466 |
+
}
|
467 |
+
}
|
468 |
+
|
469 |
+
void test_dcgan(char *cfgfile, char *weightfile)
|
470 |
+
{
|
471 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
472 |
+
set_batch_network(net, 1);
|
473 |
+
srand(2222222);
|
474 |
+
|
475 |
+
clock_t time;
|
476 |
+
char buff[256];
|
477 |
+
char *input = buff;
|
478 |
+
int imlayer = 0;
|
479 |
+
|
480 |
+
imlayer = net->n-1;
|
481 |
+
|
482 |
+
while(1){
|
483 |
+
image im = make_image(net->w, net->h, net->c);
|
484 |
+
int i;
|
485 |
+
for(i = 0; i < im.w*im.h*im.c; ++i){
|
486 |
+
im.data[i] = rand_normal();
|
487 |
+
}
|
488 |
+
//float mag = mag_array(im.data, im.w*im.h*im.c);
|
489 |
+
//scale_array(im.data, im.w*im.h*im.c, 1./mag);
|
490 |
+
|
491 |
+
float *X = im.data;
|
492 |
+
time=clock();
|
493 |
+
network_predict(net, X);
|
494 |
+
image out = get_network_image_layer(net, imlayer);
|
495 |
+
//yuv_to_rgb(out);
|
496 |
+
normalize_image(out);
|
497 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
498 |
+
save_image(out, "out");
|
499 |
+
show_image(out, "out", 0);
|
500 |
+
|
501 |
+
free_image(im);
|
502 |
+
}
|
503 |
+
}
|
504 |
+
|
505 |
+
void set_network_alpha_beta(network *net, float alpha, float beta)
|
506 |
+
{
|
507 |
+
int i;
|
508 |
+
for(i = 0; i < net->n; ++i){
|
509 |
+
if(net->layers[i].type == SHORTCUT){
|
510 |
+
net->layers[i].alpha = alpha;
|
511 |
+
net->layers[i].beta = beta;
|
512 |
+
}
|
513 |
+
}
|
514 |
+
}
|
515 |
+
|
516 |
+
void train_prog(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch)
|
517 |
+
{
|
518 |
+
#ifdef GPU
|
519 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
520 |
+
srand(time(0));
|
521 |
+
char *base = basecfg(cfg);
|
522 |
+
char *abase = basecfg(acfg);
|
523 |
+
printf("%s\n", base);
|
524 |
+
network *gnet = load_network(cfg, weight, clear);
|
525 |
+
network *anet = load_network(acfg, aweight, clear);
|
526 |
+
|
527 |
+
int i, j, k;
|
528 |
+
layer imlayer = gnet->layers[gnet->n-1];
|
529 |
+
|
530 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", gnet->learning_rate, gnet->momentum, gnet->decay);
|
531 |
+
int imgs = gnet->batch*gnet->subdivisions;
|
532 |
+
i = *gnet->seen/imgs;
|
533 |
+
data train, buffer;
|
534 |
+
|
535 |
+
|
536 |
+
list *plist = get_paths(train_images);
|
537 |
+
char **paths = (char **)list_to_array(plist);
|
538 |
+
|
539 |
+
load_args args= get_base_args(anet);
|
540 |
+
args.paths = paths;
|
541 |
+
args.n = imgs;
|
542 |
+
args.m = plist->size;
|
543 |
+
args.d = &buffer;
|
544 |
+
args.type = CLASSIFICATION_DATA;
|
545 |
+
args.threads=16;
|
546 |
+
args.classes = 1;
|
547 |
+
char *ls[2] = {"imagenet", "zzzzzzzz"};
|
548 |
+
args.labels = ls;
|
549 |
+
|
550 |
+
pthread_t load_thread = load_data_in_thread(args);
|
551 |
+
clock_t time;
|
552 |
+
|
553 |
+
gnet->train = 1;
|
554 |
+
anet->train = 1;
|
555 |
+
|
556 |
+
int x_size = gnet->inputs*gnet->batch;
|
557 |
+
int y_size = gnet->truths*gnet->batch;
|
558 |
+
float *imerror = cuda_make_array(0, y_size);
|
559 |
+
|
560 |
+
float aloss_avg = -1;
|
561 |
+
|
562 |
+
if (maxbatch == 0) maxbatch = gnet->max_batches;
|
563 |
+
while (get_current_batch(gnet) < maxbatch) {
|
564 |
+
{
|
565 |
+
int cb = get_current_batch(gnet);
|
566 |
+
float alpha = (float) cb / (maxbatch/2);
|
567 |
+
if(alpha > 1) alpha = 1;
|
568 |
+
float beta = 1 - alpha;
|
569 |
+
printf("%f %f\n", alpha, beta);
|
570 |
+
set_network_alpha_beta(gnet, alpha, beta);
|
571 |
+
set_network_alpha_beta(anet, beta, alpha);
|
572 |
+
}
|
573 |
+
|
574 |
+
i += 1;
|
575 |
+
time=clock();
|
576 |
+
pthread_join(load_thread, 0);
|
577 |
+
train = buffer;
|
578 |
+
|
579 |
+
load_thread = load_data_in_thread(args);
|
580 |
+
|
581 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
582 |
+
|
583 |
+
data gen = copy_data(train);
|
584 |
+
for (j = 0; j < imgs; ++j) {
|
585 |
+
train.y.vals[j][0] = 1;
|
586 |
+
gen.y.vals[j][0] = 0;
|
587 |
+
}
|
588 |
+
time=clock();
|
589 |
+
|
590 |
+
for (j = 0; j < gnet->subdivisions; ++j) {
|
591 |
+
get_next_batch(train, gnet->batch, j*gnet->batch, gnet->truth, 0);
|
592 |
+
int z;
|
593 |
+
for(z = 0; z < x_size; ++z){
|
594 |
+
gnet->input[z] = rand_normal();
|
595 |
+
}
|
596 |
+
/*
|
597 |
+
for(z = 0; z < gnet->batch; ++z){
|
598 |
+
float mag = mag_array(gnet->input + z*gnet->inputs, gnet->inputs);
|
599 |
+
scale_array(gnet->input + z*gnet->inputs, gnet->inputs, 1./mag);
|
600 |
+
}
|
601 |
+
*/
|
602 |
+
*gnet->seen += gnet->batch;
|
603 |
+
forward_network(gnet);
|
604 |
+
|
605 |
+
fill_gpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
|
606 |
+
fill_cpu(anet->truths*anet->batch, 1, anet->truth, 1);
|
607 |
+
copy_cpu(anet->inputs*anet->batch, imlayer.output, 1, anet->input, 1);
|
608 |
+
anet->delta_gpu = imerror;
|
609 |
+
forward_network(anet);
|
610 |
+
backward_network(anet);
|
611 |
+
|
612 |
+
//float genaloss = *anet->cost / anet->batch;
|
613 |
+
|
614 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
|
615 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 0, gnet->layers[gnet->n-1].delta_gpu, 1);
|
616 |
+
|
617 |
+
axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, gnet->layers[gnet->n-1].delta_gpu, 1);
|
618 |
+
|
619 |
+
backward_network(gnet);
|
620 |
+
|
621 |
+
for(k = 0; k < gnet->batch; ++k){
|
622 |
+
int index = j*gnet->batch + k;
|
623 |
+
copy_cpu(gnet->outputs, gnet->output + k*gnet->outputs, 1, gen.X.vals[index], 1);
|
624 |
+
}
|
625 |
+
}
|
626 |
+
harmless_update_network_gpu(anet);
|
627 |
+
|
628 |
+
data merge = concat_data(train, gen);
|
629 |
+
float aloss = train_network(anet, merge);
|
630 |
+
|
631 |
+
#ifdef OPENCV
|
632 |
+
if(display){
|
633 |
+
image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]);
|
634 |
+
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
635 |
+
show_image(im, "gen", 1);
|
636 |
+
show_image(im2, "train", 1);
|
637 |
+
save_image(im, "gen");
|
638 |
+
save_image(im2, "train");
|
639 |
+
}
|
640 |
+
#endif
|
641 |
+
|
642 |
+
update_network_gpu(gnet);
|
643 |
+
|
644 |
+
free_data(merge);
|
645 |
+
free_data(train);
|
646 |
+
free_data(gen);
|
647 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
648 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
649 |
+
|
650 |
+
printf("%d: adv: %f | adv_avg: %f, %f rate, %lf seconds, %d images\n", i, aloss, aloss_avg, get_current_rate(gnet), sec(clock()-time), i*imgs);
|
651 |
+
if(i%10000==0){
|
652 |
+
char buff[256];
|
653 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
654 |
+
save_weights(gnet, buff);
|
655 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
656 |
+
save_weights(anet, buff);
|
657 |
+
}
|
658 |
+
if(i%1000==0){
|
659 |
+
char buff[256];
|
660 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
661 |
+
save_weights(gnet, buff);
|
662 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
663 |
+
save_weights(anet, buff);
|
664 |
+
}
|
665 |
+
}
|
666 |
+
char buff[256];
|
667 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
668 |
+
save_weights(gnet, buff);
|
669 |
+
#endif
|
670 |
+
}
|
671 |
+
|
672 |
+
void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch)
|
673 |
+
{
|
674 |
+
#ifdef GPU
|
675 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
676 |
+
srand(time(0));
|
677 |
+
char *base = basecfg(cfg);
|
678 |
+
char *abase = basecfg(acfg);
|
679 |
+
printf("%s\n", base);
|
680 |
+
network *gnet = load_network(cfg, weight, clear);
|
681 |
+
network *anet = load_network(acfg, aweight, clear);
|
682 |
+
//float orig_rate = anet->learning_rate;
|
683 |
+
|
684 |
+
int i, j, k;
|
685 |
+
layer imlayer = {0};
|
686 |
+
for (i = 0; i < gnet->n; ++i) {
|
687 |
+
if (gnet->layers[i].out_c == 3) {
|
688 |
+
imlayer = gnet->layers[i];
|
689 |
+
break;
|
690 |
+
}
|
691 |
+
}
|
692 |
+
|
693 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", gnet->learning_rate, gnet->momentum, gnet->decay);
|
694 |
+
int imgs = gnet->batch*gnet->subdivisions;
|
695 |
+
i = *gnet->seen/imgs;
|
696 |
+
data train, buffer;
|
697 |
+
|
698 |
+
|
699 |
+
list *plist = get_paths(train_images);
|
700 |
+
//int N = plist->size;
|
701 |
+
char **paths = (char **)list_to_array(plist);
|
702 |
+
|
703 |
+
load_args args= get_base_args(anet);
|
704 |
+
args.paths = paths;
|
705 |
+
args.n = imgs;
|
706 |
+
args.m = plist->size;
|
707 |
+
args.d = &buffer;
|
708 |
+
args.type = CLASSIFICATION_DATA;
|
709 |
+
args.threads=16;
|
710 |
+
args.classes = 1;
|
711 |
+
char *ls[2] = {"imagenet", "zzzzzzzz"};
|
712 |
+
args.labels = ls;
|
713 |
+
|
714 |
+
pthread_t load_thread = load_data_in_thread(args);
|
715 |
+
clock_t time;
|
716 |
+
|
717 |
+
gnet->train = 1;
|
718 |
+
anet->train = 1;
|
719 |
+
|
720 |
+
int x_size = gnet->inputs*gnet->batch;
|
721 |
+
int y_size = gnet->truths*gnet->batch;
|
722 |
+
float *imerror = cuda_make_array(0, y_size);
|
723 |
+
|
724 |
+
//int ay_size = anet->truths*anet->batch;
|
725 |
+
|
726 |
+
float aloss_avg = -1;
|
727 |
+
|
728 |
+
//data generated = copy_data(train);
|
729 |
+
|
730 |
+
if (maxbatch == 0) maxbatch = gnet->max_batches;
|
731 |
+
while (get_current_batch(gnet) < maxbatch) {
|
732 |
+
i += 1;
|
733 |
+
time=clock();
|
734 |
+
pthread_join(load_thread, 0);
|
735 |
+
train = buffer;
|
736 |
+
|
737 |
+
//translate_data_rows(train, -.5);
|
738 |
+
//scale_data_rows(train, 2);
|
739 |
+
|
740 |
+
load_thread = load_data_in_thread(args);
|
741 |
+
|
742 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
743 |
+
|
744 |
+
data gen = copy_data(train);
|
745 |
+
for (j = 0; j < imgs; ++j) {
|
746 |
+
train.y.vals[j][0] = 1;
|
747 |
+
gen.y.vals[j][0] = 0;
|
748 |
+
}
|
749 |
+
time=clock();
|
750 |
+
|
751 |
+
for(j = 0; j < gnet->subdivisions; ++j){
|
752 |
+
get_next_batch(train, gnet->batch, j*gnet->batch, gnet->truth, 0);
|
753 |
+
int z;
|
754 |
+
for(z = 0; z < x_size; ++z){
|
755 |
+
gnet->input[z] = rand_normal();
|
756 |
+
}
|
757 |
+
for(z = 0; z < gnet->batch; ++z){
|
758 |
+
float mag = mag_array(gnet->input + z*gnet->inputs, gnet->inputs);
|
759 |
+
scale_array(gnet->input + z*gnet->inputs, gnet->inputs, 1./mag);
|
760 |
+
}
|
761 |
+
/*
|
762 |
+
for(z = 0; z < 100; ++z){
|
763 |
+
printf("%f, ", gnet->input[z]);
|
764 |
+
}
|
765 |
+
printf("\n");
|
766 |
+
printf("input: %f %f\n", mean_array(gnet->input, x_size), variance_array(gnet->input, x_size));
|
767 |
+
*/
|
768 |
+
|
769 |
+
//cuda_push_array(gnet->input_gpu, gnet->input, x_size);
|
770 |
+
//cuda_push_array(gnet->truth_gpu, gnet->truth, y_size);
|
771 |
+
*gnet->seen += gnet->batch;
|
772 |
+
forward_network(gnet);
|
773 |
+
|
774 |
+
fill_gpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
|
775 |
+
fill_cpu(anet->truths*anet->batch, 1, anet->truth, 1);
|
776 |
+
copy_cpu(anet->inputs*anet->batch, imlayer.output, 1, anet->input, 1);
|
777 |
+
anet->delta_gpu = imerror;
|
778 |
+
forward_network(anet);
|
779 |
+
backward_network(anet);
|
780 |
+
|
781 |
+
//float genaloss = *anet->cost / anet->batch;
|
782 |
+
//printf("%f\n", genaloss);
|
783 |
+
|
784 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
|
785 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 0, gnet->layers[gnet->n-1].delta_gpu, 1);
|
786 |
+
|
787 |
+
//printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
|
788 |
+
//printf("features %f\n", cuda_mag_array(gnet->layers[gnet->n-1].delta_gpu, imlayer.outputs*imlayer.batch));
|
789 |
+
|
790 |
+
axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, gnet->layers[gnet->n-1].delta_gpu, 1);
|
791 |
+
|
792 |
+
backward_network(gnet);
|
793 |
+
|
794 |
+
/*
|
795 |
+
for(k = 0; k < gnet->n; ++k){
|
796 |
+
layer l = gnet->layers[k];
|
797 |
+
cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
|
798 |
+
printf("%d: %f %f\n", k, mean_array(l.output, l.outputs*l.batch), variance_array(l.output, l.outputs*l.batch));
|
799 |
+
}
|
800 |
+
*/
|
801 |
+
|
802 |
+
for(k = 0; k < gnet->batch; ++k){
|
803 |
+
int index = j*gnet->batch + k;
|
804 |
+
copy_cpu(gnet->outputs, gnet->output + k*gnet->outputs, 1, gen.X.vals[index], 1);
|
805 |
+
}
|
806 |
+
}
|
807 |
+
harmless_update_network_gpu(anet);
|
808 |
+
|
809 |
+
data merge = concat_data(train, gen);
|
810 |
+
//randomize_data(merge);
|
811 |
+
float aloss = train_network(anet, merge);
|
812 |
+
|
813 |
+
//translate_image(im, 1);
|
814 |
+
//scale_image(im, .5);
|
815 |
+
//translate_image(im2, 1);
|
816 |
+
//scale_image(im2, .5);
|
817 |
+
#ifdef OPENCV
|
818 |
+
if(display){
|
819 |
+
image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]);
|
820 |
+
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
821 |
+
show_image(im, "gen", 1);
|
822 |
+
show_image(im2, "train", 1);
|
823 |
+
save_image(im, "gen");
|
824 |
+
save_image(im2, "train");
|
825 |
+
}
|
826 |
+
#endif
|
827 |
+
|
828 |
+
/*
|
829 |
+
if(aloss < .1){
|
830 |
+
anet->learning_rate = 0;
|
831 |
+
} else if (aloss > .3){
|
832 |
+
anet->learning_rate = orig_rate;
|
833 |
+
}
|
834 |
+
*/
|
835 |
+
|
836 |
+
update_network_gpu(gnet);
|
837 |
+
|
838 |
+
free_data(merge);
|
839 |
+
free_data(train);
|
840 |
+
free_data(gen);
|
841 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
842 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
843 |
+
|
844 |
+
printf("%d: adv: %f | adv_avg: %f, %f rate, %lf seconds, %d images\n", i, aloss, aloss_avg, get_current_rate(gnet), sec(clock()-time), i*imgs);
|
845 |
+
if(i%10000==0){
|
846 |
+
char buff[256];
|
847 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
848 |
+
save_weights(gnet, buff);
|
849 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
850 |
+
save_weights(anet, buff);
|
851 |
+
}
|
852 |
+
if(i%1000==0){
|
853 |
+
char buff[256];
|
854 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
855 |
+
save_weights(gnet, buff);
|
856 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
857 |
+
save_weights(anet, buff);
|
858 |
+
}
|
859 |
+
}
|
860 |
+
char buff[256];
|
861 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
862 |
+
save_weights(gnet, buff);
|
863 |
+
#endif
|
864 |
+
}
|
865 |
+
|
866 |
+
void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display)
|
867 |
+
{
|
868 |
+
#ifdef GPU
|
869 |
+
//char *train_images = "/home/pjreddie/data/coco/train1.txt";
|
870 |
+
//char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
871 |
+
char *train_images = "/home/pjreddie/data/imagenet/imagenet1k.train.list";
|
872 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
873 |
+
srand(time(0));
|
874 |
+
char *base = basecfg(cfg);
|
875 |
+
char *abase = basecfg(acfg);
|
876 |
+
printf("%s\n", base);
|
877 |
+
network *net = load_network(cfg, weight, clear);
|
878 |
+
network *anet = load_network(acfg, aweight, clear);
|
879 |
+
|
880 |
+
int i, j, k;
|
881 |
+
layer imlayer = {0};
|
882 |
+
for (i = 0; i < net->n; ++i) {
|
883 |
+
if (net->layers[i].out_c == 3) {
|
884 |
+
imlayer = net->layers[i];
|
885 |
+
break;
|
886 |
+
}
|
887 |
+
}
|
888 |
+
|
889 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
890 |
+
int imgs = net->batch*net->subdivisions;
|
891 |
+
i = *net->seen/imgs;
|
892 |
+
data train, buffer;
|
893 |
+
|
894 |
+
|
895 |
+
list *plist = get_paths(train_images);
|
896 |
+
//int N = plist->size;
|
897 |
+
char **paths = (char **)list_to_array(plist);
|
898 |
+
|
899 |
+
load_args args= get_base_args(net);
|
900 |
+
args.paths = paths;
|
901 |
+
args.n = imgs;
|
902 |
+
args.m = plist->size;
|
903 |
+
args.d = &buffer;
|
904 |
+
|
905 |
+
args.type = CLASSIFICATION_DATA;
|
906 |
+
args.classes = 1;
|
907 |
+
char *ls[2] = {"imagenet"};
|
908 |
+
args.labels = ls;
|
909 |
+
|
910 |
+
pthread_t load_thread = load_data_in_thread(args);
|
911 |
+
clock_t time;
|
912 |
+
|
913 |
+
int x_size = net->inputs*net->batch;
|
914 |
+
//int y_size = x_size;
|
915 |
+
net->delta = 0;
|
916 |
+
net->train = 1;
|
917 |
+
float *pixs = calloc(x_size, sizeof(float));
|
918 |
+
float *graypixs = calloc(x_size, sizeof(float));
|
919 |
+
//float *y = calloc(y_size, sizeof(float));
|
920 |
+
|
921 |
+
//int ay_size = anet->outputs*anet->batch;
|
922 |
+
anet->delta = 0;
|
923 |
+
anet->train = 1;
|
924 |
+
|
925 |
+
float *imerror = cuda_make_array(0, imlayer.outputs*imlayer.batch);
|
926 |
+
|
927 |
+
float aloss_avg = -1;
|
928 |
+
float gloss_avg = -1;
|
929 |
+
|
930 |
+
//data generated = copy_data(train);
|
931 |
+
|
932 |
+
while (get_current_batch(net) < net->max_batches) {
|
933 |
+
i += 1;
|
934 |
+
time=clock();
|
935 |
+
pthread_join(load_thread, 0);
|
936 |
+
train = buffer;
|
937 |
+
load_thread = load_data_in_thread(args);
|
938 |
+
|
939 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
940 |
+
|
941 |
+
data gray = copy_data(train);
|
942 |
+
for(j = 0; j < imgs; ++j){
|
943 |
+
image gim = float_to_image(net->w, net->h, net->c, gray.X.vals[j]);
|
944 |
+
grayscale_image_3c(gim);
|
945 |
+
train.y.vals[j][0] = .95;
|
946 |
+
gray.y.vals[j][0] = .05;
|
947 |
+
}
|
948 |
+
time=clock();
|
949 |
+
float gloss = 0;
|
950 |
+
|
951 |
+
for(j = 0; j < net->subdivisions; ++j){
|
952 |
+
get_next_batch(train, net->batch, j*net->batch, pixs, 0);
|
953 |
+
get_next_batch(gray, net->batch, j*net->batch, graypixs, 0);
|
954 |
+
cuda_push_array(net->input_gpu, graypixs, net->inputs*net->batch);
|
955 |
+
cuda_push_array(net->truth_gpu, pixs, net->truths*net->batch);
|
956 |
+
/*
|
957 |
+
image origi = float_to_image(net->w, net->h, 3, pixs);
|
958 |
+
image grayi = float_to_image(net->w, net->h, 3, graypixs);
|
959 |
+
show_image(grayi, "gray");
|
960 |
+
show_image(origi, "orig");
|
961 |
+
cvWaitKey(0);
|
962 |
+
*/
|
963 |
+
*net->seen += net->batch;
|
964 |
+
forward_network_gpu(net);
|
965 |
+
|
966 |
+
fill_gpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
|
967 |
+
copy_gpu(anet->inputs*anet->batch, imlayer.output_gpu, 1, anet->input_gpu, 1);
|
968 |
+
fill_gpu(anet->inputs*anet->batch, .95, anet->truth_gpu, 1);
|
969 |
+
anet->delta_gpu = imerror;
|
970 |
+
forward_network_gpu(anet);
|
971 |
+
backward_network_gpu(anet);
|
972 |
+
|
973 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 1./100., net->layers[net->n-1].delta_gpu, 1);
|
974 |
+
|
975 |
+
scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
|
976 |
+
|
977 |
+
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
|
978 |
+
printf("features %f\n", cuda_mag_array(net->layers[net->n-1].delta_gpu, imlayer.outputs*imlayer.batch));
|
979 |
+
|
980 |
+
axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, net->layers[net->n-1].delta_gpu, 1);
|
981 |
+
|
982 |
+
backward_network_gpu(net);
|
983 |
+
|
984 |
+
|
985 |
+
gloss += *net->cost /(net->subdivisions*net->batch);
|
986 |
+
|
987 |
+
for(k = 0; k < net->batch; ++k){
|
988 |
+
int index = j*net->batch + k;
|
989 |
+
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, gray.X.vals[index], 1);
|
990 |
+
}
|
991 |
+
}
|
992 |
+
harmless_update_network_gpu(anet);
|
993 |
+
|
994 |
+
data merge = concat_data(train, gray);
|
995 |
+
//randomize_data(merge);
|
996 |
+
float aloss = train_network(anet, merge);
|
997 |
+
|
998 |
+
update_network_gpu(net);
|
999 |
+
|
1000 |
+
#ifdef OPENCV
|
1001 |
+
if(display){
|
1002 |
+
image im = float_to_image(anet->w, anet->h, anet->c, gray.X.vals[0]);
|
1003 |
+
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
1004 |
+
show_image(im, "gen", 1);
|
1005 |
+
show_image(im2, "train", 1);
|
1006 |
+
}
|
1007 |
+
#endif
|
1008 |
+
free_data(merge);
|
1009 |
+
free_data(train);
|
1010 |
+
free_data(gray);
|
1011 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
1012 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
1013 |
+
gloss_avg = gloss_avg*.9 + gloss*.1;
|
1014 |
+
|
1015 |
+
printf("%d: gen: %f, adv: %f | gen_avg: %f, adv_avg: %f, %f rate, %lf seconds, %d images\n", i, gloss, aloss, gloss_avg, aloss_avg, get_current_rate(net), sec(clock()-time), i*imgs);
|
1016 |
+
if(i%1000==0){
|
1017 |
+
char buff[256];
|
1018 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
1019 |
+
save_weights(net, buff);
|
1020 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
1021 |
+
save_weights(anet, buff);
|
1022 |
+
}
|
1023 |
+
if(i%100==0){
|
1024 |
+
char buff[256];
|
1025 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
1026 |
+
save_weights(net, buff);
|
1027 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
1028 |
+
save_weights(anet, buff);
|
1029 |
+
}
|
1030 |
+
}
|
1031 |
+
char buff[256];
|
1032 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
1033 |
+
save_weights(net, buff);
|
1034 |
+
#endif
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
/*
|
1038 |
+
void train_lsd2(char *cfgfile, char *weightfile, char *acfgfile, char *aweightfile, int clear)
|
1039 |
+
{
|
1040 |
+
#ifdef GPU
|
1041 |
+
char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
1042 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
1043 |
+
srand(time(0));
|
1044 |
+
char *base = basecfg(cfgfile);
|
1045 |
+
printf("%s\n", base);
|
1046 |
+
network net = parse_network_cfg(cfgfile);
|
1047 |
+
if(weightfile){
|
1048 |
+
load_weights(&net, weightfile);
|
1049 |
+
}
|
1050 |
+
if(clear) *net->seen = 0;
|
1051 |
+
|
1052 |
+
char *abase = basecfg(acfgfile);
|
1053 |
+
network anet = parse_network_cfg(acfgfile);
|
1054 |
+
if(aweightfile){
|
1055 |
+
load_weights(&anet, aweightfile);
|
1056 |
+
}
|
1057 |
+
if(clear) *anet->seen = 0;
|
1058 |
+
|
1059 |
+
int i, j, k;
|
1060 |
+
layer imlayer = {0};
|
1061 |
+
for (i = 0; i < net->n; ++i) {
|
1062 |
+
if (net->layers[i].out_c == 3) {
|
1063 |
+
imlayer = net->layers[i];
|
1064 |
+
break;
|
1065 |
+
}
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
1069 |
+
int imgs = net->batch*net->subdivisions;
|
1070 |
+
i = *net->seen/imgs;
|
1071 |
+
data train, buffer;
|
1072 |
+
|
1073 |
+
|
1074 |
+
list *plist = get_paths(train_images);
|
1075 |
+
//int N = plist->size;
|
1076 |
+
char **paths = (char **)list_to_array(plist);
|
1077 |
+
|
1078 |
+
load_args args = {0};
|
1079 |
+
args.w = net->w;
|
1080 |
+
args.h = net->h;
|
1081 |
+
args.paths = paths;
|
1082 |
+
args.n = imgs;
|
1083 |
+
args.m = plist->size;
|
1084 |
+
args.d = &buffer;
|
1085 |
+
|
1086 |
+
args.min = net->min_crop;
|
1087 |
+
args.max = net->max_crop;
|
1088 |
+
args.angle = net->angle;
|
1089 |
+
args.aspect = net->aspect;
|
1090 |
+
args.exposure = net->exposure;
|
1091 |
+
args.saturation = net->saturation;
|
1092 |
+
args.hue = net->hue;
|
1093 |
+
args.size = net->w;
|
1094 |
+
args.type = CLASSIFICATION_DATA;
|
1095 |
+
args.classes = 1;
|
1096 |
+
char *ls[1] = {"coco"};
|
1097 |
+
args.labels = ls;
|
1098 |
+
|
1099 |
+
pthread_t load_thread = load_data_in_thread(args);
|
1100 |
+
clock_t time;
|
1101 |
+
|
1102 |
+
network_state gstate = {0};
|
1103 |
+
gstate.index = 0;
|
1104 |
+
gstate.net = net;
|
1105 |
+
int x_size = get_network_input_size(net)*net->batch;
|
1106 |
+
int y_size = 1*net->batch;
|
1107 |
+
gstate.input = cuda_make_array(0, x_size);
|
1108 |
+
gstate.truth = 0;
|
1109 |
+
gstate.delta = 0;
|
1110 |
+
gstate.train = 1;
|
1111 |
+
float *X = calloc(x_size, sizeof(float));
|
1112 |
+
float *y = calloc(y_size, sizeof(float));
|
1113 |
+
|
1114 |
+
network_state astate = {0};
|
1115 |
+
astate.index = 0;
|
1116 |
+
astate.net = anet;
|
1117 |
+
int ay_size = get_network_output_size(anet)*anet->batch;
|
1118 |
+
astate.input = 0;
|
1119 |
+
astate.truth = 0;
|
1120 |
+
astate.delta = 0;
|
1121 |
+
astate.train = 1;
|
1122 |
+
|
1123 |
+
float *imerror = cuda_make_array(0, imlayer.outputs);
|
1124 |
+
float *ones_gpu = cuda_make_array(0, ay_size);
|
1125 |
+
fill_gpu(ay_size, 1, ones_gpu, 1);
|
1126 |
+
|
1127 |
+
float aloss_avg = -1;
|
1128 |
+
float gloss_avg = -1;
|
1129 |
+
|
1130 |
+
//data generated = copy_data(train);
|
1131 |
+
|
1132 |
+
while (get_current_batch(net) < net->max_batches) {
|
1133 |
+
i += 1;
|
1134 |
+
time=clock();
|
1135 |
+
pthread_join(load_thread, 0);
|
1136 |
+
train = buffer;
|
1137 |
+
load_thread = load_data_in_thread(args);
|
1138 |
+
|
1139 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
1140 |
+
|
1141 |
+
data generated = copy_data(train);
|
1142 |
+
time=clock();
|
1143 |
+
float gloss = 0;
|
1144 |
+
|
1145 |
+
for(j = 0; j < net->subdivisions; ++j){
|
1146 |
+
get_next_batch(train, net->batch, j*net->batch, X, y);
|
1147 |
+
cuda_push_array(gstate.input, X, x_size);
|
1148 |
+
*net->seen += net->batch;
|
1149 |
+
forward_network_gpu(net, gstate);
|
1150 |
+
|
1151 |
+
fill_gpu(imlayer.outputs, 0, imerror, 1);
|
1152 |
+
astate.input = imlayer.output_gpu;
|
1153 |
+
astate.delta = imerror;
|
1154 |
+
astate.truth = ones_gpu;
|
1155 |
+
forward_network_gpu(anet, astate);
|
1156 |
+
backward_network_gpu(anet, astate);
|
1157 |
+
|
1158 |
+
scal_gpu(imlayer.outputs, 1, imerror, 1);
|
1159 |
+
axpy_gpu(imlayer.outputs, 1, imerror, 1, imlayer.delta_gpu, 1);
|
1160 |
+
|
1161 |
+
backward_network_gpu(net, gstate);
|
1162 |
+
|
1163 |
+
printf("features %f\n", cuda_mag_array(imlayer.delta_gpu, imlayer.outputs));
|
1164 |
+
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs));
|
1165 |
+
|
1166 |
+
gloss += get_network_cost(net) /(net->subdivisions*net->batch);
|
1167 |
+
|
1168 |
+
cuda_pull_array(imlayer.output_gpu, imlayer.output, imlayer.outputs*imlayer.batch);
|
1169 |
+
for(k = 0; k < net->batch; ++k){
|
1170 |
+
int index = j*net->batch + k;
|
1171 |
+
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, generated.X.vals[index], 1);
|
1172 |
+
generated.y.vals[index][0] = 0;
|
1173 |
+
}
|
1174 |
+
}
|
1175 |
+
harmless_update_network_gpu(anet);
|
1176 |
+
|
1177 |
+
data merge = concat_data(train, generated);
|
1178 |
+
randomize_data(merge);
|
1179 |
+
float aloss = train_network(anet, merge);
|
1180 |
+
|
1181 |
+
update_network_gpu(net);
|
1182 |
+
update_network_gpu(anet);
|
1183 |
+
free_data(merge);
|
1184 |
+
free_data(train);
|
1185 |
+
free_data(generated);
|
1186 |
+
if (aloss_avg < 0) aloss_avg = aloss;
|
1187 |
+
aloss_avg = aloss_avg*.9 + aloss*.1;
|
1188 |
+
gloss_avg = gloss_avg*.9 + gloss*.1;
|
1189 |
+
|
1190 |
+
printf("%d: gen: %f, adv: %f | gen_avg: %f, adv_avg: %f, %f rate, %lf seconds, %d images\n", i, gloss, aloss, gloss_avg, aloss_avg, get_current_rate(net), sec(clock()-time), i*imgs);
|
1191 |
+
if(i%1000==0){
|
1192 |
+
char buff[256];
|
1193 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
1194 |
+
save_weights(net, buff);
|
1195 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
1196 |
+
save_weights(anet, buff);
|
1197 |
+
}
|
1198 |
+
if(i%100==0){
|
1199 |
+
char buff[256];
|
1200 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
1201 |
+
save_weights(net, buff);
|
1202 |
+
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
1203 |
+
save_weights(anet, buff);
|
1204 |
+
}
|
1205 |
+
}
|
1206 |
+
char buff[256];
|
1207 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
1208 |
+
save_weights(net, buff);
|
1209 |
+
#endif
|
1210 |
+
}
|
1211 |
+
*/
|
1212 |
+
|
1213 |
+
/*
|
1214 |
+
void train_lsd(char *cfgfile, char *weightfile, int clear)
|
1215 |
+
{
|
1216 |
+
char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
1217 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
1218 |
+
srand(time(0));
|
1219 |
+
char *base = basecfg(cfgfile);
|
1220 |
+
printf("%s\n", base);
|
1221 |
+
float avg_loss = -1;
|
1222 |
+
network net = parse_network_cfg(cfgfile);
|
1223 |
+
if(weightfile){
|
1224 |
+
load_weights(&net, weightfile);
|
1225 |
+
}
|
1226 |
+
if(clear) *net->seen = 0;
|
1227 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
1228 |
+
int imgs = net->batch*net->subdivisions;
|
1229 |
+
int i = *net->seen/imgs;
|
1230 |
+
data train, buffer;
|
1231 |
+
|
1232 |
+
|
1233 |
+
list *plist = get_paths(train_images);
|
1234 |
+
//int N = plist->size;
|
1235 |
+
char **paths = (char **)list_to_array(plist);
|
1236 |
+
|
1237 |
+
load_args args = {0};
|
1238 |
+
args.w = net->w;
|
1239 |
+
args.h = net->h;
|
1240 |
+
args.paths = paths;
|
1241 |
+
args.n = imgs;
|
1242 |
+
args.m = plist->size;
|
1243 |
+
args.d = &buffer;
|
1244 |
+
|
1245 |
+
args.min = net->min_crop;
|
1246 |
+
args.max = net->max_crop;
|
1247 |
+
args.angle = net->angle;
|
1248 |
+
args.aspect = net->aspect;
|
1249 |
+
args.exposure = net->exposure;
|
1250 |
+
args.saturation = net->saturation;
|
1251 |
+
args.hue = net->hue;
|
1252 |
+
args.size = net->w;
|
1253 |
+
args.type = CLASSIFICATION_DATA;
|
1254 |
+
args.classes = 1;
|
1255 |
+
char *ls[1] = {"coco"};
|
1256 |
+
args.labels = ls;
|
1257 |
+
|
1258 |
+
pthread_t load_thread = load_data_in_thread(args);
|
1259 |
+
clock_t time;
|
1260 |
+
//while(i*imgs < N*120){
|
1261 |
+
while(get_current_batch(net) < net->max_batches){
|
1262 |
+
i += 1;
|
1263 |
+
time=clock();
|
1264 |
+
pthread_join(load_thread, 0);
|
1265 |
+
train = buffer;
|
1266 |
+
load_thread = load_data_in_thread(args);
|
1267 |
+
|
1268 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
1269 |
+
|
1270 |
+
time=clock();
|
1271 |
+
float loss = train_network(net, train);
|
1272 |
+
if (avg_loss < 0) avg_loss = loss;
|
1273 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
1274 |
+
|
1275 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
1276 |
+
if(i%1000==0){
|
1277 |
+
char buff[256];
|
1278 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
1279 |
+
save_weights(net, buff);
|
1280 |
+
}
|
1281 |
+
if(i%100==0){
|
1282 |
+
char buff[256];
|
1283 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
1284 |
+
save_weights(net, buff);
|
1285 |
+
}
|
1286 |
+
free_data(train);
|
1287 |
+
}
|
1288 |
+
char buff[256];
|
1289 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
1290 |
+
save_weights(net, buff);
|
1291 |
+
}
|
1292 |
+
*/
|
1293 |
+
|
1294 |
+
void test_lsd(char *cfg, char *weights, char *filename, int gray)
|
1295 |
+
{
|
1296 |
+
network *net = load_network(cfg, weights, 0);
|
1297 |
+
set_batch_network(net, 1);
|
1298 |
+
srand(2222222);
|
1299 |
+
|
1300 |
+
clock_t time;
|
1301 |
+
char buff[256];
|
1302 |
+
char *input = buff;
|
1303 |
+
int i, imlayer = 0;
|
1304 |
+
|
1305 |
+
for (i = 0; i < net->n; ++i) {
|
1306 |
+
if (net->layers[i].out_c == 3) {
|
1307 |
+
imlayer = i;
|
1308 |
+
printf("%d\n", i);
|
1309 |
+
break;
|
1310 |
+
}
|
1311 |
+
}
|
1312 |
+
|
1313 |
+
while(1){
|
1314 |
+
if(filename){
|
1315 |
+
strncpy(input, filename, 256);
|
1316 |
+
}else{
|
1317 |
+
printf("Enter Image Path: ");
|
1318 |
+
fflush(stdout);
|
1319 |
+
input = fgets(input, 256, stdin);
|
1320 |
+
if(!input) return;
|
1321 |
+
strtok(input, "\n");
|
1322 |
+
}
|
1323 |
+
image im = load_image_color(input, 0, 0);
|
1324 |
+
image resized = resize_min(im, net->w);
|
1325 |
+
image crop = crop_image(resized, (resized.w - net->w)/2, (resized.h - net->h)/2, net->w, net->h);
|
1326 |
+
if(gray) grayscale_image_3c(crop);
|
1327 |
+
|
1328 |
+
float *X = crop.data;
|
1329 |
+
time=clock();
|
1330 |
+
network_predict(net, X);
|
1331 |
+
image out = get_network_image_layer(net, imlayer);
|
1332 |
+
//yuv_to_rgb(out);
|
1333 |
+
constrain_image(out);
|
1334 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
1335 |
+
save_image(out, "out");
|
1336 |
+
show_image(out, "out", 1);
|
1337 |
+
show_image(crop, "crop", 0);
|
1338 |
+
|
1339 |
+
free_image(im);
|
1340 |
+
free_image(resized);
|
1341 |
+
free_image(crop);
|
1342 |
+
if (filename) break;
|
1343 |
+
}
|
1344 |
+
}
|
1345 |
+
|
1346 |
+
|
1347 |
+
void run_lsd(int argc, char **argv)
|
1348 |
+
{
|
1349 |
+
if(argc < 4){
|
1350 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
1351 |
+
return;
|
1352 |
+
}
|
1353 |
+
|
1354 |
+
int clear = find_arg(argc, argv, "-clear");
|
1355 |
+
int display = find_arg(argc, argv, "-display");
|
1356 |
+
int batches = find_int_arg(argc, argv, "-b", 0);
|
1357 |
+
char *file = find_char_arg(argc, argv, "-file", "/home/pjreddie/data/imagenet/imagenet1k.train.list");
|
1358 |
+
|
1359 |
+
char *cfg = argv[3];
|
1360 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
1361 |
+
char *filename = (argc > 5) ? argv[5] : 0;
|
1362 |
+
char *acfg = argv[5];
|
1363 |
+
char *aweights = (argc > 6) ? argv[6] : 0;
|
1364 |
+
//if(0==strcmp(argv[2], "train")) train_lsd(cfg, weights, clear);
|
1365 |
+
//else if(0==strcmp(argv[2], "train2")) train_lsd2(cfg, weights, acfg, aweights, clear);
|
1366 |
+
//else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear);
|
1367 |
+
//else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear);
|
1368 |
+
if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file, batches);
|
1369 |
+
else if(0==strcmp(argv[2], "trainprog")) train_prog(cfg, weights, acfg, aweights, clear, display, file, batches);
|
1370 |
+
else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear, display);
|
1371 |
+
else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights);
|
1372 |
+
else if(0==strcmp(argv[2], "inter")) inter_dcgan(cfg, weights);
|
1373 |
+
else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename, 0);
|
1374 |
+
else if(0==strcmp(argv[2], "color")) test_lsd(cfg, weights, filename, 1);
|
1375 |
+
/*
|
1376 |
+
else if(0==strcmp(argv[2], "valid")) validate_lsd(cfg, weights);
|
1377 |
+
*/
|
1378 |
+
}
|
model/examples/nightmare.c
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <math.h>
|
4 |
+
|
5 |
+
// ./darknet nightmare cfg/extractor.recon.cfg ~/trained/yolo-coco.conv frame6.png -reconstruct -iters 500 -i 3 -lambda .1 -rate .01 -smooth 2
|
6 |
+
|
7 |
+
float abs_mean(float *x, int n)
|
8 |
+
{
|
9 |
+
int i;
|
10 |
+
float sum = 0;
|
11 |
+
for (i = 0; i < n; ++i){
|
12 |
+
sum += fabs(x[i]);
|
13 |
+
}
|
14 |
+
return sum/n;
|
15 |
+
}
|
16 |
+
|
17 |
+
void calculate_loss(float *output, float *delta, int n, float thresh)
|
18 |
+
{
|
19 |
+
int i;
|
20 |
+
float mean = mean_array(output, n);
|
21 |
+
float var = variance_array(output, n);
|
22 |
+
for(i = 0; i < n; ++i){
|
23 |
+
if(delta[i] > mean + thresh*sqrt(var)) delta[i] = output[i];
|
24 |
+
else delta[i] = 0;
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm)
|
29 |
+
{
|
30 |
+
//scale_image(orig, 2);
|
31 |
+
//translate_image(orig, -1);
|
32 |
+
net->n = max_layer + 1;
|
33 |
+
|
34 |
+
int dx = rand()%16 - 8;
|
35 |
+
int dy = rand()%16 - 8;
|
36 |
+
int flip = rand()%2;
|
37 |
+
|
38 |
+
image crop = crop_image(orig, dx, dy, orig.w, orig.h);
|
39 |
+
image im = resize_image(crop, (int)(orig.w * scale), (int)(orig.h * scale));
|
40 |
+
if(flip) flip_image(im);
|
41 |
+
|
42 |
+
resize_network(net, im.w, im.h);
|
43 |
+
layer last = net->layers[net->n-1];
|
44 |
+
//net->layers[net->n - 1].activation = LINEAR;
|
45 |
+
|
46 |
+
image delta = make_image(im.w, im.h, im.c);
|
47 |
+
|
48 |
+
#ifdef GPU
|
49 |
+
net->delta_gpu = cuda_make_array(delta.data, im.w*im.h*im.c);
|
50 |
+
copy_cpu(net->inputs, im.data, 1, net->input, 1);
|
51 |
+
|
52 |
+
forward_network_gpu(net);
|
53 |
+
copy_gpu(last.outputs, last.output_gpu, 1, last.delta_gpu, 1);
|
54 |
+
|
55 |
+
cuda_pull_array(last.delta_gpu, last.delta, last.outputs);
|
56 |
+
calculate_loss(last.delta, last.delta, last.outputs, thresh);
|
57 |
+
cuda_push_array(last.delta_gpu, last.delta, last.outputs);
|
58 |
+
|
59 |
+
backward_network_gpu(net);
|
60 |
+
|
61 |
+
cuda_pull_array(net->delta_gpu, delta.data, im.w*im.h*im.c);
|
62 |
+
cuda_free(net->delta_gpu);
|
63 |
+
net->delta_gpu = 0;
|
64 |
+
#else
|
65 |
+
printf("\nnet: %d %d %d im: %d %d %d\n", net->w, net->h, net->inputs, im.w, im.h, im.c);
|
66 |
+
copy_cpu(net->inputs, im.data, 1, net->input, 1);
|
67 |
+
net->delta = delta.data;
|
68 |
+
forward_network(net);
|
69 |
+
copy_cpu(last.outputs, last.output, 1, last.delta, 1);
|
70 |
+
calculate_loss(last.output, last.delta, last.outputs, thresh);
|
71 |
+
backward_network(net);
|
72 |
+
#endif
|
73 |
+
|
74 |
+
if(flip) flip_image(delta);
|
75 |
+
//normalize_array(delta.data, delta.w*delta.h*delta.c);
|
76 |
+
image resized = resize_image(delta, orig.w, orig.h);
|
77 |
+
image out = crop_image(resized, -dx, -dy, orig.w, orig.h);
|
78 |
+
|
79 |
+
/*
|
80 |
+
image g = grayscale_image(out);
|
81 |
+
free_image(out);
|
82 |
+
out = g;
|
83 |
+
*/
|
84 |
+
|
85 |
+
//rate = rate / abs_mean(out.data, out.w*out.h*out.c);
|
86 |
+
image gray = make_image(out.w, out.h, out.c);
|
87 |
+
fill_image(gray, .5);
|
88 |
+
axpy_cpu(orig.w*orig.h*orig.c, -1, orig.data, 1, gray.data, 1);
|
89 |
+
axpy_cpu(orig.w*orig.h*orig.c, .1, gray.data, 1, out.data, 1);
|
90 |
+
|
91 |
+
if(norm) normalize_array(out.data, out.w*out.h*out.c);
|
92 |
+
axpy_cpu(orig.w*orig.h*orig.c, rate, out.data, 1, orig.data, 1);
|
93 |
+
|
94 |
+
/*
|
95 |
+
normalize_array(orig.data, orig.w*orig.h*orig.c);
|
96 |
+
scale_image(orig, sqrt(var));
|
97 |
+
translate_image(orig, mean);
|
98 |
+
*/
|
99 |
+
|
100 |
+
//translate_image(orig, 1);
|
101 |
+
//scale_image(orig, .5);
|
102 |
+
//normalize_image(orig);
|
103 |
+
|
104 |
+
constrain_image(orig);
|
105 |
+
|
106 |
+
free_image(crop);
|
107 |
+
free_image(im);
|
108 |
+
free_image(delta);
|
109 |
+
free_image(resized);
|
110 |
+
free_image(out);
|
111 |
+
|
112 |
+
}
|
113 |
+
|
114 |
+
void smooth(image recon, image update, float lambda, int num)
|
115 |
+
{
|
116 |
+
int i, j, k;
|
117 |
+
int ii, jj;
|
118 |
+
for(k = 0; k < recon.c; ++k){
|
119 |
+
for(j = 0; j < recon.h; ++j){
|
120 |
+
for(i = 0; i < recon.w; ++i){
|
121 |
+
int out_index = i + recon.w*(j + recon.h*k);
|
122 |
+
for(jj = j-num; jj <= j + num && jj < recon.h; ++jj){
|
123 |
+
if (jj < 0) continue;
|
124 |
+
for(ii = i-num; ii <= i + num && ii < recon.w; ++ii){
|
125 |
+
if (ii < 0) continue;
|
126 |
+
int in_index = ii + recon.w*(jj + recon.h*k);
|
127 |
+
update.data[out_index] += lambda * (recon.data[in_index] - recon.data[out_index]);
|
128 |
+
}
|
129 |
+
}
|
130 |
+
}
|
131 |
+
}
|
132 |
+
}
|
133 |
+
}
|
134 |
+
|
135 |
+
void reconstruct_picture(network *net, float *features, image recon, image update, float rate, float momentum, float lambda, int smooth_size, int iters)
|
136 |
+
{
|
137 |
+
int iter = 0;
|
138 |
+
for (iter = 0; iter < iters; ++iter) {
|
139 |
+
image delta = make_image(recon.w, recon.h, recon.c);
|
140 |
+
|
141 |
+
#ifdef GPU
|
142 |
+
layer l = get_network_output_layer(net);
|
143 |
+
cuda_push_array(net->input_gpu, recon.data, recon.w*recon.h*recon.c);
|
144 |
+
//cuda_push_array(net->truth_gpu, features, net->truths);
|
145 |
+
net->delta_gpu = cuda_make_array(delta.data, delta.w*delta.h*delta.c);
|
146 |
+
|
147 |
+
forward_network_gpu(net);
|
148 |
+
cuda_push_array(l.delta_gpu, features, l.outputs);
|
149 |
+
axpy_gpu(l.outputs, -1, l.output_gpu, 1, l.delta_gpu, 1);
|
150 |
+
backward_network_gpu(net);
|
151 |
+
|
152 |
+
cuda_pull_array(net->delta_gpu, delta.data, delta.w*delta.h*delta.c);
|
153 |
+
|
154 |
+
cuda_free(net->delta_gpu);
|
155 |
+
#else
|
156 |
+
net->input = recon.data;
|
157 |
+
net->delta = delta.data;
|
158 |
+
net->truth = features;
|
159 |
+
|
160 |
+
forward_network(net);
|
161 |
+
backward_network(net);
|
162 |
+
#endif
|
163 |
+
|
164 |
+
//normalize_array(delta.data, delta.w*delta.h*delta.c);
|
165 |
+
axpy_cpu(recon.w*recon.h*recon.c, 1, delta.data, 1, update.data, 1);
|
166 |
+
//smooth(recon, update, lambda, smooth_size);
|
167 |
+
|
168 |
+
axpy_cpu(recon.w*recon.h*recon.c, rate, update.data, 1, recon.data, 1);
|
169 |
+
scal_cpu(recon.w*recon.h*recon.c, momentum, update.data, 1);
|
170 |
+
|
171 |
+
float mag = mag_array(delta.data, recon.w*recon.h*recon.c);
|
172 |
+
printf("mag: %f\n", mag);
|
173 |
+
//scal_cpu(recon.w*recon.h*recon.c, 600/mag, recon.data, 1);
|
174 |
+
|
175 |
+
constrain_image(recon);
|
176 |
+
free_image(delta);
|
177 |
+
}
|
178 |
+
}
|
179 |
+
|
180 |
+
/*
|
181 |
+
void run_lsd(int argc, char **argv)
|
182 |
+
{
|
183 |
+
srand(0);
|
184 |
+
if(argc < 3){
|
185 |
+
fprintf(stderr, "usage: %s %s [cfg] [weights] [image] [options! (optional)]\n", argv[0], argv[1]);
|
186 |
+
return;
|
187 |
+
}
|
188 |
+
|
189 |
+
char *cfg = argv[2];
|
190 |
+
char *weights = argv[3];
|
191 |
+
char *input = argv[4];
|
192 |
+
|
193 |
+
int norm = find_int_arg(argc, argv, "-norm", 1);
|
194 |
+
int rounds = find_int_arg(argc, argv, "-rounds", 1);
|
195 |
+
int iters = find_int_arg(argc, argv, "-iters", 10);
|
196 |
+
float rate = find_float_arg(argc, argv, "-rate", .04);
|
197 |
+
float momentum = find_float_arg(argc, argv, "-momentum", .9);
|
198 |
+
float lambda = find_float_arg(argc, argv, "-lambda", .01);
|
199 |
+
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
200 |
+
int reconstruct = find_arg(argc, argv, "-reconstruct");
|
201 |
+
int smooth_size = find_int_arg(argc, argv, "-smooth", 1);
|
202 |
+
|
203 |
+
network net = parse_network_cfg(cfg);
|
204 |
+
load_weights(&net, weights);
|
205 |
+
char *cfgbase = basecfg(cfg);
|
206 |
+
char *imbase = basecfg(input);
|
207 |
+
|
208 |
+
set_batch_network(&net, 1);
|
209 |
+
image im = load_image_color(input, 0, 0);
|
210 |
+
|
211 |
+
float *features = 0;
|
212 |
+
image update;
|
213 |
+
if (reconstruct){
|
214 |
+
im = letterbox_image(im, net->w, net->h);
|
215 |
+
|
216 |
+
int zz = 0;
|
217 |
+
network_predict(net, im.data);
|
218 |
+
image out_im = get_network_image(net);
|
219 |
+
image crop = crop_image(out_im, zz, zz, out_im.w-2*zz, out_im.h-2*zz);
|
220 |
+
//flip_image(crop);
|
221 |
+
image f_im = resize_image(crop, out_im.w, out_im.h);
|
222 |
+
free_image(crop);
|
223 |
+
printf("%d features\n", out_im.w*out_im.h*out_im.c);
|
224 |
+
|
225 |
+
|
226 |
+
im = resize_image(im, im.w, im.h);
|
227 |
+
f_im = resize_image(f_im, f_im.w, f_im.h);
|
228 |
+
features = f_im.data;
|
229 |
+
|
230 |
+
int i;
|
231 |
+
for(i = 0; i < 14*14*512; ++i){
|
232 |
+
features[i] += rand_uniform(-.19, .19);
|
233 |
+
}
|
234 |
+
|
235 |
+
free_image(im);
|
236 |
+
im = make_random_image(im.w, im.h, im.c);
|
237 |
+
update = make_image(im.w, im.h, im.c);
|
238 |
+
|
239 |
+
}
|
240 |
+
|
241 |
+
int e;
|
242 |
+
int n;
|
243 |
+
for(e = 0; e < rounds; ++e){
|
244 |
+
fprintf(stderr, "Iteration: ");
|
245 |
+
fflush(stderr);
|
246 |
+
for(n = 0; n < iters; ++n){
|
247 |
+
fprintf(stderr, "%d, ", n);
|
248 |
+
fflush(stderr);
|
249 |
+
if(reconstruct){
|
250 |
+
reconstruct_picture(net, features, im, update, rate, momentum, lambda, smooth_size, 1);
|
251 |
+
//if ((n+1)%30 == 0) rate *= .5;
|
252 |
+
show_image(im, "reconstruction");
|
253 |
+
#ifdef OPENCV
|
254 |
+
cvWaitKey(10);
|
255 |
+
#endif
|
256 |
+
}else{
|
257 |
+
int layer = max_layer + rand()%range - range/2;
|
258 |
+
int octave = rand()%octaves;
|
259 |
+
optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh, norm);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
fprintf(stderr, "done\n");
|
263 |
+
char buff[256];
|
264 |
+
if (prefix){
|
265 |
+
sprintf(buff, "%s/%s_%s_%d_%06d",prefix, imbase, cfgbase, max_layer, e);
|
266 |
+
}else{
|
267 |
+
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
|
268 |
+
}
|
269 |
+
printf("%d %s\n", e, buff);
|
270 |
+
save_image(im, buff);
|
271 |
+
//show_image(im, buff);
|
272 |
+
//cvWaitKey(0);
|
273 |
+
|
274 |
+
if(rotate){
|
275 |
+
image rot = rotate_image(im, rotate);
|
276 |
+
free_image(im);
|
277 |
+
im = rot;
|
278 |
+
}
|
279 |
+
image crop = crop_image(im, im.w * (1. - zoom)/2., im.h * (1.-zoom)/2., im.w*zoom, im.h*zoom);
|
280 |
+
image resized = resize_image(crop, im.w, im.h);
|
281 |
+
free_image(im);
|
282 |
+
free_image(crop);
|
283 |
+
im = resized;
|
284 |
+
}
|
285 |
+
}
|
286 |
+
*/
|
287 |
+
|
288 |
+
void run_nightmare(int argc, char **argv)
|
289 |
+
{
|
290 |
+
srand(0);
|
291 |
+
if(argc < 4){
|
292 |
+
fprintf(stderr, "usage: %s %s [cfg] [weights] [image] [layer] [options! (optional)]\n", argv[0], argv[1]);
|
293 |
+
return;
|
294 |
+
}
|
295 |
+
|
296 |
+
char *cfg = argv[2];
|
297 |
+
char *weights = argv[3];
|
298 |
+
char *input = argv[4];
|
299 |
+
int max_layer = atoi(argv[5]);
|
300 |
+
|
301 |
+
int range = find_int_arg(argc, argv, "-range", 1);
|
302 |
+
int norm = find_int_arg(argc, argv, "-norm", 1);
|
303 |
+
int rounds = find_int_arg(argc, argv, "-rounds", 1);
|
304 |
+
int iters = find_int_arg(argc, argv, "-iters", 10);
|
305 |
+
int octaves = find_int_arg(argc, argv, "-octaves", 4);
|
306 |
+
float zoom = find_float_arg(argc, argv, "-zoom", 1.);
|
307 |
+
float rate = find_float_arg(argc, argv, "-rate", .04);
|
308 |
+
float thresh = find_float_arg(argc, argv, "-thresh", 1.);
|
309 |
+
float rotate = find_float_arg(argc, argv, "-rotate", 0);
|
310 |
+
float momentum = find_float_arg(argc, argv, "-momentum", .9);
|
311 |
+
float lambda = find_float_arg(argc, argv, "-lambda", .01);
|
312 |
+
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
313 |
+
int reconstruct = find_arg(argc, argv, "-reconstruct");
|
314 |
+
int smooth_size = find_int_arg(argc, argv, "-smooth", 1);
|
315 |
+
|
316 |
+
network *net = load_network(cfg, weights, 0);
|
317 |
+
char *cfgbase = basecfg(cfg);
|
318 |
+
char *imbase = basecfg(input);
|
319 |
+
|
320 |
+
set_batch_network(net, 1);
|
321 |
+
image im = load_image_color(input, 0, 0);
|
322 |
+
if(0){
|
323 |
+
float scale = 1;
|
324 |
+
if(im.w > 512 || im.h > 512){
|
325 |
+
if(im.w > im.h) scale = 512.0/im.w;
|
326 |
+
else scale = 512.0/im.h;
|
327 |
+
}
|
328 |
+
image resized = resize_image(im, scale*im.w, scale*im.h);
|
329 |
+
free_image(im);
|
330 |
+
im = resized;
|
331 |
+
}
|
332 |
+
//im = letterbox_image(im, net->w, net->h);
|
333 |
+
|
334 |
+
float *features = 0;
|
335 |
+
image update;
|
336 |
+
if (reconstruct){
|
337 |
+
net->n = max_layer;
|
338 |
+
im = letterbox_image(im, net->w, net->h);
|
339 |
+
//resize_network(&net, im.w, im.h);
|
340 |
+
|
341 |
+
network_predict(net, im.data);
|
342 |
+
if(net->layers[net->n-1].type == REGION){
|
343 |
+
printf("region!\n");
|
344 |
+
zero_objectness(net->layers[net->n-1]);
|
345 |
+
}
|
346 |
+
image out_im = copy_image(get_network_image(net));
|
347 |
+
/*
|
348 |
+
image crop = crop_image(out_im, zz, zz, out_im.w-2*zz, out_im.h-2*zz);
|
349 |
+
//flip_image(crop);
|
350 |
+
image f_im = resize_image(crop, out_im.w, out_im.h);
|
351 |
+
free_image(crop);
|
352 |
+
*/
|
353 |
+
printf("%d features\n", out_im.w*out_im.h*out_im.c);
|
354 |
+
|
355 |
+
features = out_im.data;
|
356 |
+
|
357 |
+
/*
|
358 |
+
int i;
|
359 |
+
for(i = 0; i < 14*14*512; ++i){
|
360 |
+
//features[i] += rand_uniform(-.19, .19);
|
361 |
+
}
|
362 |
+
free_image(im);
|
363 |
+
im = make_random_image(im.w, im.h, im.c);
|
364 |
+
*/
|
365 |
+
update = make_image(im.w, im.h, im.c);
|
366 |
+
}
|
367 |
+
|
368 |
+
int e;
|
369 |
+
int n;
|
370 |
+
for(e = 0; e < rounds; ++e){
|
371 |
+
fprintf(stderr, "Iteration: ");
|
372 |
+
fflush(stderr);
|
373 |
+
for(n = 0; n < iters; ++n){
|
374 |
+
fprintf(stderr, "%d, ", n);
|
375 |
+
fflush(stderr);
|
376 |
+
if(reconstruct){
|
377 |
+
reconstruct_picture(net, features, im, update, rate, momentum, lambda, smooth_size, 1);
|
378 |
+
//if ((n+1)%30 == 0) rate *= .5;
|
379 |
+
show_image(im, "reconstruction", 10);
|
380 |
+
}else{
|
381 |
+
int layer = max_layer + rand()%range - range/2;
|
382 |
+
int octave = rand()%octaves;
|
383 |
+
optimize_picture(net, im, layer, 1/pow(1.33333333, octave), rate, thresh, norm);
|
384 |
+
}
|
385 |
+
}
|
386 |
+
fprintf(stderr, "done\n");
|
387 |
+
if(0){
|
388 |
+
image g = grayscale_image(im);
|
389 |
+
free_image(im);
|
390 |
+
im = g;
|
391 |
+
}
|
392 |
+
char buff[256];
|
393 |
+
if (prefix){
|
394 |
+
sprintf(buff, "%s/%s_%s_%d_%06d",prefix, imbase, cfgbase, max_layer, e);
|
395 |
+
}else{
|
396 |
+
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
|
397 |
+
}
|
398 |
+
printf("%d %s\n", e, buff);
|
399 |
+
save_image(im, buff);
|
400 |
+
//show_image(im, buff, 0);
|
401 |
+
|
402 |
+
if(rotate){
|
403 |
+
image rot = rotate_image(im, rotate);
|
404 |
+
free_image(im);
|
405 |
+
im = rot;
|
406 |
+
}
|
407 |
+
image crop = crop_image(im, im.w * (1. - zoom)/2., im.h * (1.-zoom)/2., im.w*zoom, im.h*zoom);
|
408 |
+
image resized = resize_image(crop, im.w, im.h);
|
409 |
+
free_image(im);
|
410 |
+
free_image(crop);
|
411 |
+
im = resized;
|
412 |
+
}
|
413 |
+
}
|
414 |
+
|
model/examples/regressor.c
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
#include <sys/time.h>
|
3 |
+
#include <assert.h>
|
4 |
+
|
5 |
+
void train_regressor(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
6 |
+
{
|
7 |
+
int i;
|
8 |
+
|
9 |
+
float avg_loss = -1;
|
10 |
+
char *base = basecfg(cfgfile);
|
11 |
+
printf("%s\n", base);
|
12 |
+
printf("%d\n", ngpus);
|
13 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
14 |
+
|
15 |
+
srand(time(0));
|
16 |
+
int seed = rand();
|
17 |
+
for(i = 0; i < ngpus; ++i){
|
18 |
+
srand(seed);
|
19 |
+
#ifdef GPU
|
20 |
+
cuda_set_device(gpus[i]);
|
21 |
+
#endif
|
22 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
23 |
+
nets[i]->learning_rate *= ngpus;
|
24 |
+
}
|
25 |
+
srand(time(0));
|
26 |
+
network *net = nets[0];
|
27 |
+
|
28 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
29 |
+
|
30 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
31 |
+
list *options = read_data_cfg(datacfg);
|
32 |
+
|
33 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
34 |
+
char *train_list = option_find_str(options, "train", "data/train.list");
|
35 |
+
int classes = option_find_int(options, "classes", 1);
|
36 |
+
|
37 |
+
list *plist = get_paths(train_list);
|
38 |
+
char **paths = (char **)list_to_array(plist);
|
39 |
+
printf("%d\n", plist->size);
|
40 |
+
int N = plist->size;
|
41 |
+
clock_t time;
|
42 |
+
|
43 |
+
load_args args = {0};
|
44 |
+
args.w = net->w;
|
45 |
+
args.h = net->h;
|
46 |
+
args.threads = 32;
|
47 |
+
args.classes = classes;
|
48 |
+
|
49 |
+
args.min = net->min_ratio*net->w;
|
50 |
+
args.max = net->max_ratio*net->w;
|
51 |
+
args.angle = net->angle;
|
52 |
+
args.aspect = net->aspect;
|
53 |
+
args.exposure = net->exposure;
|
54 |
+
args.saturation = net->saturation;
|
55 |
+
args.hue = net->hue;
|
56 |
+
args.size = net->w;
|
57 |
+
|
58 |
+
args.paths = paths;
|
59 |
+
args.n = imgs;
|
60 |
+
args.m = N;
|
61 |
+
args.type = REGRESSION_DATA;
|
62 |
+
|
63 |
+
data train;
|
64 |
+
data buffer;
|
65 |
+
pthread_t load_thread;
|
66 |
+
args.d = &buffer;
|
67 |
+
load_thread = load_data(args);
|
68 |
+
|
69 |
+
int epoch = (*net->seen)/N;
|
70 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
71 |
+
time=clock();
|
72 |
+
|
73 |
+
pthread_join(load_thread, 0);
|
74 |
+
train = buffer;
|
75 |
+
load_thread = load_data(args);
|
76 |
+
|
77 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
78 |
+
time=clock();
|
79 |
+
|
80 |
+
float loss = 0;
|
81 |
+
#ifdef GPU
|
82 |
+
if(ngpus == 1){
|
83 |
+
loss = train_network(net, train);
|
84 |
+
} else {
|
85 |
+
loss = train_networks(nets, ngpus, train, 4);
|
86 |
+
}
|
87 |
+
#else
|
88 |
+
loss = train_network(net, train);
|
89 |
+
#endif
|
90 |
+
if(avg_loss == -1) avg_loss = loss;
|
91 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
92 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
|
93 |
+
free_data(train);
|
94 |
+
if(*net->seen/N > epoch){
|
95 |
+
epoch = *net->seen/N;
|
96 |
+
char buff[256];
|
97 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
98 |
+
save_weights(net, buff);
|
99 |
+
}
|
100 |
+
if(get_current_batch(net)%100 == 0){
|
101 |
+
char buff[256];
|
102 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
103 |
+
save_weights(net, buff);
|
104 |
+
}
|
105 |
+
}
|
106 |
+
char buff[256];
|
107 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
108 |
+
save_weights(net, buff);
|
109 |
+
|
110 |
+
free_network(net);
|
111 |
+
free_ptrs((void**)paths, plist->size);
|
112 |
+
free_list(plist);
|
113 |
+
free(base);
|
114 |
+
}
|
115 |
+
|
116 |
+
void predict_regressor(char *cfgfile, char *weightfile, char *filename)
|
117 |
+
{
|
118 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
119 |
+
set_batch_network(net, 1);
|
120 |
+
srand(2222222);
|
121 |
+
|
122 |
+
clock_t time;
|
123 |
+
char buff[256];
|
124 |
+
char *input = buff;
|
125 |
+
while(1){
|
126 |
+
if(filename){
|
127 |
+
strncpy(input, filename, 256);
|
128 |
+
}else{
|
129 |
+
printf("Enter Image Path: ");
|
130 |
+
fflush(stdout);
|
131 |
+
input = fgets(input, 256, stdin);
|
132 |
+
if(!input) return;
|
133 |
+
strtok(input, "\n");
|
134 |
+
}
|
135 |
+
image im = load_image_color(input, 0, 0);
|
136 |
+
image sized = letterbox_image(im, net->w, net->h);
|
137 |
+
|
138 |
+
float *X = sized.data;
|
139 |
+
time=clock();
|
140 |
+
float *predictions = network_predict(net, X);
|
141 |
+
printf("Predicted: %f\n", predictions[0]);
|
142 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
143 |
+
free_image(im);
|
144 |
+
free_image(sized);
|
145 |
+
if (filename) break;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
|
150 |
+
void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
|
151 |
+
{
|
152 |
+
#ifdef OPENCV
|
153 |
+
printf("Regressor Demo\n");
|
154 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
155 |
+
set_batch_network(net, 1);
|
156 |
+
|
157 |
+
srand(2222222);
|
158 |
+
list *options = read_data_cfg(datacfg);
|
159 |
+
int classes = option_find_int(options, "classes", 1);
|
160 |
+
char *name_list = option_find_str(options, "names", 0);
|
161 |
+
char **names = get_labels(name_list);
|
162 |
+
|
163 |
+
void * cap = open_video_stream(filename, cam_index, 0,0,0);
|
164 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
165 |
+
float fps = 0;
|
166 |
+
|
167 |
+
while(1){
|
168 |
+
struct timeval tval_before, tval_after, tval_result;
|
169 |
+
gettimeofday(&tval_before, NULL);
|
170 |
+
|
171 |
+
image in = get_image_from_stream(cap);
|
172 |
+
image crop = center_crop_image(in, net->w, net->h);
|
173 |
+
grayscale_image_3c(crop);
|
174 |
+
|
175 |
+
float *predictions = network_predict(net, crop.data);
|
176 |
+
|
177 |
+
printf("\033[2J");
|
178 |
+
printf("\033[1;1H");
|
179 |
+
printf("\nFPS:%.0f\n",fps);
|
180 |
+
|
181 |
+
int i;
|
182 |
+
for(i = 0; i < classes; ++i){
|
183 |
+
printf("%s: %f\n", names[i], predictions[i]);
|
184 |
+
}
|
185 |
+
|
186 |
+
show_image(crop, "Regressor", 10);
|
187 |
+
free_image(in);
|
188 |
+
free_image(crop);
|
189 |
+
|
190 |
+
gettimeofday(&tval_after, NULL);
|
191 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
192 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
193 |
+
fps = .9*fps + .1*curr;
|
194 |
+
}
|
195 |
+
#endif
|
196 |
+
}
|
197 |
+
|
198 |
+
|
199 |
+
void run_regressor(int argc, char **argv)
|
200 |
+
{
|
201 |
+
if(argc < 4){
|
202 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
203 |
+
return;
|
204 |
+
}
|
205 |
+
|
206 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
207 |
+
int *gpus = 0;
|
208 |
+
int gpu = 0;
|
209 |
+
int ngpus = 0;
|
210 |
+
if(gpu_list){
|
211 |
+
printf("%s\n", gpu_list);
|
212 |
+
int len = strlen(gpu_list);
|
213 |
+
ngpus = 1;
|
214 |
+
int i;
|
215 |
+
for(i = 0; i < len; ++i){
|
216 |
+
if (gpu_list[i] == ',') ++ngpus;
|
217 |
+
}
|
218 |
+
gpus = calloc(ngpus, sizeof(int));
|
219 |
+
for(i = 0; i < ngpus; ++i){
|
220 |
+
gpus[i] = atoi(gpu_list);
|
221 |
+
gpu_list = strchr(gpu_list, ',')+1;
|
222 |
+
}
|
223 |
+
} else {
|
224 |
+
gpu = gpu_index;
|
225 |
+
gpus = &gpu;
|
226 |
+
ngpus = 1;
|
227 |
+
}
|
228 |
+
|
229 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
230 |
+
int clear = find_arg(argc, argv, "-clear");
|
231 |
+
char *data = argv[3];
|
232 |
+
char *cfg = argv[4];
|
233 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
234 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
235 |
+
if(0==strcmp(argv[2], "test")) predict_regressor(data, cfg, weights);
|
236 |
+
else if(0==strcmp(argv[2], "train")) train_regressor(data, cfg, weights, gpus, ngpus, clear);
|
237 |
+
else if(0==strcmp(argv[2], "demo")) demo_regressor(data, cfg, weights, cam_index, filename);
|
238 |
+
}
|
239 |
+
|
240 |
+
|
model/examples/rnn.c
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#include <math.h>
|
4 |
+
|
5 |
+
typedef struct {
|
6 |
+
float *x;
|
7 |
+
float *y;
|
8 |
+
} float_pair;
|
9 |
+
|
10 |
+
unsigned char **load_files(char *filename, int *n)
|
11 |
+
{
|
12 |
+
list *paths = get_paths(filename);
|
13 |
+
*n = paths->size;
|
14 |
+
unsigned char **contents = calloc(*n, sizeof(char *));
|
15 |
+
int i;
|
16 |
+
node *x = paths->front;
|
17 |
+
for(i = 0; i < *n; ++i){
|
18 |
+
contents[i] = read_file((char *)x->val);
|
19 |
+
x = x->next;
|
20 |
+
}
|
21 |
+
return contents;
|
22 |
+
}
|
23 |
+
|
24 |
+
int *read_tokenized_data(char *filename, size_t *read)
|
25 |
+
{
|
26 |
+
size_t size = 512;
|
27 |
+
size_t count = 0;
|
28 |
+
FILE *fp = fopen(filename, "r");
|
29 |
+
int *d = calloc(size, sizeof(int));
|
30 |
+
int n, one;
|
31 |
+
one = fscanf(fp, "%d", &n);
|
32 |
+
while(one == 1){
|
33 |
+
++count;
|
34 |
+
if(count > size){
|
35 |
+
size = size*2;
|
36 |
+
d = realloc(d, size*sizeof(int));
|
37 |
+
}
|
38 |
+
d[count-1] = n;
|
39 |
+
one = fscanf(fp, "%d", &n);
|
40 |
+
}
|
41 |
+
fclose(fp);
|
42 |
+
d = realloc(d, count*sizeof(int));
|
43 |
+
*read = count;
|
44 |
+
return d;
|
45 |
+
}
|
46 |
+
|
47 |
+
char **read_tokens(char *filename, size_t *read)
|
48 |
+
{
|
49 |
+
size_t size = 512;
|
50 |
+
size_t count = 0;
|
51 |
+
FILE *fp = fopen(filename, "r");
|
52 |
+
char **d = calloc(size, sizeof(char *));
|
53 |
+
char *line;
|
54 |
+
while((line=fgetl(fp)) != 0){
|
55 |
+
++count;
|
56 |
+
if(count > size){
|
57 |
+
size = size*2;
|
58 |
+
d = realloc(d, size*sizeof(char *));
|
59 |
+
}
|
60 |
+
if(0==strcmp(line, "<NEWLINE>")) line = "\n";
|
61 |
+
d[count-1] = line;
|
62 |
+
}
|
63 |
+
fclose(fp);
|
64 |
+
d = realloc(d, count*sizeof(char *));
|
65 |
+
*read = count;
|
66 |
+
return d;
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
|
71 |
+
{
|
72 |
+
float *x = calloc(batch * steps * characters, sizeof(float));
|
73 |
+
float *y = calloc(batch * steps * characters, sizeof(float));
|
74 |
+
int i,j;
|
75 |
+
for(i = 0; i < batch; ++i){
|
76 |
+
for(j = 0; j < steps; ++j){
|
77 |
+
int curr = tokens[(offsets[i])%len];
|
78 |
+
int next = tokens[(offsets[i] + 1)%len];
|
79 |
+
|
80 |
+
x[(j*batch + i)*characters + curr] = 1;
|
81 |
+
y[(j*batch + i)*characters + next] = 1;
|
82 |
+
|
83 |
+
offsets[i] = (offsets[i] + 1) % len;
|
84 |
+
|
85 |
+
if(curr >= characters || curr < 0 || next >= characters || next < 0){
|
86 |
+
error("Bad char");
|
87 |
+
}
|
88 |
+
}
|
89 |
+
}
|
90 |
+
float_pair p;
|
91 |
+
p.x = x;
|
92 |
+
p.y = y;
|
93 |
+
return p;
|
94 |
+
}
|
95 |
+
|
96 |
+
float_pair get_seq2seq_data(char **source, char **dest, int n, int characters, size_t len, int batch, int steps)
|
97 |
+
{
|
98 |
+
int i,j;
|
99 |
+
float *x = calloc(batch * steps * characters, sizeof(float));
|
100 |
+
float *y = calloc(batch * steps * characters, sizeof(float));
|
101 |
+
for(i = 0; i < batch; ++i){
|
102 |
+
int index = rand()%n;
|
103 |
+
//int slen = strlen(source[index]);
|
104 |
+
//int dlen = strlen(dest[index]);
|
105 |
+
for(j = 0; j < steps; ++j){
|
106 |
+
unsigned char curr = source[index][j];
|
107 |
+
unsigned char next = dest[index][j];
|
108 |
+
|
109 |
+
x[(j*batch + i)*characters + curr] = 1;
|
110 |
+
y[(j*batch + i)*characters + next] = 1;
|
111 |
+
|
112 |
+
if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
|
113 |
+
/*text[(index+j+2)%len] = 0;
|
114 |
+
printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
|
115 |
+
printf("%s", text+index);
|
116 |
+
*/
|
117 |
+
error("Bad char");
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
float_pair p;
|
122 |
+
p.x = x;
|
123 |
+
p.y = y;
|
124 |
+
return p;
|
125 |
+
}
|
126 |
+
|
127 |
+
float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
|
128 |
+
{
|
129 |
+
float *x = calloc(batch * steps * characters, sizeof(float));
|
130 |
+
float *y = calloc(batch * steps * characters, sizeof(float));
|
131 |
+
int i,j;
|
132 |
+
for(i = 0; i < batch; ++i){
|
133 |
+
for(j = 0; j < steps; ++j){
|
134 |
+
unsigned char curr = text[(offsets[i])%len];
|
135 |
+
unsigned char next = text[(offsets[i] + 1)%len];
|
136 |
+
|
137 |
+
x[(j*batch + i)*characters + curr] = 1;
|
138 |
+
y[(j*batch + i)*characters + next] = 1;
|
139 |
+
|
140 |
+
offsets[i] = (offsets[i] + 1) % len;
|
141 |
+
|
142 |
+
if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
|
143 |
+
/*text[(index+j+2)%len] = 0;
|
144 |
+
printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
|
145 |
+
printf("%s", text+index);
|
146 |
+
*/
|
147 |
+
error("Bad char");
|
148 |
+
}
|
149 |
+
}
|
150 |
+
}
|
151 |
+
float_pair p;
|
152 |
+
p.x = x;
|
153 |
+
p.y = y;
|
154 |
+
return p;
|
155 |
+
}
|
156 |
+
|
157 |
+
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
|
158 |
+
{
|
159 |
+
srand(time(0));
|
160 |
+
unsigned char *text = 0;
|
161 |
+
int *tokens = 0;
|
162 |
+
size_t size;
|
163 |
+
if(tokenized){
|
164 |
+
tokens = read_tokenized_data(filename, &size);
|
165 |
+
} else {
|
166 |
+
text = read_file(filename);
|
167 |
+
size = strlen((const char*)text);
|
168 |
+
}
|
169 |
+
|
170 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
171 |
+
char *base = basecfg(cfgfile);
|
172 |
+
fprintf(stderr, "%s\n", base);
|
173 |
+
float avg_loss = -1;
|
174 |
+
network *net = load_network(cfgfile, weightfile, clear);
|
175 |
+
|
176 |
+
int inputs = net->inputs;
|
177 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->learning_rate, net->momentum, net->decay, inputs, net->batch, net->time_steps);
|
178 |
+
int batch = net->batch;
|
179 |
+
int steps = net->time_steps;
|
180 |
+
if(clear) *net->seen = 0;
|
181 |
+
int i = (*net->seen)/net->batch;
|
182 |
+
|
183 |
+
int streams = batch/steps;
|
184 |
+
size_t *offsets = calloc(streams, sizeof(size_t));
|
185 |
+
int j;
|
186 |
+
for(j = 0; j < streams; ++j){
|
187 |
+
offsets[j] = rand_size_t()%size;
|
188 |
+
}
|
189 |
+
|
190 |
+
clock_t time;
|
191 |
+
while(get_current_batch(net) < net->max_batches){
|
192 |
+
i += 1;
|
193 |
+
time=clock();
|
194 |
+
float_pair p;
|
195 |
+
if(tokenized){
|
196 |
+
p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
|
197 |
+
}else{
|
198 |
+
p = get_rnn_data(text, offsets, inputs, size, streams, steps);
|
199 |
+
}
|
200 |
+
|
201 |
+
copy_cpu(net->inputs*net->batch, p.x, 1, net->input, 1);
|
202 |
+
copy_cpu(net->truths*net->batch, p.y, 1, net->truth, 1);
|
203 |
+
float loss = train_network_datum(net) / (batch);
|
204 |
+
free(p.x);
|
205 |
+
free(p.y);
|
206 |
+
if (avg_loss < 0) avg_loss = loss;
|
207 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
208 |
+
|
209 |
+
size_t chars = get_current_batch(net)*batch;
|
210 |
+
fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
|
211 |
+
|
212 |
+
for(j = 0; j < streams; ++j){
|
213 |
+
//printf("%d\n", j);
|
214 |
+
if(rand()%64 == 0){
|
215 |
+
//fprintf(stderr, "Reset\n");
|
216 |
+
offsets[j] = rand_size_t()%size;
|
217 |
+
reset_network_state(net, j);
|
218 |
+
}
|
219 |
+
}
|
220 |
+
|
221 |
+
if(i%10000==0){
|
222 |
+
char buff[256];
|
223 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
224 |
+
save_weights(net, buff);
|
225 |
+
}
|
226 |
+
if(i%100==0){
|
227 |
+
char buff[256];
|
228 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
229 |
+
save_weights(net, buff);
|
230 |
+
}
|
231 |
+
}
|
232 |
+
char buff[256];
|
233 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
234 |
+
save_weights(net, buff);
|
235 |
+
}
|
236 |
+
|
237 |
+
void print_symbol(int n, char **tokens){
|
238 |
+
if(tokens){
|
239 |
+
printf("%s ", tokens[n]);
|
240 |
+
} else {
|
241 |
+
printf("%c", n);
|
242 |
+
}
|
243 |
+
}
|
244 |
+
|
245 |
+
void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
|
246 |
+
{
|
247 |
+
char **tokens = 0;
|
248 |
+
if(token_file){
|
249 |
+
size_t n;
|
250 |
+
tokens = read_tokens(token_file, &n);
|
251 |
+
}
|
252 |
+
|
253 |
+
srand(rseed);
|
254 |
+
char *base = basecfg(cfgfile);
|
255 |
+
fprintf(stderr, "%s\n", base);
|
256 |
+
|
257 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
258 |
+
int inputs = net->inputs;
|
259 |
+
|
260 |
+
int i, j;
|
261 |
+
for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
|
262 |
+
int c = 0;
|
263 |
+
int len = strlen(seed);
|
264 |
+
float *input = calloc(inputs, sizeof(float));
|
265 |
+
|
266 |
+
/*
|
267 |
+
fill_cpu(inputs, 0, input, 1);
|
268 |
+
for(i = 0; i < 10; ++i){
|
269 |
+
network_predict(net, input);
|
270 |
+
}
|
271 |
+
fill_cpu(inputs, 0, input, 1);
|
272 |
+
*/
|
273 |
+
|
274 |
+
for(i = 0; i < len-1; ++i){
|
275 |
+
c = seed[i];
|
276 |
+
input[c] = 1;
|
277 |
+
network_predict(net, input);
|
278 |
+
input[c] = 0;
|
279 |
+
print_symbol(c, tokens);
|
280 |
+
}
|
281 |
+
if(len) c = seed[len-1];
|
282 |
+
print_symbol(c, tokens);
|
283 |
+
for(i = 0; i < num; ++i){
|
284 |
+
input[c] = 1;
|
285 |
+
float *out = network_predict(net, input);
|
286 |
+
input[c] = 0;
|
287 |
+
for(j = 32; j < 127; ++j){
|
288 |
+
//printf("%d %c %f\n",j, j, out[j]);
|
289 |
+
}
|
290 |
+
for(j = 0; j < inputs; ++j){
|
291 |
+
if (out[j] < .0001) out[j] = 0;
|
292 |
+
}
|
293 |
+
c = sample_array(out, inputs);
|
294 |
+
print_symbol(c, tokens);
|
295 |
+
}
|
296 |
+
printf("\n");
|
297 |
+
}
|
298 |
+
|
299 |
+
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
|
300 |
+
{
|
301 |
+
char **tokens = 0;
|
302 |
+
if(token_file){
|
303 |
+
size_t n;
|
304 |
+
tokens = read_tokens(token_file, &n);
|
305 |
+
}
|
306 |
+
|
307 |
+
srand(rseed);
|
308 |
+
char *base = basecfg(cfgfile);
|
309 |
+
fprintf(stderr, "%s\n", base);
|
310 |
+
|
311 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
312 |
+
int inputs = net->inputs;
|
313 |
+
|
314 |
+
int i, j;
|
315 |
+
for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
|
316 |
+
int c = 0;
|
317 |
+
float *input = calloc(inputs, sizeof(float));
|
318 |
+
float *out = 0;
|
319 |
+
|
320 |
+
while(1){
|
321 |
+
reset_network_state(net, 0);
|
322 |
+
while((c = getc(stdin)) != EOF && c != 0){
|
323 |
+
input[c] = 1;
|
324 |
+
out = network_predict(net, input);
|
325 |
+
input[c] = 0;
|
326 |
+
}
|
327 |
+
for(i = 0; i < num; ++i){
|
328 |
+
for(j = 0; j < inputs; ++j){
|
329 |
+
if (out[j] < .0001) out[j] = 0;
|
330 |
+
}
|
331 |
+
int next = sample_array(out, inputs);
|
332 |
+
if(c == '.' && next == '\n') break;
|
333 |
+
c = next;
|
334 |
+
print_symbol(c, tokens);
|
335 |
+
|
336 |
+
input[c] = 1;
|
337 |
+
out = network_predict(net, input);
|
338 |
+
input[c] = 0;
|
339 |
+
}
|
340 |
+
printf("\n");
|
341 |
+
}
|
342 |
+
}
|
343 |
+
|
344 |
+
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
|
345 |
+
{
|
346 |
+
char **tokens = 0;
|
347 |
+
if(token_file){
|
348 |
+
size_t n;
|
349 |
+
tokens = read_tokens(token_file, &n);
|
350 |
+
}
|
351 |
+
|
352 |
+
srand(rseed);
|
353 |
+
char *base = basecfg(cfgfile);
|
354 |
+
fprintf(stderr, "%s\n", base);
|
355 |
+
|
356 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
357 |
+
int inputs = net->inputs;
|
358 |
+
|
359 |
+
int i, j;
|
360 |
+
for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
|
361 |
+
int c = 0;
|
362 |
+
float *input = calloc(inputs, sizeof(float));
|
363 |
+
float *out = 0;
|
364 |
+
|
365 |
+
while((c = getc(stdin)) != EOF){
|
366 |
+
input[c] = 1;
|
367 |
+
out = network_predict(net, input);
|
368 |
+
input[c] = 0;
|
369 |
+
}
|
370 |
+
for(i = 0; i < num; ++i){
|
371 |
+
for(j = 0; j < inputs; ++j){
|
372 |
+
if (out[j] < .0001) out[j] = 0;
|
373 |
+
}
|
374 |
+
int next = sample_array(out, inputs);
|
375 |
+
if(c == '.' && next == '\n') break;
|
376 |
+
c = next;
|
377 |
+
print_symbol(c, tokens);
|
378 |
+
|
379 |
+
input[c] = 1;
|
380 |
+
out = network_predict(net, input);
|
381 |
+
input[c] = 0;
|
382 |
+
}
|
383 |
+
printf("\n");
|
384 |
+
}
|
385 |
+
|
386 |
+
void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
|
387 |
+
{
|
388 |
+
char *base = basecfg(cfgfile);
|
389 |
+
fprintf(stderr, "%s\n", base);
|
390 |
+
|
391 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
392 |
+
int inputs = net->inputs;
|
393 |
+
|
394 |
+
int count = 0;
|
395 |
+
int words = 1;
|
396 |
+
int c;
|
397 |
+
int len = strlen(seed);
|
398 |
+
float *input = calloc(inputs, sizeof(float));
|
399 |
+
int i;
|
400 |
+
for(i = 0; i < len; ++i){
|
401 |
+
c = seed[i];
|
402 |
+
input[(int)c] = 1;
|
403 |
+
network_predict(net, input);
|
404 |
+
input[(int)c] = 0;
|
405 |
+
}
|
406 |
+
float sum = 0;
|
407 |
+
c = getc(stdin);
|
408 |
+
float log2 = log(2);
|
409 |
+
int in = 0;
|
410 |
+
while(c != EOF){
|
411 |
+
int next = getc(stdin);
|
412 |
+
if(next == EOF) break;
|
413 |
+
if(next < 0 || next >= 255) error("Out of range character");
|
414 |
+
|
415 |
+
input[c] = 1;
|
416 |
+
float *out = network_predict(net, input);
|
417 |
+
input[c] = 0;
|
418 |
+
|
419 |
+
if(c == '.' && next == '\n') in = 0;
|
420 |
+
if(!in) {
|
421 |
+
if(c == '>' && next == '>'){
|
422 |
+
in = 1;
|
423 |
+
++words;
|
424 |
+
}
|
425 |
+
c = next;
|
426 |
+
continue;
|
427 |
+
}
|
428 |
+
++count;
|
429 |
+
sum += log(out[next])/log2;
|
430 |
+
c = next;
|
431 |
+
printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
|
432 |
+
}
|
433 |
+
}
|
434 |
+
|
435 |
+
void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
|
436 |
+
{
|
437 |
+
char *base = basecfg(cfgfile);
|
438 |
+
fprintf(stderr, "%s\n", base);
|
439 |
+
|
440 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
441 |
+
int inputs = net->inputs;
|
442 |
+
|
443 |
+
int count = 0;
|
444 |
+
int words = 1;
|
445 |
+
int c;
|
446 |
+
int len = strlen(seed);
|
447 |
+
float *input = calloc(inputs, sizeof(float));
|
448 |
+
int i;
|
449 |
+
for(i = 0; i < len; ++i){
|
450 |
+
c = seed[i];
|
451 |
+
input[(int)c] = 1;
|
452 |
+
network_predict(net, input);
|
453 |
+
input[(int)c] = 0;
|
454 |
+
}
|
455 |
+
float sum = 0;
|
456 |
+
c = getc(stdin);
|
457 |
+
float log2 = log(2);
|
458 |
+
while(c != EOF){
|
459 |
+
int next = getc(stdin);
|
460 |
+
if(next == EOF) break;
|
461 |
+
if(next < 0 || next >= 255) error("Out of range character");
|
462 |
+
++count;
|
463 |
+
if(next == ' ' || next == '\n' || next == '\t') ++words;
|
464 |
+
input[c] = 1;
|
465 |
+
float *out = network_predict(net, input);
|
466 |
+
input[c] = 0;
|
467 |
+
sum += log(out[next])/log2;
|
468 |
+
c = next;
|
469 |
+
printf("%d BPC: %4.4f Perplexity: %4.4f Word Perplexity: %4.4f\n", count, -sum/count, pow(2, -sum/count), pow(2, -sum/words));
|
470 |
+
}
|
471 |
+
}
|
472 |
+
|
473 |
+
void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
|
474 |
+
{
|
475 |
+
char *base = basecfg(cfgfile);
|
476 |
+
fprintf(stderr, "%s\n", base);
|
477 |
+
|
478 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
479 |
+
int inputs = net->inputs;
|
480 |
+
|
481 |
+
int c;
|
482 |
+
int seed_len = strlen(seed);
|
483 |
+
float *input = calloc(inputs, sizeof(float));
|
484 |
+
int i;
|
485 |
+
char *line;
|
486 |
+
while((line=fgetl(stdin)) != 0){
|
487 |
+
reset_network_state(net, 0);
|
488 |
+
for(i = 0; i < seed_len; ++i){
|
489 |
+
c = seed[i];
|
490 |
+
input[(int)c] = 1;
|
491 |
+
network_predict(net, input);
|
492 |
+
input[(int)c] = 0;
|
493 |
+
}
|
494 |
+
strip(line);
|
495 |
+
int str_len = strlen(line);
|
496 |
+
for(i = 0; i < str_len; ++i){
|
497 |
+
c = line[i];
|
498 |
+
input[(int)c] = 1;
|
499 |
+
network_predict(net, input);
|
500 |
+
input[(int)c] = 0;
|
501 |
+
}
|
502 |
+
c = ' ';
|
503 |
+
input[(int)c] = 1;
|
504 |
+
network_predict(net, input);
|
505 |
+
input[(int)c] = 0;
|
506 |
+
|
507 |
+
layer l = net->layers[0];
|
508 |
+
#ifdef GPU
|
509 |
+
cuda_pull_array(l.output_gpu, l.output, l.outputs);
|
510 |
+
#endif
|
511 |
+
printf("%s", line);
|
512 |
+
for(i = 0; i < l.outputs; ++i){
|
513 |
+
printf(",%g", l.output[i]);
|
514 |
+
}
|
515 |
+
printf("\n");
|
516 |
+
}
|
517 |
+
}
|
518 |
+
|
519 |
+
void run_char_rnn(int argc, char **argv)
|
520 |
+
{
|
521 |
+
if(argc < 4){
|
522 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
523 |
+
return;
|
524 |
+
}
|
525 |
+
char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt");
|
526 |
+
char *seed = find_char_arg(argc, argv, "-seed", "\n\n");
|
527 |
+
int len = find_int_arg(argc, argv, "-len", 1000);
|
528 |
+
float temp = find_float_arg(argc, argv, "-temp", .7);
|
529 |
+
int rseed = find_int_arg(argc, argv, "-srand", time(0));
|
530 |
+
int clear = find_arg(argc, argv, "-clear");
|
531 |
+
int tokenized = find_arg(argc, argv, "-tokenized");
|
532 |
+
char *tokens = find_char_arg(argc, argv, "-tokens", 0);
|
533 |
+
|
534 |
+
char *cfg = argv[3];
|
535 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
536 |
+
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
|
537 |
+
else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
|
538 |
+
else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
|
539 |
+
else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
|
540 |
+
else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
|
541 |
+
else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
|
542 |
+
}
|
model/examples/rnn_vid.c
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
#ifdef OPENCV
|
4 |
+
image get_image_from_stream(CvCapture *cap);
|
5 |
+
image ipl_to_image(IplImage* src);
|
6 |
+
|
7 |
+
void reconstruct_picture(network net, float *features, image recon, image update, float rate, float momentum, float lambda, int smooth_size, int iters);
|
8 |
+
|
9 |
+
|
10 |
+
typedef struct {
|
11 |
+
float *x;
|
12 |
+
float *y;
|
13 |
+
} float_pair;
|
14 |
+
|
15 |
+
float_pair get_rnn_vid_data(network net, char **files, int n, int batch, int steps)
|
16 |
+
{
|
17 |
+
int b;
|
18 |
+
assert(net.batch == steps + 1);
|
19 |
+
image out_im = get_network_image(net);
|
20 |
+
int output_size = out_im.w*out_im.h*out_im.c;
|
21 |
+
printf("%d %d %d\n", out_im.w, out_im.h, out_im.c);
|
22 |
+
float *feats = calloc(net.batch*batch*output_size, sizeof(float));
|
23 |
+
for(b = 0; b < batch; ++b){
|
24 |
+
int input_size = net.w*net.h*net.c;
|
25 |
+
float *input = calloc(input_size*net.batch, sizeof(float));
|
26 |
+
char *filename = files[rand()%n];
|
27 |
+
CvCapture *cap = cvCaptureFromFile(filename);
|
28 |
+
int frames = cvGetCaptureProperty(cap, CV_CAP_PROP_FRAME_COUNT);
|
29 |
+
int index = rand() % (frames - steps - 2);
|
30 |
+
if (frames < (steps + 4)){
|
31 |
+
--b;
|
32 |
+
free(input);
|
33 |
+
continue;
|
34 |
+
}
|
35 |
+
|
36 |
+
printf("frames: %d, index: %d\n", frames, index);
|
37 |
+
cvSetCaptureProperty(cap, CV_CAP_PROP_POS_FRAMES, index);
|
38 |
+
|
39 |
+
int i;
|
40 |
+
for(i = 0; i < net.batch; ++i){
|
41 |
+
IplImage* src = cvQueryFrame(cap);
|
42 |
+
image im = ipl_to_image(src);
|
43 |
+
rgbgr_image(im);
|
44 |
+
image re = resize_image(im, net.w, net.h);
|
45 |
+
//show_image(re, "loaded");
|
46 |
+
//cvWaitKey(10);
|
47 |
+
memcpy(input + i*input_size, re.data, input_size*sizeof(float));
|
48 |
+
free_image(im);
|
49 |
+
free_image(re);
|
50 |
+
}
|
51 |
+
float *output = network_predict(net, input);
|
52 |
+
|
53 |
+
free(input);
|
54 |
+
|
55 |
+
for(i = 0; i < net.batch; ++i){
|
56 |
+
memcpy(feats + (b + i*batch)*output_size, output + i*output_size, output_size*sizeof(float));
|
57 |
+
}
|
58 |
+
|
59 |
+
cvReleaseCapture(&cap);
|
60 |
+
}
|
61 |
+
|
62 |
+
//printf("%d %d %d\n", out_im.w, out_im.h, out_im.c);
|
63 |
+
float_pair p = {0};
|
64 |
+
p.x = feats;
|
65 |
+
p.y = feats + output_size*batch; //+ out_im.w*out_im.h*out_im.c;
|
66 |
+
|
67 |
+
return p;
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
void train_vid_rnn(char *cfgfile, char *weightfile)
|
72 |
+
{
|
73 |
+
char *train_videos = "data/vid/train.txt";
|
74 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
75 |
+
srand(time(0));
|
76 |
+
char *base = basecfg(cfgfile);
|
77 |
+
printf("%s\n", base);
|
78 |
+
float avg_loss = -1;
|
79 |
+
network net = parse_network_cfg(cfgfile);
|
80 |
+
if(weightfile){
|
81 |
+
load_weights(&net, weightfile);
|
82 |
+
}
|
83 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
84 |
+
int imgs = net.batch*net.subdivisions;
|
85 |
+
int i = *net.seen/imgs;
|
86 |
+
|
87 |
+
list *plist = get_paths(train_videos);
|
88 |
+
int N = plist->size;
|
89 |
+
char **paths = (char **)list_to_array(plist);
|
90 |
+
clock_t time;
|
91 |
+
int steps = net.time_steps;
|
92 |
+
int batch = net.batch / net.time_steps;
|
93 |
+
|
94 |
+
network extractor = parse_network_cfg("cfg/extractor.cfg");
|
95 |
+
load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv");
|
96 |
+
|
97 |
+
while(get_current_batch(net) < net.max_batches){
|
98 |
+
i += 1;
|
99 |
+
time=clock();
|
100 |
+
float_pair p = get_rnn_vid_data(extractor, paths, N, batch, steps);
|
101 |
+
|
102 |
+
copy_cpu(net.inputs*net.batch, p.x, 1, net.input, 1);
|
103 |
+
copy_cpu(net.truths*net.batch, p.y, 1, net.truth, 1);
|
104 |
+
float loss = train_network_datum(net) / (net.batch);
|
105 |
+
|
106 |
+
|
107 |
+
free(p.x);
|
108 |
+
if (avg_loss < 0) avg_loss = loss;
|
109 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
110 |
+
|
111 |
+
fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
|
112 |
+
if(i%100==0){
|
113 |
+
char buff[256];
|
114 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
115 |
+
save_weights(net, buff);
|
116 |
+
}
|
117 |
+
if(i%10==0){
|
118 |
+
char buff[256];
|
119 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
120 |
+
save_weights(net, buff);
|
121 |
+
}
|
122 |
+
}
|
123 |
+
char buff[256];
|
124 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
125 |
+
save_weights(net, buff);
|
126 |
+
}
|
127 |
+
|
128 |
+
|
129 |
+
image save_reconstruction(network net, image *init, float *feat, char *name, int i)
|
130 |
+
{
|
131 |
+
image recon;
|
132 |
+
if (init) {
|
133 |
+
recon = copy_image(*init);
|
134 |
+
} else {
|
135 |
+
recon = make_random_image(net.w, net.h, 3);
|
136 |
+
}
|
137 |
+
|
138 |
+
image update = make_image(net.w, net.h, 3);
|
139 |
+
reconstruct_picture(net, feat, recon, update, .01, .9, .1, 2, 50);
|
140 |
+
char buff[256];
|
141 |
+
sprintf(buff, "%s%d", name, i);
|
142 |
+
save_image(recon, buff);
|
143 |
+
free_image(update);
|
144 |
+
return recon;
|
145 |
+
}
|
146 |
+
|
147 |
+
void generate_vid_rnn(char *cfgfile, char *weightfile)
|
148 |
+
{
|
149 |
+
network extractor = parse_network_cfg("cfg/extractor.recon.cfg");
|
150 |
+
load_weights(&extractor, "/home/pjreddie/trained/yolo-coco.conv");
|
151 |
+
|
152 |
+
network net = parse_network_cfg(cfgfile);
|
153 |
+
if(weightfile){
|
154 |
+
load_weights(&net, weightfile);
|
155 |
+
}
|
156 |
+
set_batch_network(&extractor, 1);
|
157 |
+
set_batch_network(&net, 1);
|
158 |
+
|
159 |
+
int i;
|
160 |
+
CvCapture *cap = cvCaptureFromFile("/extra/vid/ILSVRC2015/Data/VID/snippets/val/ILSVRC2015_val_00007030.mp4");
|
161 |
+
float *feat;
|
162 |
+
float *next;
|
163 |
+
image last;
|
164 |
+
for(i = 0; i < 25; ++i){
|
165 |
+
image im = get_image_from_stream(cap);
|
166 |
+
image re = resize_image(im, extractor.w, extractor.h);
|
167 |
+
feat = network_predict(extractor, re.data);
|
168 |
+
if(i > 0){
|
169 |
+
printf("%f %f\n", mean_array(feat, 14*14*512), variance_array(feat, 14*14*512));
|
170 |
+
printf("%f %f\n", mean_array(next, 14*14*512), variance_array(next, 14*14*512));
|
171 |
+
printf("%f\n", mse_array(feat, 14*14*512));
|
172 |
+
axpy_cpu(14*14*512, -1, feat, 1, next, 1);
|
173 |
+
printf("%f\n", mse_array(next, 14*14*512));
|
174 |
+
}
|
175 |
+
next = network_predict(net, feat);
|
176 |
+
|
177 |
+
free_image(im);
|
178 |
+
|
179 |
+
free_image(save_reconstruction(extractor, 0, feat, "feat", i));
|
180 |
+
free_image(save_reconstruction(extractor, 0, next, "next", i));
|
181 |
+
if (i==24) last = copy_image(re);
|
182 |
+
free_image(re);
|
183 |
+
}
|
184 |
+
for(i = 0; i < 30; ++i){
|
185 |
+
next = network_predict(net, next);
|
186 |
+
image new = save_reconstruction(extractor, &last, next, "new", i);
|
187 |
+
free_image(last);
|
188 |
+
last = new;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
void run_vid_rnn(int argc, char **argv)
|
193 |
+
{
|
194 |
+
if(argc < 4){
|
195 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
196 |
+
return;
|
197 |
+
}
|
198 |
+
|
199 |
+
char *cfg = argv[3];
|
200 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
201 |
+
//char *filename = (argc > 5) ? argv[5]: 0;
|
202 |
+
if(0==strcmp(argv[2], "train")) train_vid_rnn(cfg, weights);
|
203 |
+
else if(0==strcmp(argv[2], "generate")) generate_vid_rnn(cfg, weights);
|
204 |
+
}
|
205 |
+
#else
|
206 |
+
void run_vid_rnn(int argc, char **argv){}
|
207 |
+
#endif
|
208 |
+
|
model/examples/segmenter.c
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
#include <sys/time.h>
|
3 |
+
#include <assert.h>
|
4 |
+
|
5 |
+
void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int display)
|
6 |
+
{
|
7 |
+
int i;
|
8 |
+
|
9 |
+
float avg_loss = -1;
|
10 |
+
char *base = basecfg(cfgfile);
|
11 |
+
printf("%s\n", base);
|
12 |
+
printf("%d\n", ngpus);
|
13 |
+
network **nets = calloc(ngpus, sizeof(network*));
|
14 |
+
|
15 |
+
srand(time(0));
|
16 |
+
int seed = rand();
|
17 |
+
for(i = 0; i < ngpus; ++i){
|
18 |
+
srand(seed);
|
19 |
+
#ifdef GPU
|
20 |
+
cuda_set_device(gpus[i]);
|
21 |
+
#endif
|
22 |
+
nets[i] = load_network(cfgfile, weightfile, clear);
|
23 |
+
nets[i]->learning_rate *= ngpus;
|
24 |
+
}
|
25 |
+
srand(time(0));
|
26 |
+
network *net = nets[0];
|
27 |
+
image pred = get_network_image(net);
|
28 |
+
|
29 |
+
int div = net->w/pred.w;
|
30 |
+
assert(pred.w * div == net->w);
|
31 |
+
assert(pred.h * div == net->h);
|
32 |
+
|
33 |
+
int imgs = net->batch * net->subdivisions * ngpus;
|
34 |
+
|
35 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
36 |
+
list *options = read_data_cfg(datacfg);
|
37 |
+
|
38 |
+
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
39 |
+
char *train_list = option_find_str(options, "train", "data/train.list");
|
40 |
+
|
41 |
+
list *plist = get_paths(train_list);
|
42 |
+
char **paths = (char **)list_to_array(plist);
|
43 |
+
printf("%d\n", plist->size);
|
44 |
+
int N = plist->size;
|
45 |
+
|
46 |
+
load_args args = {0};
|
47 |
+
args.w = net->w;
|
48 |
+
args.h = net->h;
|
49 |
+
args.threads = 32;
|
50 |
+
args.scale = div;
|
51 |
+
|
52 |
+
args.min = net->min_crop;
|
53 |
+
args.max = net->max_crop;
|
54 |
+
args.angle = net->angle;
|
55 |
+
args.aspect = net->aspect;
|
56 |
+
args.exposure = net->exposure;
|
57 |
+
args.saturation = net->saturation;
|
58 |
+
args.hue = net->hue;
|
59 |
+
args.size = net->w;
|
60 |
+
args.classes = 80;
|
61 |
+
|
62 |
+
args.paths = paths;
|
63 |
+
args.n = imgs;
|
64 |
+
args.m = N;
|
65 |
+
args.type = SEGMENTATION_DATA;
|
66 |
+
|
67 |
+
data train;
|
68 |
+
data buffer;
|
69 |
+
pthread_t load_thread;
|
70 |
+
args.d = &buffer;
|
71 |
+
load_thread = load_data(args);
|
72 |
+
|
73 |
+
int epoch = (*net->seen)/N;
|
74 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
75 |
+
double time = what_time_is_it_now();
|
76 |
+
|
77 |
+
pthread_join(load_thread, 0);
|
78 |
+
train = buffer;
|
79 |
+
load_thread = load_data(args);
|
80 |
+
|
81 |
+
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
82 |
+
time = what_time_is_it_now();
|
83 |
+
|
84 |
+
float loss = 0;
|
85 |
+
#ifdef GPU
|
86 |
+
if(ngpus == 1){
|
87 |
+
loss = train_network(net, train);
|
88 |
+
} else {
|
89 |
+
loss = train_networks(nets, ngpus, train, 4);
|
90 |
+
}
|
91 |
+
#else
|
92 |
+
loss = train_network(net, train);
|
93 |
+
#endif
|
94 |
+
if(display){
|
95 |
+
image tr = float_to_image(net->w/div, net->h/div, 80, train.y.vals[net->batch*(net->subdivisions-1)]);
|
96 |
+
image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]);
|
97 |
+
image mask = mask_to_rgb(tr);
|
98 |
+
image prmask = mask_to_rgb(pred);
|
99 |
+
show_image(im, "input", 1);
|
100 |
+
show_image(prmask, "pred", 1);
|
101 |
+
show_image(mask, "truth", 100);
|
102 |
+
free_image(mask);
|
103 |
+
free_image(prmask);
|
104 |
+
}
|
105 |
+
if(avg_loss == -1) avg_loss = loss;
|
106 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
107 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
108 |
+
free_data(train);
|
109 |
+
if(*net->seen/N > epoch){
|
110 |
+
epoch = *net->seen/N;
|
111 |
+
char buff[256];
|
112 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
113 |
+
save_weights(net, buff);
|
114 |
+
}
|
115 |
+
if(get_current_batch(net)%100 == 0){
|
116 |
+
char buff[256];
|
117 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
118 |
+
save_weights(net, buff);
|
119 |
+
}
|
120 |
+
}
|
121 |
+
char buff[256];
|
122 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
123 |
+
save_weights(net, buff);
|
124 |
+
|
125 |
+
free_network(net);
|
126 |
+
free_ptrs((void**)paths, plist->size);
|
127 |
+
free_list(plist);
|
128 |
+
free(base);
|
129 |
+
}
|
130 |
+
|
131 |
+
void predict_segmenter(char *datafile, char *cfg, char *weights, char *filename)
|
132 |
+
{
|
133 |
+
network *net = load_network(cfg, weights, 0);
|
134 |
+
set_batch_network(net, 1);
|
135 |
+
srand(2222222);
|
136 |
+
|
137 |
+
clock_t time;
|
138 |
+
char buff[256];
|
139 |
+
char *input = buff;
|
140 |
+
while(1){
|
141 |
+
if(filename){
|
142 |
+
strncpy(input, filename, 256);
|
143 |
+
}else{
|
144 |
+
printf("Enter Image Path: ");
|
145 |
+
fflush(stdout);
|
146 |
+
input = fgets(input, 256, stdin);
|
147 |
+
if(!input) return;
|
148 |
+
strtok(input, "\n");
|
149 |
+
}
|
150 |
+
image im = load_image_color(input, 0, 0);
|
151 |
+
image sized = letterbox_image(im, net->w, net->h);
|
152 |
+
|
153 |
+
float *X = sized.data;
|
154 |
+
time=clock();
|
155 |
+
float *predictions = network_predict(net, X);
|
156 |
+
image pred = get_network_image(net);
|
157 |
+
image prmask = mask_to_rgb(pred);
|
158 |
+
printf("Predicted: %f\n", predictions[0]);
|
159 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
160 |
+
show_image(sized, "orig", 1);
|
161 |
+
show_image(prmask, "pred", 0);
|
162 |
+
free_image(im);
|
163 |
+
free_image(sized);
|
164 |
+
free_image(prmask);
|
165 |
+
if (filename) break;
|
166 |
+
}
|
167 |
+
}
|
168 |
+
|
169 |
+
|
170 |
+
void demo_segmenter(char *datacfg, char *cfg, char *weights, int cam_index, const char *filename)
|
171 |
+
{
|
172 |
+
#ifdef OPENCV
|
173 |
+
printf("Classifier Demo\n");
|
174 |
+
network *net = load_network(cfg, weights, 0);
|
175 |
+
set_batch_network(net, 1);
|
176 |
+
|
177 |
+
srand(2222222);
|
178 |
+
void * cap = open_video_stream(filename, cam_index, 0,0,0);
|
179 |
+
|
180 |
+
if(!cap) error("Couldn't connect to webcam.\n");
|
181 |
+
float fps = 0;
|
182 |
+
|
183 |
+
while(1){
|
184 |
+
struct timeval tval_before, tval_after, tval_result;
|
185 |
+
gettimeofday(&tval_before, NULL);
|
186 |
+
|
187 |
+
image in = get_image_from_stream(cap);
|
188 |
+
image in_s = letterbox_image(in, net->w, net->h);
|
189 |
+
|
190 |
+
network_predict(net, in_s.data);
|
191 |
+
|
192 |
+
printf("\033[2J");
|
193 |
+
printf("\033[1;1H");
|
194 |
+
printf("\nFPS:%.0f\n",fps);
|
195 |
+
|
196 |
+
image pred = get_network_image(net);
|
197 |
+
image prmask = mask_to_rgb(pred);
|
198 |
+
show_image(prmask, "Segmenter", 10);
|
199 |
+
|
200 |
+
free_image(in_s);
|
201 |
+
free_image(in);
|
202 |
+
free_image(prmask);
|
203 |
+
|
204 |
+
gettimeofday(&tval_after, NULL);
|
205 |
+
timersub(&tval_after, &tval_before, &tval_result);
|
206 |
+
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
207 |
+
fps = .9*fps + .1*curr;
|
208 |
+
}
|
209 |
+
#endif
|
210 |
+
}
|
211 |
+
|
212 |
+
|
213 |
+
void run_segmenter(int argc, char **argv)
|
214 |
+
{
|
215 |
+
if(argc < 4){
|
216 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
217 |
+
return;
|
218 |
+
}
|
219 |
+
|
220 |
+
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
221 |
+
int *gpus = 0;
|
222 |
+
int gpu = 0;
|
223 |
+
int ngpus = 0;
|
224 |
+
if(gpu_list){
|
225 |
+
printf("%s\n", gpu_list);
|
226 |
+
int len = strlen(gpu_list);
|
227 |
+
ngpus = 1;
|
228 |
+
int i;
|
229 |
+
for(i = 0; i < len; ++i){
|
230 |
+
if (gpu_list[i] == ',') ++ngpus;
|
231 |
+
}
|
232 |
+
gpus = calloc(ngpus, sizeof(int));
|
233 |
+
for(i = 0; i < ngpus; ++i){
|
234 |
+
gpus[i] = atoi(gpu_list);
|
235 |
+
gpu_list = strchr(gpu_list, ',')+1;
|
236 |
+
}
|
237 |
+
} else {
|
238 |
+
gpu = gpu_index;
|
239 |
+
gpus = &gpu;
|
240 |
+
ngpus = 1;
|
241 |
+
}
|
242 |
+
|
243 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
244 |
+
int clear = find_arg(argc, argv, "-clear");
|
245 |
+
int display = find_arg(argc, argv, "-display");
|
246 |
+
char *data = argv[3];
|
247 |
+
char *cfg = argv[4];
|
248 |
+
char *weights = (argc > 5) ? argv[5] : 0;
|
249 |
+
char *filename = (argc > 6) ? argv[6]: 0;
|
250 |
+
if(0==strcmp(argv[2], "test")) predict_segmenter(data, cfg, weights, filename);
|
251 |
+
else if(0==strcmp(argv[2], "train")) train_segmenter(data, cfg, weights, gpus, ngpus, clear, display);
|
252 |
+
else if(0==strcmp(argv[2], "demo")) demo_segmenter(data, cfg, weights, cam_index, filename);
|
253 |
+
}
|
254 |
+
|
255 |
+
|
model/examples/super.c
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void train_super(char *cfgfile, char *weightfile, int clear)
|
4 |
+
{
|
5 |
+
char *train_images = "/data/imagenet/imagenet1k.train.list";
|
6 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
7 |
+
srand(time(0));
|
8 |
+
char *base = basecfg(cfgfile);
|
9 |
+
printf("%s\n", base);
|
10 |
+
float avg_loss = -1;
|
11 |
+
network *net = load_network(cfgfile, weightfile, clear);
|
12 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
13 |
+
int imgs = net->batch*net->subdivisions;
|
14 |
+
int i = *net->seen/imgs;
|
15 |
+
data train, buffer;
|
16 |
+
|
17 |
+
|
18 |
+
list *plist = get_paths(train_images);
|
19 |
+
//int N = plist->size;
|
20 |
+
char **paths = (char **)list_to_array(plist);
|
21 |
+
|
22 |
+
load_args args = {0};
|
23 |
+
args.w = net->w;
|
24 |
+
args.h = net->h;
|
25 |
+
args.scale = 4;
|
26 |
+
args.paths = paths;
|
27 |
+
args.n = imgs;
|
28 |
+
args.m = plist->size;
|
29 |
+
args.d = &buffer;
|
30 |
+
args.type = SUPER_DATA;
|
31 |
+
|
32 |
+
pthread_t load_thread = load_data_in_thread(args);
|
33 |
+
clock_t time;
|
34 |
+
//while(i*imgs < N*120){
|
35 |
+
while(get_current_batch(net) < net->max_batches){
|
36 |
+
i += 1;
|
37 |
+
time=clock();
|
38 |
+
pthread_join(load_thread, 0);
|
39 |
+
train = buffer;
|
40 |
+
load_thread = load_data_in_thread(args);
|
41 |
+
|
42 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
43 |
+
|
44 |
+
time=clock();
|
45 |
+
float loss = train_network(net, train);
|
46 |
+
if (avg_loss < 0) avg_loss = loss;
|
47 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
48 |
+
|
49 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
50 |
+
if(i%1000==0){
|
51 |
+
char buff[256];
|
52 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
53 |
+
save_weights(net, buff);
|
54 |
+
}
|
55 |
+
if(i%100==0){
|
56 |
+
char buff[256];
|
57 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
58 |
+
save_weights(net, buff);
|
59 |
+
}
|
60 |
+
free_data(train);
|
61 |
+
}
|
62 |
+
char buff[256];
|
63 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
64 |
+
save_weights(net, buff);
|
65 |
+
}
|
66 |
+
|
67 |
+
void test_super(char *cfgfile, char *weightfile, char *filename)
|
68 |
+
{
|
69 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
70 |
+
set_batch_network(net, 1);
|
71 |
+
srand(2222222);
|
72 |
+
|
73 |
+
clock_t time;
|
74 |
+
char buff[256];
|
75 |
+
char *input = buff;
|
76 |
+
while(1){
|
77 |
+
if(filename){
|
78 |
+
strncpy(input, filename, 256);
|
79 |
+
}else{
|
80 |
+
printf("Enter Image Path: ");
|
81 |
+
fflush(stdout);
|
82 |
+
input = fgets(input, 256, stdin);
|
83 |
+
if(!input) return;
|
84 |
+
strtok(input, "\n");
|
85 |
+
}
|
86 |
+
image im = load_image_color(input, 0, 0);
|
87 |
+
resize_network(net, im.w, im.h);
|
88 |
+
printf("%d %d\n", im.w, im.h);
|
89 |
+
|
90 |
+
float *X = im.data;
|
91 |
+
time=clock();
|
92 |
+
network_predict(net, X);
|
93 |
+
image out = get_network_image(net);
|
94 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
95 |
+
save_image(out, "out");
|
96 |
+
show_image(out, "out", 0);
|
97 |
+
|
98 |
+
free_image(im);
|
99 |
+
if (filename) break;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
void run_super(int argc, char **argv)
|
105 |
+
{
|
106 |
+
if(argc < 4){
|
107 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
108 |
+
return;
|
109 |
+
}
|
110 |
+
|
111 |
+
char *cfg = argv[3];
|
112 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
113 |
+
char *filename = (argc > 5) ? argv[5] : 0;
|
114 |
+
int clear = find_arg(argc, argv, "-clear");
|
115 |
+
if(0==strcmp(argv[2], "train")) train_super(cfg, weights, clear);
|
116 |
+
else if(0==strcmp(argv[2], "test")) test_super(cfg, weights, filename);
|
117 |
+
/*
|
118 |
+
else if(0==strcmp(argv[2], "valid")) validate_super(cfg, weights);
|
119 |
+
*/
|
120 |
+
}
|
model/examples/swag.c
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
#include <sys/time.h>
|
3 |
+
|
4 |
+
void train_swag(char *cfgfile, char *weightfile)
|
5 |
+
{
|
6 |
+
char *train_images = "data/voc.0712.trainval";
|
7 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
8 |
+
srand(time(0));
|
9 |
+
char *base = basecfg(cfgfile);
|
10 |
+
printf("%s\n", base);
|
11 |
+
float avg_loss = -1;
|
12 |
+
network net = parse_network_cfg(cfgfile);
|
13 |
+
if(weightfile){
|
14 |
+
load_weights(&net, weightfile);
|
15 |
+
}
|
16 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
17 |
+
int imgs = net.batch*net.subdivisions;
|
18 |
+
int i = *net.seen/imgs;
|
19 |
+
data train, buffer;
|
20 |
+
|
21 |
+
layer l = net.layers[net.n - 1];
|
22 |
+
|
23 |
+
int side = l.side;
|
24 |
+
int classes = l.classes;
|
25 |
+
float jitter = l.jitter;
|
26 |
+
|
27 |
+
list *plist = get_paths(train_images);
|
28 |
+
//int N = plist->size;
|
29 |
+
char **paths = (char **)list_to_array(plist);
|
30 |
+
|
31 |
+
load_args args = {0};
|
32 |
+
args.w = net.w;
|
33 |
+
args.h = net.h;
|
34 |
+
args.paths = paths;
|
35 |
+
args.n = imgs;
|
36 |
+
args.m = plist->size;
|
37 |
+
args.classes = classes;
|
38 |
+
args.jitter = jitter;
|
39 |
+
args.num_boxes = side;
|
40 |
+
args.d = &buffer;
|
41 |
+
args.type = REGION_DATA;
|
42 |
+
|
43 |
+
pthread_t load_thread = load_data_in_thread(args);
|
44 |
+
clock_t time;
|
45 |
+
//while(i*imgs < N*120){
|
46 |
+
while(get_current_batch(net) < net.max_batches){
|
47 |
+
i += 1;
|
48 |
+
time=clock();
|
49 |
+
pthread_join(load_thread, 0);
|
50 |
+
train = buffer;
|
51 |
+
load_thread = load_data_in_thread(args);
|
52 |
+
|
53 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
54 |
+
|
55 |
+
time=clock();
|
56 |
+
float loss = train_network(net, train);
|
57 |
+
if (avg_loss < 0) avg_loss = loss;
|
58 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
59 |
+
|
60 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
61 |
+
if(i%1000==0 || i == 600){
|
62 |
+
char buff[256];
|
63 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
64 |
+
save_weights(net, buff);
|
65 |
+
}
|
66 |
+
free_data(train);
|
67 |
+
}
|
68 |
+
char buff[256];
|
69 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
70 |
+
save_weights(net, buff);
|
71 |
+
}
|
72 |
+
|
73 |
+
void run_swag(int argc, char **argv)
|
74 |
+
{
|
75 |
+
if(argc < 4){
|
76 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
77 |
+
return;
|
78 |
+
}
|
79 |
+
|
80 |
+
char *cfg = argv[3];
|
81 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
82 |
+
if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
|
83 |
+
}
|
model/examples/tag.c
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void train_tag(char *cfgfile, char *weightfile, int clear)
|
4 |
+
{
|
5 |
+
srand(time(0));
|
6 |
+
float avg_loss = -1;
|
7 |
+
char *base = basecfg(cfgfile);
|
8 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
9 |
+
printf("%s\n", base);
|
10 |
+
network *net = load_network(cfgfile, weightfile, clear);
|
11 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
12 |
+
int imgs = 1024;
|
13 |
+
list *plist = get_paths("/home/pjreddie/tag/train.list");
|
14 |
+
char **paths = (char **)list_to_array(plist);
|
15 |
+
printf("%d\n", plist->size);
|
16 |
+
int N = plist->size;
|
17 |
+
clock_t time;
|
18 |
+
pthread_t load_thread;
|
19 |
+
data train;
|
20 |
+
data buffer;
|
21 |
+
|
22 |
+
load_args args = {0};
|
23 |
+
args.w = net->w;
|
24 |
+
args.h = net->h;
|
25 |
+
|
26 |
+
args.min = net->w;
|
27 |
+
args.max = net->max_crop;
|
28 |
+
args.size = net->w;
|
29 |
+
|
30 |
+
args.paths = paths;
|
31 |
+
args.classes = net->outputs;
|
32 |
+
args.n = imgs;
|
33 |
+
args.m = N;
|
34 |
+
args.d = &buffer;
|
35 |
+
args.type = TAG_DATA;
|
36 |
+
|
37 |
+
args.angle = net->angle;
|
38 |
+
args.exposure = net->exposure;
|
39 |
+
args.saturation = net->saturation;
|
40 |
+
args.hue = net->hue;
|
41 |
+
|
42 |
+
fprintf(stderr, "%d classes\n", net->outputs);
|
43 |
+
|
44 |
+
load_thread = load_data_in_thread(args);
|
45 |
+
int epoch = (*net->seen)/N;
|
46 |
+
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
47 |
+
time=clock();
|
48 |
+
pthread_join(load_thread, 0);
|
49 |
+
train = buffer;
|
50 |
+
|
51 |
+
load_thread = load_data_in_thread(args);
|
52 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
53 |
+
time=clock();
|
54 |
+
float loss = train_network(net, train);
|
55 |
+
if(avg_loss == -1) avg_loss = loss;
|
56 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
57 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
|
58 |
+
free_data(train);
|
59 |
+
if(*net->seen/N > epoch){
|
60 |
+
epoch = *net->seen/N;
|
61 |
+
char buff[256];
|
62 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
63 |
+
save_weights(net, buff);
|
64 |
+
}
|
65 |
+
if(get_current_batch(net)%100 == 0){
|
66 |
+
char buff[256];
|
67 |
+
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
68 |
+
save_weights(net, buff);
|
69 |
+
}
|
70 |
+
}
|
71 |
+
char buff[256];
|
72 |
+
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
73 |
+
save_weights(net, buff);
|
74 |
+
|
75 |
+
pthread_join(load_thread, 0);
|
76 |
+
free_data(buffer);
|
77 |
+
free_network(net);
|
78 |
+
free_ptrs((void**)paths, plist->size);
|
79 |
+
free_list(plist);
|
80 |
+
free(base);
|
81 |
+
}
|
82 |
+
|
83 |
+
void test_tag(char *cfgfile, char *weightfile, char *filename)
|
84 |
+
{
|
85 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
86 |
+
set_batch_network(net, 1);
|
87 |
+
srand(2222222);
|
88 |
+
int i = 0;
|
89 |
+
char **names = get_labels("data/tags.txt");
|
90 |
+
clock_t time;
|
91 |
+
int indexes[10];
|
92 |
+
char buff[256];
|
93 |
+
char *input = buff;
|
94 |
+
int size = net->w;
|
95 |
+
while(1){
|
96 |
+
if(filename){
|
97 |
+
strncpy(input, filename, 256);
|
98 |
+
}else{
|
99 |
+
printf("Enter Image Path: ");
|
100 |
+
fflush(stdout);
|
101 |
+
input = fgets(input, 256, stdin);
|
102 |
+
if(!input) return;
|
103 |
+
strtok(input, "\n");
|
104 |
+
}
|
105 |
+
image im = load_image_color(input, 0, 0);
|
106 |
+
image r = resize_min(im, size);
|
107 |
+
resize_network(net, r.w, r.h);
|
108 |
+
printf("%d %d\n", r.w, r.h);
|
109 |
+
|
110 |
+
float *X = r.data;
|
111 |
+
time=clock();
|
112 |
+
float *predictions = network_predict(net, X);
|
113 |
+
top_predictions(net, 10, indexes);
|
114 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
115 |
+
for(i = 0; i < 10; ++i){
|
116 |
+
int index = indexes[i];
|
117 |
+
printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
|
118 |
+
}
|
119 |
+
if(r.data != im.data) free_image(r);
|
120 |
+
free_image(im);
|
121 |
+
if (filename) break;
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
void run_tag(int argc, char **argv)
|
127 |
+
{
|
128 |
+
if(argc < 4){
|
129 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
130 |
+
return;
|
131 |
+
}
|
132 |
+
|
133 |
+
int clear = find_arg(argc, argv, "-clear");
|
134 |
+
char *cfg = argv[3];
|
135 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
136 |
+
char *filename = (argc > 5) ? argv[5] : 0;
|
137 |
+
if(0==strcmp(argv[2], "train")) train_tag(cfg, weights, clear);
|
138 |
+
else if(0==strcmp(argv[2], "test")) test_tag(cfg, weights, filename);
|
139 |
+
}
|
140 |
+
|
model/examples/voxel.c
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void extract_voxel(char *lfile, char *rfile, char *prefix)
|
4 |
+
{
|
5 |
+
#ifdef OPENCV
|
6 |
+
int w = 1920;
|
7 |
+
int h = 1080;
|
8 |
+
int shift = 0;
|
9 |
+
int count = 0;
|
10 |
+
CvCapture *lcap = cvCaptureFromFile(lfile);
|
11 |
+
CvCapture *rcap = cvCaptureFromFile(rfile);
|
12 |
+
while(1){
|
13 |
+
image l = get_image_from_stream(lcap);
|
14 |
+
image r = get_image_from_stream(rcap);
|
15 |
+
if(!l.w || !r.w) break;
|
16 |
+
if(count%100 == 0) {
|
17 |
+
shift = best_3d_shift_r(l, r, -l.h/100, l.h/100);
|
18 |
+
printf("%d\n", shift);
|
19 |
+
}
|
20 |
+
image ls = crop_image(l, (l.w - w)/2, (l.h - h)/2, w, h);
|
21 |
+
image rs = crop_image(r, 105 + (r.w - w)/2, (r.h - h)/2 + shift, w, h);
|
22 |
+
char buff[256];
|
23 |
+
sprintf(buff, "%s_%05d_l", prefix, count);
|
24 |
+
save_image(ls, buff);
|
25 |
+
sprintf(buff, "%s_%05d_r", prefix, count);
|
26 |
+
save_image(rs, buff);
|
27 |
+
free_image(l);
|
28 |
+
free_image(r);
|
29 |
+
free_image(ls);
|
30 |
+
free_image(rs);
|
31 |
+
++count;
|
32 |
+
}
|
33 |
+
|
34 |
+
#else
|
35 |
+
printf("need OpenCV for extraction\n");
|
36 |
+
#endif
|
37 |
+
}
|
38 |
+
|
39 |
+
void train_voxel(char *cfgfile, char *weightfile)
|
40 |
+
{
|
41 |
+
char *train_images = "/data/imagenet/imagenet1k.train.list";
|
42 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
43 |
+
srand(time(0));
|
44 |
+
char *base = basecfg(cfgfile);
|
45 |
+
printf("%s\n", base);
|
46 |
+
float avg_loss = -1;
|
47 |
+
network net = parse_network_cfg(cfgfile);
|
48 |
+
if(weightfile){
|
49 |
+
load_weights(&net, weightfile);
|
50 |
+
}
|
51 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
52 |
+
int imgs = net.batch*net.subdivisions;
|
53 |
+
int i = *net.seen/imgs;
|
54 |
+
data train, buffer;
|
55 |
+
|
56 |
+
|
57 |
+
list *plist = get_paths(train_images);
|
58 |
+
//int N = plist->size;
|
59 |
+
char **paths = (char **)list_to_array(plist);
|
60 |
+
|
61 |
+
load_args args = {0};
|
62 |
+
args.w = net.w;
|
63 |
+
args.h = net.h;
|
64 |
+
args.scale = 4;
|
65 |
+
args.paths = paths;
|
66 |
+
args.n = imgs;
|
67 |
+
args.m = plist->size;
|
68 |
+
args.d = &buffer;
|
69 |
+
args.type = SUPER_DATA;
|
70 |
+
|
71 |
+
pthread_t load_thread = load_data_in_thread(args);
|
72 |
+
clock_t time;
|
73 |
+
//while(i*imgs < N*120){
|
74 |
+
while(get_current_batch(net) < net.max_batches){
|
75 |
+
i += 1;
|
76 |
+
time=clock();
|
77 |
+
pthread_join(load_thread, 0);
|
78 |
+
train = buffer;
|
79 |
+
load_thread = load_data_in_thread(args);
|
80 |
+
|
81 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
82 |
+
|
83 |
+
time=clock();
|
84 |
+
float loss = train_network(net, train);
|
85 |
+
if (avg_loss < 0) avg_loss = loss;
|
86 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
87 |
+
|
88 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
89 |
+
if(i%1000==0){
|
90 |
+
char buff[256];
|
91 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
92 |
+
save_weights(net, buff);
|
93 |
+
}
|
94 |
+
if(i%100==0){
|
95 |
+
char buff[256];
|
96 |
+
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
97 |
+
save_weights(net, buff);
|
98 |
+
}
|
99 |
+
free_data(train);
|
100 |
+
}
|
101 |
+
char buff[256];
|
102 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
103 |
+
save_weights(net, buff);
|
104 |
+
}
|
105 |
+
|
106 |
+
void test_voxel(char *cfgfile, char *weightfile, char *filename)
|
107 |
+
{
|
108 |
+
network net = parse_network_cfg(cfgfile);
|
109 |
+
if(weightfile){
|
110 |
+
load_weights(&net, weightfile);
|
111 |
+
}
|
112 |
+
set_batch_network(&net, 1);
|
113 |
+
srand(2222222);
|
114 |
+
|
115 |
+
clock_t time;
|
116 |
+
char buff[256];
|
117 |
+
char *input = buff;
|
118 |
+
while(1){
|
119 |
+
if(filename){
|
120 |
+
strncpy(input, filename, 256);
|
121 |
+
}else{
|
122 |
+
printf("Enter Image Path: ");
|
123 |
+
fflush(stdout);
|
124 |
+
input = fgets(input, 256, stdin);
|
125 |
+
if(!input) return;
|
126 |
+
strtok(input, "\n");
|
127 |
+
}
|
128 |
+
image im = load_image_color(input, 0, 0);
|
129 |
+
resize_network(&net, im.w, im.h);
|
130 |
+
printf("%d %d\n", im.w, im.h);
|
131 |
+
|
132 |
+
float *X = im.data;
|
133 |
+
time=clock();
|
134 |
+
network_predict(net, X);
|
135 |
+
image out = get_network_image(net);
|
136 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
137 |
+
save_image(out, "out");
|
138 |
+
|
139 |
+
free_image(im);
|
140 |
+
if (filename) break;
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
void run_voxel(int argc, char **argv)
|
146 |
+
{
|
147 |
+
if(argc < 4){
|
148 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
149 |
+
return;
|
150 |
+
}
|
151 |
+
|
152 |
+
char *cfg = argv[3];
|
153 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
154 |
+
char *filename = (argc > 5) ? argv[5] : 0;
|
155 |
+
if(0==strcmp(argv[2], "train")) train_voxel(cfg, weights);
|
156 |
+
else if(0==strcmp(argv[2], "test")) test_voxel(cfg, weights, filename);
|
157 |
+
else if(0==strcmp(argv[2], "extract")) extract_voxel(argv[3], argv[4], argv[5]);
|
158 |
+
/*
|
159 |
+
else if(0==strcmp(argv[2], "valid")) validate_voxel(cfg, weights);
|
160 |
+
*/
|
161 |
+
}
|
model/examples/writing.c
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
void train_writing(char *cfgfile, char *weightfile)
|
4 |
+
{
|
5 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
6 |
+
srand(time(0));
|
7 |
+
float avg_loss = -1;
|
8 |
+
char *base = basecfg(cfgfile);
|
9 |
+
printf("%s\n", base);
|
10 |
+
network net = parse_network_cfg(cfgfile);
|
11 |
+
if(weightfile){
|
12 |
+
load_weights(&net, weightfile);
|
13 |
+
}
|
14 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
15 |
+
int imgs = net.batch*net.subdivisions;
|
16 |
+
list *plist = get_paths("figures.list");
|
17 |
+
char **paths = (char **)list_to_array(plist);
|
18 |
+
clock_t time;
|
19 |
+
int N = plist->size;
|
20 |
+
printf("N: %d\n", N);
|
21 |
+
image out = get_network_image(net);
|
22 |
+
|
23 |
+
data train, buffer;
|
24 |
+
|
25 |
+
load_args args = {0};
|
26 |
+
args.w = net.w;
|
27 |
+
args.h = net.h;
|
28 |
+
args.out_w = out.w;
|
29 |
+
args.out_h = out.h;
|
30 |
+
args.paths = paths;
|
31 |
+
args.n = imgs;
|
32 |
+
args.m = N;
|
33 |
+
args.d = &buffer;
|
34 |
+
args.type = WRITING_DATA;
|
35 |
+
|
36 |
+
pthread_t load_thread = load_data_in_thread(args);
|
37 |
+
int epoch = (*net.seen)/N;
|
38 |
+
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
|
39 |
+
time=clock();
|
40 |
+
pthread_join(load_thread, 0);
|
41 |
+
train = buffer;
|
42 |
+
load_thread = load_data_in_thread(args);
|
43 |
+
printf("Loaded %lf seconds\n",sec(clock()-time));
|
44 |
+
|
45 |
+
time=clock();
|
46 |
+
float loss = train_network(net, train);
|
47 |
+
|
48 |
+
/*
|
49 |
+
image pred = float_to_image(64, 64, 1, out);
|
50 |
+
print_image(pred);
|
51 |
+
*/
|
52 |
+
|
53 |
+
/*
|
54 |
+
image im = float_to_image(256, 256, 3, train.X.vals[0]);
|
55 |
+
image lab = float_to_image(64, 64, 1, train.y.vals[0]);
|
56 |
+
image pred = float_to_image(64, 64, 1, out);
|
57 |
+
show_image(im, "image");
|
58 |
+
show_image(lab, "label");
|
59 |
+
print_image(lab);
|
60 |
+
show_image(pred, "pred");
|
61 |
+
cvWaitKey(0);
|
62 |
+
*/
|
63 |
+
|
64 |
+
if(avg_loss == -1) avg_loss = loss;
|
65 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
66 |
+
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
|
67 |
+
free_data(train);
|
68 |
+
if(get_current_batch(net)%100 == 0){
|
69 |
+
char buff[256];
|
70 |
+
sprintf(buff, "%s/%s_batch_%ld.weights", backup_directory, base, get_current_batch(net));
|
71 |
+
save_weights(net, buff);
|
72 |
+
}
|
73 |
+
if(*net.seen/N > epoch){
|
74 |
+
epoch = *net.seen/N;
|
75 |
+
char buff[256];
|
76 |
+
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
77 |
+
save_weights(net, buff);
|
78 |
+
}
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
void test_writing(char *cfgfile, char *weightfile, char *filename)
|
83 |
+
{
|
84 |
+
network net = parse_network_cfg(cfgfile);
|
85 |
+
if(weightfile){
|
86 |
+
load_weights(&net, weightfile);
|
87 |
+
}
|
88 |
+
set_batch_network(&net, 1);
|
89 |
+
srand(2222222);
|
90 |
+
clock_t time;
|
91 |
+
char buff[256];
|
92 |
+
char *input = buff;
|
93 |
+
while(1){
|
94 |
+
if(filename){
|
95 |
+
strncpy(input, filename, 256);
|
96 |
+
}else{
|
97 |
+
printf("Enter Image Path: ");
|
98 |
+
fflush(stdout);
|
99 |
+
input = fgets(input, 256, stdin);
|
100 |
+
if(!input) return;
|
101 |
+
strtok(input, "\n");
|
102 |
+
}
|
103 |
+
|
104 |
+
image im = load_image_color(input, 0, 0);
|
105 |
+
resize_network(&net, im.w, im.h);
|
106 |
+
printf("%d %d %d\n", im.h, im.w, im.c);
|
107 |
+
float *X = im.data;
|
108 |
+
time=clock();
|
109 |
+
network_predict(net, X);
|
110 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
111 |
+
image pred = get_network_image(net);
|
112 |
+
|
113 |
+
image upsampled = resize_image(pred, im.w, im.h);
|
114 |
+
image thresh = threshold_image(upsampled, .5);
|
115 |
+
pred = thresh;
|
116 |
+
|
117 |
+
show_image(pred, "prediction");
|
118 |
+
show_image(im, "orig");
|
119 |
+
#ifdef OPENCV
|
120 |
+
cvWaitKey(0);
|
121 |
+
cvDestroyAllWindows();
|
122 |
+
#endif
|
123 |
+
|
124 |
+
free_image(upsampled);
|
125 |
+
free_image(thresh);
|
126 |
+
free_image(im);
|
127 |
+
if (filename) break;
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
void run_writing(int argc, char **argv)
|
132 |
+
{
|
133 |
+
if(argc < 4){
|
134 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
135 |
+
return;
|
136 |
+
}
|
137 |
+
|
138 |
+
char *cfg = argv[3];
|
139 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
140 |
+
char *filename = (argc > 5) ? argv[5] : 0;
|
141 |
+
if(0==strcmp(argv[2], "train")) train_writing(cfg, weights);
|
142 |
+
else if(0==strcmp(argv[2], "test")) test_writing(cfg, weights, filename);
|
143 |
+
}
|
144 |
+
|
model/examples/yolo.c
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "darknet.h"
|
2 |
+
|
3 |
+
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
|
4 |
+
|
5 |
+
void train_yolo(char *cfgfile, char *weightfile)
|
6 |
+
{
|
7 |
+
char *train_images = "/data/voc/train.txt";
|
8 |
+
char *backup_directory = "/home/pjreddie/backup/";
|
9 |
+
srand(time(0));
|
10 |
+
char *base = basecfg(cfgfile);
|
11 |
+
printf("%s\n", base);
|
12 |
+
float avg_loss = -1;
|
13 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
14 |
+
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
15 |
+
int imgs = net->batch*net->subdivisions;
|
16 |
+
int i = *net->seen/imgs;
|
17 |
+
data train, buffer;
|
18 |
+
|
19 |
+
|
20 |
+
layer l = net->layers[net->n - 1];
|
21 |
+
|
22 |
+
int side = l.side;
|
23 |
+
int classes = l.classes;
|
24 |
+
float jitter = l.jitter;
|
25 |
+
|
26 |
+
list *plist = get_paths(train_images);
|
27 |
+
//int N = plist->size;
|
28 |
+
char **paths = (char **)list_to_array(plist);
|
29 |
+
|
30 |
+
load_args args = {0};
|
31 |
+
args.w = net->w;
|
32 |
+
args.h = net->h;
|
33 |
+
args.paths = paths;
|
34 |
+
args.n = imgs;
|
35 |
+
args.m = plist->size;
|
36 |
+
args.classes = classes;
|
37 |
+
args.jitter = jitter;
|
38 |
+
args.num_boxes = side;
|
39 |
+
args.d = &buffer;
|
40 |
+
args.type = REGION_DATA;
|
41 |
+
|
42 |
+
args.angle = net->angle;
|
43 |
+
args.exposure = net->exposure;
|
44 |
+
args.saturation = net->saturation;
|
45 |
+
args.hue = net->hue;
|
46 |
+
|
47 |
+
pthread_t load_thread = load_data_in_thread(args);
|
48 |
+
clock_t time;
|
49 |
+
//while(i*imgs < N*120){
|
50 |
+
while(get_current_batch(net) < net->max_batches){
|
51 |
+
i += 1;
|
52 |
+
time=clock();
|
53 |
+
pthread_join(load_thread, 0);
|
54 |
+
train = buffer;
|
55 |
+
load_thread = load_data_in_thread(args);
|
56 |
+
|
57 |
+
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
58 |
+
|
59 |
+
time=clock();
|
60 |
+
float loss = train_network(net, train);
|
61 |
+
if (avg_loss < 0) avg_loss = loss;
|
62 |
+
avg_loss = avg_loss*.9 + loss*.1;
|
63 |
+
|
64 |
+
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
|
65 |
+
if(i%1000==0 || (i < 1000 && i%100 == 0)){
|
66 |
+
char buff[256];
|
67 |
+
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
68 |
+
save_weights(net, buff);
|
69 |
+
}
|
70 |
+
free_data(train);
|
71 |
+
}
|
72 |
+
char buff[256];
|
73 |
+
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
74 |
+
save_weights(net, buff);
|
75 |
+
}
|
76 |
+
|
77 |
+
void print_yolo_detections(FILE **fps, char *id, int total, int classes, int w, int h, detection *dets)
|
78 |
+
{
|
79 |
+
int i, j;
|
80 |
+
for(i = 0; i < total; ++i){
|
81 |
+
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
82 |
+
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
83 |
+
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
84 |
+
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
85 |
+
|
86 |
+
if (xmin < 0) xmin = 0;
|
87 |
+
if (ymin < 0) ymin = 0;
|
88 |
+
if (xmax > w) xmax = w;
|
89 |
+
if (ymax > h) ymax = h;
|
90 |
+
|
91 |
+
for(j = 0; j < classes; ++j){
|
92 |
+
if (dets[i].prob[j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, dets[i].prob[j],
|
93 |
+
xmin, ymin, xmax, ymax);
|
94 |
+
}
|
95 |
+
}
|
96 |
+
}
|
97 |
+
|
98 |
+
void validate_yolo(char *cfg, char *weights)
|
99 |
+
{
|
100 |
+
network *net = load_network(cfg, weights, 0);
|
101 |
+
set_batch_network(net, 1);
|
102 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
103 |
+
srand(time(0));
|
104 |
+
|
105 |
+
char *base = "results/comp4_det_test_";
|
106 |
+
//list *plist = get_paths("data/voc.2007.test");
|
107 |
+
list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt");
|
108 |
+
//list *plist = get_paths("data/voc.2012.test");
|
109 |
+
char **paths = (char **)list_to_array(plist);
|
110 |
+
|
111 |
+
layer l = net->layers[net->n-1];
|
112 |
+
int classes = l.classes;
|
113 |
+
|
114 |
+
int j;
|
115 |
+
FILE **fps = calloc(classes, sizeof(FILE *));
|
116 |
+
for(j = 0; j < classes; ++j){
|
117 |
+
char buff[1024];
|
118 |
+
snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
|
119 |
+
fps[j] = fopen(buff, "w");
|
120 |
+
}
|
121 |
+
|
122 |
+
int m = plist->size;
|
123 |
+
int i=0;
|
124 |
+
int t;
|
125 |
+
|
126 |
+
float thresh = .001;
|
127 |
+
int nms = 1;
|
128 |
+
float iou_thresh = .5;
|
129 |
+
|
130 |
+
int nthreads = 8;
|
131 |
+
image *val = calloc(nthreads, sizeof(image));
|
132 |
+
image *val_resized = calloc(nthreads, sizeof(image));
|
133 |
+
image *buf = calloc(nthreads, sizeof(image));
|
134 |
+
image *buf_resized = calloc(nthreads, sizeof(image));
|
135 |
+
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
136 |
+
|
137 |
+
load_args args = {0};
|
138 |
+
args.w = net->w;
|
139 |
+
args.h = net->h;
|
140 |
+
args.type = IMAGE_DATA;
|
141 |
+
|
142 |
+
for(t = 0; t < nthreads; ++t){
|
143 |
+
args.path = paths[i+t];
|
144 |
+
args.im = &buf[t];
|
145 |
+
args.resized = &buf_resized[t];
|
146 |
+
thr[t] = load_data_in_thread(args);
|
147 |
+
}
|
148 |
+
time_t start = time(0);
|
149 |
+
for(i = nthreads; i < m+nthreads; i += nthreads){
|
150 |
+
fprintf(stderr, "%d\n", i);
|
151 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
152 |
+
pthread_join(thr[t], 0);
|
153 |
+
val[t] = buf[t];
|
154 |
+
val_resized[t] = buf_resized[t];
|
155 |
+
}
|
156 |
+
for(t = 0; t < nthreads && i+t < m; ++t){
|
157 |
+
args.path = paths[i+t];
|
158 |
+
args.im = &buf[t];
|
159 |
+
args.resized = &buf_resized[t];
|
160 |
+
thr[t] = load_data_in_thread(args);
|
161 |
+
}
|
162 |
+
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
163 |
+
char *path = paths[i+t-nthreads];
|
164 |
+
char *id = basecfg(path);
|
165 |
+
float *X = val_resized[t].data;
|
166 |
+
network_predict(net, X);
|
167 |
+
int w = val[t].w;
|
168 |
+
int h = val[t].h;
|
169 |
+
int nboxes = 0;
|
170 |
+
detection *dets = get_network_boxes(net, w, h, thresh, 0, 0, 0, &nboxes);
|
171 |
+
if (nms) do_nms_sort(dets, l.side*l.side*l.n, classes, iou_thresh);
|
172 |
+
print_yolo_detections(fps, id, l.side*l.side*l.n, classes, w, h, dets);
|
173 |
+
free_detections(dets, nboxes);
|
174 |
+
free(id);
|
175 |
+
free_image(val[t]);
|
176 |
+
free_image(val_resized[t]);
|
177 |
+
}
|
178 |
+
}
|
179 |
+
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
|
180 |
+
}
|
181 |
+
|
182 |
+
void validate_yolo_recall(char *cfg, char *weights)
|
183 |
+
{
|
184 |
+
network *net = load_network(cfg, weights, 0);
|
185 |
+
set_batch_network(net, 1);
|
186 |
+
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
187 |
+
srand(time(0));
|
188 |
+
|
189 |
+
char *base = "results/comp4_det_test_";
|
190 |
+
list *plist = get_paths("data/voc.2007.test");
|
191 |
+
char **paths = (char **)list_to_array(plist);
|
192 |
+
|
193 |
+
layer l = net->layers[net->n-1];
|
194 |
+
int classes = l.classes;
|
195 |
+
int side = l.side;
|
196 |
+
|
197 |
+
int j, k;
|
198 |
+
FILE **fps = calloc(classes, sizeof(FILE *));
|
199 |
+
for(j = 0; j < classes; ++j){
|
200 |
+
char buff[1024];
|
201 |
+
snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
|
202 |
+
fps[j] = fopen(buff, "w");
|
203 |
+
}
|
204 |
+
|
205 |
+
int m = plist->size;
|
206 |
+
int i=0;
|
207 |
+
|
208 |
+
float thresh = .001;
|
209 |
+
float iou_thresh = .5;
|
210 |
+
float nms = 0;
|
211 |
+
|
212 |
+
int total = 0;
|
213 |
+
int correct = 0;
|
214 |
+
int proposals = 0;
|
215 |
+
float avg_iou = 0;
|
216 |
+
|
217 |
+
for(i = 0; i < m; ++i){
|
218 |
+
char *path = paths[i];
|
219 |
+
image orig = load_image_color(path, 0, 0);
|
220 |
+
image sized = resize_image(orig, net->w, net->h);
|
221 |
+
char *id = basecfg(path);
|
222 |
+
network_predict(net, sized.data);
|
223 |
+
|
224 |
+
int nboxes = 0;
|
225 |
+
detection *dets = get_network_boxes(net, orig.w, orig.h, thresh, 0, 0, 1, &nboxes);
|
226 |
+
if (nms) do_nms_obj(dets, side*side*l.n, 1, nms);
|
227 |
+
|
228 |
+
char labelpath[4096];
|
229 |
+
find_replace(path, "images", "labels", labelpath);
|
230 |
+
find_replace(labelpath, "JPEGImages", "labels", labelpath);
|
231 |
+
find_replace(labelpath, ".jpg", ".txt", labelpath);
|
232 |
+
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
233 |
+
|
234 |
+
int num_labels = 0;
|
235 |
+
box_label *truth = read_boxes(labelpath, &num_labels);
|
236 |
+
for(k = 0; k < side*side*l.n; ++k){
|
237 |
+
if(dets[k].objectness > thresh){
|
238 |
+
++proposals;
|
239 |
+
}
|
240 |
+
}
|
241 |
+
for (j = 0; j < num_labels; ++j) {
|
242 |
+
++total;
|
243 |
+
box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
|
244 |
+
float best_iou = 0;
|
245 |
+
for(k = 0; k < side*side*l.n; ++k){
|
246 |
+
float iou = box_iou(dets[k].bbox, t);
|
247 |
+
if(dets[k].objectness > thresh && iou > best_iou){
|
248 |
+
best_iou = iou;
|
249 |
+
}
|
250 |
+
}
|
251 |
+
avg_iou += best_iou;
|
252 |
+
if(best_iou > iou_thresh){
|
253 |
+
++correct;
|
254 |
+
}
|
255 |
+
}
|
256 |
+
|
257 |
+
fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
|
258 |
+
free_detections(dets, nboxes);
|
259 |
+
free(id);
|
260 |
+
free_image(orig);
|
261 |
+
free_image(sized);
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
|
266 |
+
{
|
267 |
+
image **alphabet = load_alphabet();
|
268 |
+
network *net = load_network(cfgfile, weightfile, 0);
|
269 |
+
layer l = net->layers[net->n-1];
|
270 |
+
set_batch_network(net, 1);
|
271 |
+
srand(2222222);
|
272 |
+
clock_t time;
|
273 |
+
char buff[256];
|
274 |
+
char *input = buff;
|
275 |
+
float nms=.4;
|
276 |
+
while(1){
|
277 |
+
if(filename){
|
278 |
+
strncpy(input, filename, 256);
|
279 |
+
} else {
|
280 |
+
printf("Enter Image Path: ");
|
281 |
+
fflush(stdout);
|
282 |
+
input = fgets(input, 256, stdin);
|
283 |
+
if(!input) return;
|
284 |
+
strtok(input, "\n");
|
285 |
+
}
|
286 |
+
image im = load_image_color(input,0,0);
|
287 |
+
image sized = resize_image(im, net->w, net->h);
|
288 |
+
float *X = sized.data;
|
289 |
+
time=clock();
|
290 |
+
network_predict(net, X);
|
291 |
+
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
292 |
+
|
293 |
+
int nboxes = 0;
|
294 |
+
detection *dets = get_network_boxes(net, 1, 1, thresh, 0, 0, 0, &nboxes);
|
295 |
+
if (nms) do_nms_sort(dets, l.side*l.side*l.n, l.classes, nms);
|
296 |
+
|
297 |
+
draw_detections(im, dets, l.side*l.side*l.n, thresh, voc_names, alphabet, 20);
|
298 |
+
save_image(im, "predictions");
|
299 |
+
show_image(im, "predictions", 0);
|
300 |
+
free_detections(dets, nboxes);
|
301 |
+
free_image(im);
|
302 |
+
free_image(sized);
|
303 |
+
if (filename) break;
|
304 |
+
}
|
305 |
+
}
|
306 |
+
|
307 |
+
void run_yolo(int argc, char **argv)
|
308 |
+
{
|
309 |
+
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
310 |
+
float thresh = find_float_arg(argc, argv, "-thresh", .2);
|
311 |
+
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
312 |
+
int frame_skip = find_int_arg(argc, argv, "-s", 0);
|
313 |
+
if(argc < 4){
|
314 |
+
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
315 |
+
return;
|
316 |
+
}
|
317 |
+
|
318 |
+
int avg = find_int_arg(argc, argv, "-avg", 1);
|
319 |
+
char *cfg = argv[3];
|
320 |
+
char *weights = (argc > 4) ? argv[4] : 0;
|
321 |
+
char *filename = (argc > 5) ? argv[5]: 0;
|
322 |
+
if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
|
323 |
+
else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
|
324 |
+
else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
|
325 |
+
else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
|
326 |
+
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, avg, .5, 0,0,0,0);
|
327 |
+
}
|
model/include/darknet.h
ADDED
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef DARKNET_API
|
2 |
+
#define DARKNET_API
|
3 |
+
#include <stdlib.h>
|
4 |
+
#include <stdio.h>
|
5 |
+
#include <string.h>
|
6 |
+
#include <pthread.h>
|
7 |
+
|
8 |
+
#ifdef GPU
|
9 |
+
#define BLOCK 512
|
10 |
+
|
11 |
+
#include "cuda_runtime.h"
|
12 |
+
#include "curand.h"
|
13 |
+
#include "cublas_v2.h"
|
14 |
+
|
15 |
+
#ifdef CUDNN
|
16 |
+
#include "cudnn.h"
|
17 |
+
#endif
|
18 |
+
#endif
|
19 |
+
|
20 |
+
#ifdef __cplusplus
|
21 |
+
extern "C" {
|
22 |
+
#endif
|
23 |
+
|
24 |
+
#define SECRET_NUM -1234
|
25 |
+
extern int gpu_index;
|
26 |
+
|
27 |
+
typedef struct{
|
28 |
+
int classes;
|
29 |
+
char **names;
|
30 |
+
} metadata;
|
31 |
+
|
32 |
+
metadata get_metadata(char *file);
|
33 |
+
|
34 |
+
typedef struct{
|
35 |
+
int *leaf;
|
36 |
+
int n;
|
37 |
+
int *parent;
|
38 |
+
int *child;
|
39 |
+
int *group;
|
40 |
+
char **name;
|
41 |
+
|
42 |
+
int groups;
|
43 |
+
int *group_size;
|
44 |
+
int *group_offset;
|
45 |
+
} tree;
|
46 |
+
tree *read_tree(char *filename);
|
47 |
+
|
48 |
+
typedef enum{
|
49 |
+
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU
|
50 |
+
} ACTIVATION;
|
51 |
+
|
52 |
+
typedef enum{
|
53 |
+
PNG, BMP, TGA, JPG
|
54 |
+
} IMTYPE;
|
55 |
+
|
56 |
+
typedef enum{
|
57 |
+
MULT, ADD, SUB, DIV
|
58 |
+
} BINARY_ACTIVATION;
|
59 |
+
|
60 |
+
typedef enum {
|
61 |
+
CONVOLUTIONAL,
|
62 |
+
DECONVOLUTIONAL,
|
63 |
+
CONNECTED,
|
64 |
+
MAXPOOL,
|
65 |
+
SOFTMAX,
|
66 |
+
DETECTION,
|
67 |
+
DROPOUT,
|
68 |
+
CROP,
|
69 |
+
ROUTE,
|
70 |
+
COST,
|
71 |
+
NORMALIZATION,
|
72 |
+
AVGPOOL,
|
73 |
+
LOCAL,
|
74 |
+
SHORTCUT,
|
75 |
+
ACTIVE,
|
76 |
+
RNN,
|
77 |
+
GRU,
|
78 |
+
LSTM,
|
79 |
+
CRNN,
|
80 |
+
BATCHNORM,
|
81 |
+
NETWORK,
|
82 |
+
XNOR,
|
83 |
+
REGION,
|
84 |
+
YOLO,
|
85 |
+
ISEG,
|
86 |
+
REORG,
|
87 |
+
UPSAMPLE,
|
88 |
+
LOGXENT,
|
89 |
+
L2NORM,
|
90 |
+
BLANK
|
91 |
+
} LAYER_TYPE;
|
92 |
+
|
93 |
+
typedef enum{
|
94 |
+
SSE, MASKED, L1, SEG, SMOOTH,WGAN
|
95 |
+
} COST_TYPE;
|
96 |
+
|
97 |
+
typedef struct{
|
98 |
+
int batch;
|
99 |
+
float learning_rate;
|
100 |
+
float momentum;
|
101 |
+
float decay;
|
102 |
+
int adam;
|
103 |
+
float B1;
|
104 |
+
float B2;
|
105 |
+
float eps;
|
106 |
+
int t;
|
107 |
+
} update_args;
|
108 |
+
|
109 |
+
struct network;
|
110 |
+
typedef struct network network;
|
111 |
+
|
112 |
+
struct layer;
|
113 |
+
typedef struct layer layer;
|
114 |
+
|
115 |
+
struct layer{
|
116 |
+
LAYER_TYPE type;
|
117 |
+
ACTIVATION activation;
|
118 |
+
COST_TYPE cost_type;
|
119 |
+
void (*forward) (struct layer, struct network);
|
120 |
+
void (*backward) (struct layer, struct network);
|
121 |
+
void (*update) (struct layer, update_args);
|
122 |
+
void (*forward_gpu) (struct layer, struct network);
|
123 |
+
void (*backward_gpu) (struct layer, struct network);
|
124 |
+
void (*update_gpu) (struct layer, update_args);
|
125 |
+
int batch_normalize;
|
126 |
+
int shortcut;
|
127 |
+
int batch;
|
128 |
+
int forced;
|
129 |
+
int flipped;
|
130 |
+
int inputs;
|
131 |
+
int outputs;
|
132 |
+
int nweights;
|
133 |
+
int nbiases;
|
134 |
+
int extra;
|
135 |
+
int truths;
|
136 |
+
int h,w,c;
|
137 |
+
int out_h, out_w, out_c;
|
138 |
+
int n;
|
139 |
+
int max_boxes;
|
140 |
+
int groups;
|
141 |
+
int size;
|
142 |
+
int side;
|
143 |
+
int stride;
|
144 |
+
int reverse;
|
145 |
+
int flatten;
|
146 |
+
int spatial;
|
147 |
+
int pad;
|
148 |
+
int sqrt;
|
149 |
+
int flip;
|
150 |
+
int index;
|
151 |
+
int binary;
|
152 |
+
int xnor;
|
153 |
+
int steps;
|
154 |
+
int hidden;
|
155 |
+
int truth;
|
156 |
+
float smooth;
|
157 |
+
float dot;
|
158 |
+
float angle;
|
159 |
+
float jitter;
|
160 |
+
float saturation;
|
161 |
+
float exposure;
|
162 |
+
float shift;
|
163 |
+
float ratio;
|
164 |
+
float learning_rate_scale;
|
165 |
+
float clip;
|
166 |
+
int noloss;
|
167 |
+
int softmax;
|
168 |
+
int classes;
|
169 |
+
int coords;
|
170 |
+
int background;
|
171 |
+
int rescore;
|
172 |
+
int objectness;
|
173 |
+
int joint;
|
174 |
+
int noadjust;
|
175 |
+
int reorg;
|
176 |
+
int log;
|
177 |
+
int tanh;
|
178 |
+
int *mask;
|
179 |
+
int total;
|
180 |
+
|
181 |
+
float alpha;
|
182 |
+
float beta;
|
183 |
+
float kappa;
|
184 |
+
|
185 |
+
float coord_scale;
|
186 |
+
float object_scale;
|
187 |
+
float noobject_scale;
|
188 |
+
float mask_scale;
|
189 |
+
float class_scale;
|
190 |
+
int bias_match;
|
191 |
+
int random;
|
192 |
+
float ignore_thresh;
|
193 |
+
float truth_thresh;
|
194 |
+
float thresh;
|
195 |
+
float focus;
|
196 |
+
int classfix;
|
197 |
+
int absolute;
|
198 |
+
|
199 |
+
int onlyforward;
|
200 |
+
int stopbackward;
|
201 |
+
int dontload;
|
202 |
+
int dontsave;
|
203 |
+
int dontloadscales;
|
204 |
+
int numload;
|
205 |
+
|
206 |
+
float temperature;
|
207 |
+
float probability;
|
208 |
+
float scale;
|
209 |
+
|
210 |
+
char * cweights;
|
211 |
+
int * indexes;
|
212 |
+
int * input_layers;
|
213 |
+
int * input_sizes;
|
214 |
+
int * map;
|
215 |
+
int * counts;
|
216 |
+
float ** sums;
|
217 |
+
float * rand;
|
218 |
+
float * cost;
|
219 |
+
float * state;
|
220 |
+
float * prev_state;
|
221 |
+
float * forgot_state;
|
222 |
+
float * forgot_delta;
|
223 |
+
float * state_delta;
|
224 |
+
float * combine_cpu;
|
225 |
+
float * combine_delta_cpu;
|
226 |
+
|
227 |
+
float * concat;
|
228 |
+
float * concat_delta;
|
229 |
+
|
230 |
+
float * binary_weights;
|
231 |
+
|
232 |
+
float * biases;
|
233 |
+
float * bias_updates;
|
234 |
+
|
235 |
+
float * scales;
|
236 |
+
float * scale_updates;
|
237 |
+
|
238 |
+
float * weights;
|
239 |
+
float * weight_updates;
|
240 |
+
|
241 |
+
float * delta;
|
242 |
+
float * output;
|
243 |
+
float * loss;
|
244 |
+
float * squared;
|
245 |
+
float * norms;
|
246 |
+
|
247 |
+
float * spatial_mean;
|
248 |
+
float * mean;
|
249 |
+
float * variance;
|
250 |
+
|
251 |
+
float * mean_delta;
|
252 |
+
float * variance_delta;
|
253 |
+
|
254 |
+
float * rolling_mean;
|
255 |
+
float * rolling_variance;
|
256 |
+
|
257 |
+
float * x;
|
258 |
+
float * x_norm;
|
259 |
+
|
260 |
+
float * m;
|
261 |
+
float * v;
|
262 |
+
|
263 |
+
float * bias_m;
|
264 |
+
float * bias_v;
|
265 |
+
float * scale_m;
|
266 |
+
float * scale_v;
|
267 |
+
|
268 |
+
|
269 |
+
float *z_cpu;
|
270 |
+
float *r_cpu;
|
271 |
+
float *h_cpu;
|
272 |
+
float * prev_state_cpu;
|
273 |
+
|
274 |
+
float *temp_cpu;
|
275 |
+
float *temp2_cpu;
|
276 |
+
float *temp3_cpu;
|
277 |
+
|
278 |
+
float *dh_cpu;
|
279 |
+
float *hh_cpu;
|
280 |
+
float *prev_cell_cpu;
|
281 |
+
float *cell_cpu;
|
282 |
+
float *f_cpu;
|
283 |
+
float *i_cpu;
|
284 |
+
float *g_cpu;
|
285 |
+
float *o_cpu;
|
286 |
+
float *c_cpu;
|
287 |
+
float *dc_cpu;
|
288 |
+
|
289 |
+
float * binary_input;
|
290 |
+
|
291 |
+
struct layer *input_layer;
|
292 |
+
struct layer *self_layer;
|
293 |
+
struct layer *output_layer;
|
294 |
+
|
295 |
+
struct layer *reset_layer;
|
296 |
+
struct layer *update_layer;
|
297 |
+
struct layer *state_layer;
|
298 |
+
|
299 |
+
struct layer *input_gate_layer;
|
300 |
+
struct layer *state_gate_layer;
|
301 |
+
struct layer *input_save_layer;
|
302 |
+
struct layer *state_save_layer;
|
303 |
+
struct layer *input_state_layer;
|
304 |
+
struct layer *state_state_layer;
|
305 |
+
|
306 |
+
struct layer *input_z_layer;
|
307 |
+
struct layer *state_z_layer;
|
308 |
+
|
309 |
+
struct layer *input_r_layer;
|
310 |
+
struct layer *state_r_layer;
|
311 |
+
|
312 |
+
struct layer *input_h_layer;
|
313 |
+
struct layer *state_h_layer;
|
314 |
+
|
315 |
+
struct layer *wz;
|
316 |
+
struct layer *uz;
|
317 |
+
struct layer *wr;
|
318 |
+
struct layer *ur;
|
319 |
+
struct layer *wh;
|
320 |
+
struct layer *uh;
|
321 |
+
struct layer *uo;
|
322 |
+
struct layer *wo;
|
323 |
+
struct layer *uf;
|
324 |
+
struct layer *wf;
|
325 |
+
struct layer *ui;
|
326 |
+
struct layer *wi;
|
327 |
+
struct layer *ug;
|
328 |
+
struct layer *wg;
|
329 |
+
|
330 |
+
tree *softmax_tree;
|
331 |
+
|
332 |
+
size_t workspace_size;
|
333 |
+
|
334 |
+
#ifdef GPU
|
335 |
+
int *indexes_gpu;
|
336 |
+
|
337 |
+
float *z_gpu;
|
338 |
+
float *r_gpu;
|
339 |
+
float *h_gpu;
|
340 |
+
|
341 |
+
float *temp_gpu;
|
342 |
+
float *temp2_gpu;
|
343 |
+
float *temp3_gpu;
|
344 |
+
|
345 |
+
float *dh_gpu;
|
346 |
+
float *hh_gpu;
|
347 |
+
float *prev_cell_gpu;
|
348 |
+
float *cell_gpu;
|
349 |
+
float *f_gpu;
|
350 |
+
float *i_gpu;
|
351 |
+
float *g_gpu;
|
352 |
+
float *o_gpu;
|
353 |
+
float *c_gpu;
|
354 |
+
float *dc_gpu;
|
355 |
+
|
356 |
+
float *m_gpu;
|
357 |
+
float *v_gpu;
|
358 |
+
float *bias_m_gpu;
|
359 |
+
float *scale_m_gpu;
|
360 |
+
float *bias_v_gpu;
|
361 |
+
float *scale_v_gpu;
|
362 |
+
|
363 |
+
float * combine_gpu;
|
364 |
+
float * combine_delta_gpu;
|
365 |
+
|
366 |
+
float * prev_state_gpu;
|
367 |
+
float * forgot_state_gpu;
|
368 |
+
float * forgot_delta_gpu;
|
369 |
+
float * state_gpu;
|
370 |
+
float * state_delta_gpu;
|
371 |
+
float * gate_gpu;
|
372 |
+
float * gate_delta_gpu;
|
373 |
+
float * save_gpu;
|
374 |
+
float * save_delta_gpu;
|
375 |
+
float * concat_gpu;
|
376 |
+
float * concat_delta_gpu;
|
377 |
+
|
378 |
+
float * binary_input_gpu;
|
379 |
+
float * binary_weights_gpu;
|
380 |
+
|
381 |
+
float * mean_gpu;
|
382 |
+
float * variance_gpu;
|
383 |
+
|
384 |
+
float * rolling_mean_gpu;
|
385 |
+
float * rolling_variance_gpu;
|
386 |
+
|
387 |
+
float * variance_delta_gpu;
|
388 |
+
float * mean_delta_gpu;
|
389 |
+
|
390 |
+
float * x_gpu;
|
391 |
+
float * x_norm_gpu;
|
392 |
+
float * weights_gpu;
|
393 |
+
float * weight_updates_gpu;
|
394 |
+
float * weight_change_gpu;
|
395 |
+
|
396 |
+
float * biases_gpu;
|
397 |
+
float * bias_updates_gpu;
|
398 |
+
float * bias_change_gpu;
|
399 |
+
|
400 |
+
float * scales_gpu;
|
401 |
+
float * scale_updates_gpu;
|
402 |
+
float * scale_change_gpu;
|
403 |
+
|
404 |
+
float * output_gpu;
|
405 |
+
float * loss_gpu;
|
406 |
+
float * delta_gpu;
|
407 |
+
float * rand_gpu;
|
408 |
+
float * squared_gpu;
|
409 |
+
float * norms_gpu;
|
410 |
+
#ifdef CUDNN
|
411 |
+
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
412 |
+
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
413 |
+
cudnnTensorDescriptor_t normTensorDesc;
|
414 |
+
cudnnFilterDescriptor_t weightDesc;
|
415 |
+
cudnnFilterDescriptor_t dweightDesc;
|
416 |
+
cudnnConvolutionDescriptor_t convDesc;
|
417 |
+
cudnnConvolutionFwdAlgo_t fw_algo;
|
418 |
+
cudnnConvolutionBwdDataAlgo_t bd_algo;
|
419 |
+
cudnnConvolutionBwdFilterAlgo_t bf_algo;
|
420 |
+
#endif
|
421 |
+
#endif
|
422 |
+
};
|
423 |
+
|
424 |
+
void free_layer(layer);
|
425 |
+
|
426 |
+
typedef enum {
|
427 |
+
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
|
428 |
+
} learning_rate_policy;
|
429 |
+
|
430 |
+
typedef struct network{
|
431 |
+
int n;
|
432 |
+
int batch;
|
433 |
+
size_t *seen;
|
434 |
+
int *t;
|
435 |
+
float epoch;
|
436 |
+
int subdivisions;
|
437 |
+
layer *layers;
|
438 |
+
float *output;
|
439 |
+
learning_rate_policy policy;
|
440 |
+
|
441 |
+
float learning_rate;
|
442 |
+
float momentum;
|
443 |
+
float decay;
|
444 |
+
float gamma;
|
445 |
+
float scale;
|
446 |
+
float power;
|
447 |
+
int time_steps;
|
448 |
+
int step;
|
449 |
+
int max_batches;
|
450 |
+
float *scales;
|
451 |
+
int *steps;
|
452 |
+
int num_steps;
|
453 |
+
int burn_in;
|
454 |
+
|
455 |
+
int adam;
|
456 |
+
float B1;
|
457 |
+
float B2;
|
458 |
+
float eps;
|
459 |
+
|
460 |
+
int inputs;
|
461 |
+
int outputs;
|
462 |
+
int truths;
|
463 |
+
int notruth;
|
464 |
+
int h, w, c;
|
465 |
+
int max_crop;
|
466 |
+
int min_crop;
|
467 |
+
float max_ratio;
|
468 |
+
float min_ratio;
|
469 |
+
int center;
|
470 |
+
float angle;
|
471 |
+
float aspect;
|
472 |
+
float exposure;
|
473 |
+
float saturation;
|
474 |
+
float hue;
|
475 |
+
int random;
|
476 |
+
|
477 |
+
int gpu_index;
|
478 |
+
tree *hierarchy;
|
479 |
+
|
480 |
+
float *input;
|
481 |
+
float *truth;
|
482 |
+
float *delta;
|
483 |
+
float *workspace;
|
484 |
+
int train;
|
485 |
+
int index;
|
486 |
+
float *cost;
|
487 |
+
float clip;
|
488 |
+
|
489 |
+
#ifdef GPU
|
490 |
+
float *input_gpu;
|
491 |
+
float *truth_gpu;
|
492 |
+
float *delta_gpu;
|
493 |
+
float *output_gpu;
|
494 |
+
#endif
|
495 |
+
|
496 |
+
} network;
|
497 |
+
|
498 |
+
typedef struct {
|
499 |
+
int w;
|
500 |
+
int h;
|
501 |
+
float scale;
|
502 |
+
float rad;
|
503 |
+
float dx;
|
504 |
+
float dy;
|
505 |
+
float aspect;
|
506 |
+
} augment_args;
|
507 |
+
|
508 |
+
typedef struct {
|
509 |
+
int w;
|
510 |
+
int h;
|
511 |
+
int c;
|
512 |
+
float *data;
|
513 |
+
} image;
|
514 |
+
|
515 |
+
typedef struct{
|
516 |
+
float x, y, w, h;
|
517 |
+
} box;
|
518 |
+
|
519 |
+
typedef struct detection{
|
520 |
+
box bbox;
|
521 |
+
int classes;
|
522 |
+
float *prob;
|
523 |
+
float *mask;
|
524 |
+
float objectness;
|
525 |
+
int sort_class;
|
526 |
+
} detection;
|
527 |
+
|
528 |
+
typedef struct matrix{
|
529 |
+
int rows, cols;
|
530 |
+
float **vals;
|
531 |
+
} matrix;
|
532 |
+
|
533 |
+
|
534 |
+
typedef struct{
|
535 |
+
int w, h;
|
536 |
+
matrix X;
|
537 |
+
matrix y;
|
538 |
+
int shallow;
|
539 |
+
int *num_boxes;
|
540 |
+
box **boxes;
|
541 |
+
} data;
|
542 |
+
|
543 |
+
typedef enum {
|
544 |
+
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA, ISEG_DATA
|
545 |
+
} data_type;
|
546 |
+
|
547 |
+
typedef struct load_args{
|
548 |
+
int threads;
|
549 |
+
char **paths;
|
550 |
+
char *path;
|
551 |
+
int n;
|
552 |
+
int m;
|
553 |
+
char **labels;
|
554 |
+
int h;
|
555 |
+
int w;
|
556 |
+
int out_w;
|
557 |
+
int out_h;
|
558 |
+
int nh;
|
559 |
+
int nw;
|
560 |
+
int num_boxes;
|
561 |
+
int min, max, size;
|
562 |
+
int classes;
|
563 |
+
int background;
|
564 |
+
int scale;
|
565 |
+
int center;
|
566 |
+
int coords;
|
567 |
+
float jitter;
|
568 |
+
float angle;
|
569 |
+
float aspect;
|
570 |
+
float saturation;
|
571 |
+
float exposure;
|
572 |
+
float hue;
|
573 |
+
data *d;
|
574 |
+
image *im;
|
575 |
+
image *resized;
|
576 |
+
data_type type;
|
577 |
+
tree *hierarchy;
|
578 |
+
} load_args;
|
579 |
+
|
580 |
+
typedef struct{
|
581 |
+
int id;
|
582 |
+
float x,y,w,h;
|
583 |
+
float left, right, top, bottom;
|
584 |
+
} box_label;
|
585 |
+
|
586 |
+
|
587 |
+
network *load_network(char *cfg, char *weights, int clear);
|
588 |
+
load_args get_base_args(network *net);
|
589 |
+
|
590 |
+
void free_data(data d);
|
591 |
+
|
592 |
+
typedef struct node{
|
593 |
+
void *val;
|
594 |
+
struct node *next;
|
595 |
+
struct node *prev;
|
596 |
+
} node;
|
597 |
+
|
598 |
+
typedef struct list{
|
599 |
+
int size;
|
600 |
+
node *front;
|
601 |
+
node *back;
|
602 |
+
} list;
|
603 |
+
|
604 |
+
pthread_t load_data(load_args args);
|
605 |
+
list *read_data_cfg(char *filename);
|
606 |
+
list *read_cfg(char *filename);
|
607 |
+
unsigned char *read_file(char *filename);
|
608 |
+
data resize_data(data orig, int w, int h);
|
609 |
+
data *tile_data(data orig, int divs, int size);
|
610 |
+
data select_data(data *orig, int *inds);
|
611 |
+
|
612 |
+
void forward_network(network *net);
|
613 |
+
void backward_network(network *net);
|
614 |
+
void update_network(network *net);
|
615 |
+
|
616 |
+
|
617 |
+
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
618 |
+
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
|
619 |
+
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
620 |
+
void scal_cpu(int N, float ALPHA, float *X, int INCX);
|
621 |
+
void fill_cpu(int N, float ALPHA, float * X, int INCX);
|
622 |
+
void normalize_cpu(float *x, float *mean, float *variance, int batch, int filters, int spatial);
|
623 |
+
void softmax(float *input, int n, float temp, int stride, float *output);
|
624 |
+
|
625 |
+
int best_3d_shift_r(image a, image b, int min, int max);
|
626 |
+
#ifdef GPU
|
627 |
+
void axpy_gpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
|
628 |
+
void fill_gpu(int N, float ALPHA, float * X, int INCX);
|
629 |
+
void scal_gpu(int N, float ALPHA, float * X, int INCX);
|
630 |
+
void copy_gpu(int N, float * X, int INCX, float * Y, int INCY);
|
631 |
+
|
632 |
+
void cuda_set_device(int n);
|
633 |
+
void cuda_free(float *x_gpu);
|
634 |
+
float *cuda_make_array(float *x, size_t n);
|
635 |
+
void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
636 |
+
float cuda_mag_array(float *x_gpu, size_t n);
|
637 |
+
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
638 |
+
|
639 |
+
void forward_network_gpu(network *net);
|
640 |
+
void backward_network_gpu(network *net);
|
641 |
+
void update_network_gpu(network *net);
|
642 |
+
|
643 |
+
float train_networks(network **nets, int n, data d, int interval);
|
644 |
+
void sync_nets(network **nets, int n, int interval);
|
645 |
+
void harmless_update_network_gpu(network *net);
|
646 |
+
#endif
|
647 |
+
image get_label(image **characters, char *string, int size);
|
648 |
+
void draw_label(image a, int r, int c, image label, const float *rgb);
|
649 |
+
void save_image(image im, const char *name);
|
650 |
+
void save_image_options(image im, const char *name, IMTYPE f, int quality);
|
651 |
+
void get_next_batch(data d, int n, int offset, float *X, float *y);
|
652 |
+
void grayscale_image_3c(image im);
|
653 |
+
void normalize_image(image p);
|
654 |
+
void matrix_to_csv(matrix m);
|
655 |
+
float train_network_sgd(network *net, data d, int n);
|
656 |
+
void rgbgr_image(image im);
|
657 |
+
data copy_data(data d);
|
658 |
+
data concat_data(data d1, data d2);
|
659 |
+
data load_cifar10_data(char *filename);
|
660 |
+
float matrix_topk_accuracy(matrix truth, matrix guess, int k);
|
661 |
+
void matrix_add_matrix(matrix from, matrix to);
|
662 |
+
void scale_matrix(matrix m, float scale);
|
663 |
+
matrix csv_to_matrix(char *filename);
|
664 |
+
float *network_accuracies(network *net, data d, int n);
|
665 |
+
float train_network_datum(network *net);
|
666 |
+
image make_random_image(int w, int h, int c);
|
667 |
+
|
668 |
+
void denormalize_connected_layer(layer l);
|
669 |
+
void denormalize_convolutional_layer(layer l);
|
670 |
+
void statistics_connected_layer(layer l);
|
671 |
+
void rescale_weights(layer l, float scale, float trans);
|
672 |
+
void rgbgr_weights(layer l);
|
673 |
+
image *get_weights(layer l);
|
674 |
+
|
675 |
+
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, int avg, float hier_thresh, int w, int h, int fps, int fullscreen);
|
676 |
+
void get_detection_detections(layer l, int w, int h, float thresh, detection *dets);
|
677 |
+
|
678 |
+
char *option_find_str(list *l, char *key, char *def);
|
679 |
+
int option_find_int(list *l, char *key, int def);
|
680 |
+
int option_find_int_quiet(list *l, char *key, int def);
|
681 |
+
|
682 |
+
network *parse_network_cfg(char *filename);
|
683 |
+
void save_weights(network *net, char *filename);
|
684 |
+
void load_weights(network *net, char *filename);
|
685 |
+
void save_weights_upto(network *net, char *filename, int cutoff);
|
686 |
+
void load_weights_upto(network *net, char *filename, int start, int cutoff);
|
687 |
+
|
688 |
+
void zero_objectness(layer l);
|
689 |
+
void get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets);
|
690 |
+
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets);
|
691 |
+
void free_network(network *net);
|
692 |
+
void set_batch_network(network *net, int b);
|
693 |
+
void set_temp_network(network *net, float t);
|
694 |
+
image load_image(char *filename, int w, int h, int c);
|
695 |
+
image load_image_color(char *filename, int w, int h);
|
696 |
+
image make_image(int w, int h, int c);
|
697 |
+
image resize_image(image im, int w, int h);
|
698 |
+
void censor_image(image im, int dx, int dy, int w, int h);
|
699 |
+
image letterbox_image(image im, int w, int h);
|
700 |
+
image crop_image(image im, int dx, int dy, int w, int h);
|
701 |
+
image center_crop_image(image im, int w, int h);
|
702 |
+
image resize_min(image im, int min);
|
703 |
+
image resize_max(image im, int max);
|
704 |
+
image threshold_image(image im, float thresh);
|
705 |
+
image mask_to_rgb(image mask);
|
706 |
+
int resize_network(network *net, int w, int h);
|
707 |
+
void free_matrix(matrix m);
|
708 |
+
void test_resize(char *filename);
|
709 |
+
int show_image(image p, const char *name, int ms);
|
710 |
+
image copy_image(image p);
|
711 |
+
void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);
|
712 |
+
float get_current_rate(network *net);
|
713 |
+
void composite_3d(char *f1, char *f2, char *out, int delta);
|
714 |
+
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
|
715 |
+
size_t get_current_batch(network *net);
|
716 |
+
void constrain_image(image im);
|
717 |
+
image get_network_image_layer(network *net, int i);
|
718 |
+
layer get_network_output_layer(network *net);
|
719 |
+
void top_predictions(network *net, int n, int *index);
|
720 |
+
void flip_image(image a);
|
721 |
+
image float_to_image(int w, int h, int c, float *data);
|
722 |
+
void ghost_image(image source, image dest, int dx, int dy);
|
723 |
+
float network_accuracy(network *net, data d);
|
724 |
+
void random_distort_image(image im, float hue, float saturation, float exposure);
|
725 |
+
void fill_image(image m, float s);
|
726 |
+
image grayscale_image(image im);
|
727 |
+
void rotate_image_cw(image im, int times);
|
728 |
+
double what_time_is_it_now();
|
729 |
+
image rotate_image(image m, float rad);
|
730 |
+
void visualize_network(network *net);
|
731 |
+
float box_iou(box a, box b);
|
732 |
+
data load_all_cifar10();
|
733 |
+
box_label *read_boxes(char *filename, int *n);
|
734 |
+
box float_to_box(float *f, int stride);
|
735 |
+
void draw_detections(image im, detection *dets, int num, float thresh, char **names, image **alphabet, int classes);
|
736 |
+
|
737 |
+
matrix network_predict_data(network *net, data test);
|
738 |
+
image **load_alphabet();
|
739 |
+
image get_network_image(network *net);
|
740 |
+
float *network_predict(network *net, float *input);
|
741 |
+
|
742 |
+
int network_width(network *net);
|
743 |
+
int network_height(network *net);
|
744 |
+
float *network_predict_image(network *net, image im);
|
745 |
+
void network_detect(network *net, image im, float thresh, float hier_thresh, float nms, detection *dets);
|
746 |
+
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num);
|
747 |
+
void free_detections(detection *dets, int n);
|
748 |
+
|
749 |
+
void reset_network_state(network *net, int b);
|
750 |
+
|
751 |
+
char **get_labels(char *filename);
|
752 |
+
void do_nms_obj(detection *dets, int total, int classes, float thresh);
|
753 |
+
void do_nms_sort(detection *dets, int total, int classes, float thresh);
|
754 |
+
|
755 |
+
matrix make_matrix(int rows, int cols);
|
756 |
+
|
757 |
+
#ifdef OPENCV
|
758 |
+
void *open_video_stream(const char *f, int c, int w, int h, int fps);
|
759 |
+
image get_image_from_stream(void *p);
|
760 |
+
void make_window(char *name, int w, int h, int fullscreen);
|
761 |
+
#endif
|
762 |
+
|
763 |
+
void free_image(image m);
|
764 |
+
float train_network(network *net, data d);
|
765 |
+
pthread_t load_data_in_thread(load_args args);
|
766 |
+
void load_data_blocking(load_args args);
|
767 |
+
list *get_paths(char *filename);
|
768 |
+
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves, int stride);
|
769 |
+
void change_leaves(tree *t, char *leaf_list);
|
770 |
+
|
771 |
+
int find_int_arg(int argc, char **argv, char *arg, int def);
|
772 |
+
float find_float_arg(int argc, char **argv, char *arg, float def);
|
773 |
+
int find_arg(int argc, char* argv[], char *arg);
|
774 |
+
char *find_char_arg(int argc, char **argv, char *arg, char *def);
|
775 |
+
char *basecfg(char *cfgfile);
|
776 |
+
void find_replace(char *str, char *orig, char *rep, char *output);
|
777 |
+
void free_ptrs(void **ptrs, int n);
|
778 |
+
char *fgetl(FILE *fp);
|
779 |
+
void strip(char *s);
|
780 |
+
float sec(clock_t clocks);
|
781 |
+
void **list_to_array(list *l);
|
782 |
+
void top_k(float *a, int n, int k, int *index);
|
783 |
+
int *read_map(char *filename);
|
784 |
+
void error(const char *s);
|
785 |
+
int max_index(float *a, int n);
|
786 |
+
int max_int_index(int *a, int n);
|
787 |
+
int sample_array(float *a, int n);
|
788 |
+
int *random_index_order(int min, int max);
|
789 |
+
void free_list(list *l);
|
790 |
+
float mse_array(float *a, int n);
|
791 |
+
float variance_array(float *a, int n);
|
792 |
+
float mag_array(float *a, int n);
|
793 |
+
void scale_array(float *a, int n, float s);
|
794 |
+
float mean_array(float *a, int n);
|
795 |
+
float sum_array(float *a, int n);
|
796 |
+
void normalize_array(float *a, int n);
|
797 |
+
int *read_intlist(char *s, int *n, int d);
|
798 |
+
size_t rand_size_t();
|
799 |
+
float rand_normal();
|
800 |
+
float rand_uniform(float min, float max);
|
801 |
+
|
802 |
+
#ifdef __cplusplus
|
803 |
+
}
|
804 |
+
#endif
|
805 |
+
#endif
|
model/libdarknet.a
ADDED
Binary file (637 kB). View file
|
|
model/libdarknet.so
ADDED
Binary file (512 kB). View file
|
|
model/python/darknet.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ctypes import *
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
|
5 |
+
def sample(probs):
|
6 |
+
s = sum(probs)
|
7 |
+
probs = [a/s for a in probs]
|
8 |
+
r = random.uniform(0, 1)
|
9 |
+
for i in range(len(probs)):
|
10 |
+
r = r - probs[i]
|
11 |
+
if r <= 0:
|
12 |
+
return i
|
13 |
+
return len(probs)-1
|
14 |
+
|
15 |
+
def c_array(ctype, values):
|
16 |
+
arr = (ctype*len(values))()
|
17 |
+
arr[:] = values
|
18 |
+
return arr
|
19 |
+
|
20 |
+
class BOX(Structure):
|
21 |
+
_fields_ = [("x", c_float),
|
22 |
+
("y", c_float),
|
23 |
+
("w", c_float),
|
24 |
+
("h", c_float)]
|
25 |
+
|
26 |
+
class DETECTION(Structure):
|
27 |
+
_fields_ = [("bbox", BOX),
|
28 |
+
("classes", c_int),
|
29 |
+
("prob", POINTER(c_float)),
|
30 |
+
("mask", POINTER(c_float)),
|
31 |
+
("objectness", c_float),
|
32 |
+
("sort_class", c_int)]
|
33 |
+
|
34 |
+
|
35 |
+
class IMAGE(Structure):
|
36 |
+
_fields_ = [("w", c_int),
|
37 |
+
("h", c_int),
|
38 |
+
("c", c_int),
|
39 |
+
("data", POINTER(c_float))]
|
40 |
+
|
41 |
+
class METADATA(Structure):
|
42 |
+
_fields_ = [("classes", c_int),
|
43 |
+
("names", POINTER(c_char_p))]
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
|
48 |
+
lib = CDLL("libdarknet.so", RTLD_GLOBAL)
|
49 |
+
lib.network_width.argtypes = [c_void_p]
|
50 |
+
lib.network_width.restype = c_int
|
51 |
+
lib.network_height.argtypes = [c_void_p]
|
52 |
+
lib.network_height.restype = c_int
|
53 |
+
|
54 |
+
predict = lib.network_predict
|
55 |
+
predict.argtypes = [c_void_p, POINTER(c_float)]
|
56 |
+
predict.restype = POINTER(c_float)
|
57 |
+
|
58 |
+
set_gpu = lib.cuda_set_device
|
59 |
+
set_gpu.argtypes = [c_int]
|
60 |
+
|
61 |
+
make_image = lib.make_image
|
62 |
+
make_image.argtypes = [c_int, c_int, c_int]
|
63 |
+
make_image.restype = IMAGE
|
64 |
+
|
65 |
+
get_network_boxes = lib.get_network_boxes
|
66 |
+
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
|
67 |
+
get_network_boxes.restype = POINTER(DETECTION)
|
68 |
+
|
69 |
+
make_network_boxes = lib.make_network_boxes
|
70 |
+
make_network_boxes.argtypes = [c_void_p]
|
71 |
+
make_network_boxes.restype = POINTER(DETECTION)
|
72 |
+
|
73 |
+
free_detections = lib.free_detections
|
74 |
+
free_detections.argtypes = [POINTER(DETECTION), c_int]
|
75 |
+
|
76 |
+
free_ptrs = lib.free_ptrs
|
77 |
+
free_ptrs.argtypes = [POINTER(c_void_p), c_int]
|
78 |
+
|
79 |
+
network_predict = lib.network_predict
|
80 |
+
network_predict.argtypes = [c_void_p, POINTER(c_float)]
|
81 |
+
|
82 |
+
reset_rnn = lib.reset_rnn
|
83 |
+
reset_rnn.argtypes = [c_void_p]
|
84 |
+
|
85 |
+
load_net = lib.load_network
|
86 |
+
load_net.argtypes = [c_char_p, c_char_p, c_int]
|
87 |
+
load_net.restype = c_void_p
|
88 |
+
|
89 |
+
do_nms_obj = lib.do_nms_obj
|
90 |
+
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
|
91 |
+
|
92 |
+
do_nms_sort = lib.do_nms_sort
|
93 |
+
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
|
94 |
+
|
95 |
+
free_image = lib.free_image
|
96 |
+
free_image.argtypes = [IMAGE]
|
97 |
+
|
98 |
+
letterbox_image = lib.letterbox_image
|
99 |
+
letterbox_image.argtypes = [IMAGE, c_int, c_int]
|
100 |
+
letterbox_image.restype = IMAGE
|
101 |
+
|
102 |
+
load_meta = lib.get_metadata
|
103 |
+
lib.get_metadata.argtypes = [c_char_p]
|
104 |
+
lib.get_metadata.restype = METADATA
|
105 |
+
|
106 |
+
load_image = lib.load_image_color
|
107 |
+
load_image.argtypes = [c_char_p, c_int, c_int]
|
108 |
+
load_image.restype = IMAGE
|
109 |
+
|
110 |
+
rgbgr_image = lib.rgbgr_image
|
111 |
+
rgbgr_image.argtypes = [IMAGE]
|
112 |
+
|
113 |
+
predict_image = lib.network_predict_image
|
114 |
+
predict_image.argtypes = [c_void_p, IMAGE]
|
115 |
+
predict_image.restype = POINTER(c_float)
|
116 |
+
|
117 |
+
def classify(net, meta, im):
|
118 |
+
out = predict_image(net, im)
|
119 |
+
res = []
|
120 |
+
for i in range(meta.classes):
|
121 |
+
res.append((meta.names[i], out[i]))
|
122 |
+
res = sorted(res, key=lambda x: -x[1])
|
123 |
+
return res
|
124 |
+
|
125 |
+
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
|
126 |
+
im = load_image(image, 0, 0)
|
127 |
+
num = c_int(0)
|
128 |
+
pnum = pointer(num)
|
129 |
+
predict_image(net, im)
|
130 |
+
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
|
131 |
+
num = pnum[0]
|
132 |
+
if (nms): do_nms_obj(dets, num, meta.classes, nms);
|
133 |
+
|
134 |
+
res = []
|
135 |
+
for j in range(num):
|
136 |
+
for i in range(meta.classes):
|
137 |
+
if dets[j].prob[i] > 0:
|
138 |
+
b = dets[j].bbox
|
139 |
+
res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
|
140 |
+
res = sorted(res, key=lambda x: -x[1])
|
141 |
+
free_image(im)
|
142 |
+
free_detections(dets, num)
|
143 |
+
return res
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
|
147 |
+
#im = load_image("data/wolf.jpg", 0, 0)
|
148 |
+
#meta = load_meta("cfg/imagenet1k.data")
|
149 |
+
#r = classify(net, meta, im)
|
150 |
+
#print r[:10]
|
151 |
+
net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0)
|
152 |
+
meta = load_meta("cfg/coco.data")
|
153 |
+
r = detect(net, meta, "data/dog.jpg")
|
154 |
+
print(r)
|
155 |
+
|
156 |
+
|
model/python/proverbot.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from darknet import *
|
2 |
+
|
3 |
+
def predict_tactic(net, s):
|
4 |
+
prob = 0
|
5 |
+
d = c_array(c_float, [0.0]*256)
|
6 |
+
tac = ''
|
7 |
+
if not len(s):
|
8 |
+
s = '\n'
|
9 |
+
for c in s[:-1]:
|
10 |
+
d[ord(c)] = 1
|
11 |
+
pred = predict(net, d)
|
12 |
+
d[ord(c)] = 0
|
13 |
+
c = s[-1]
|
14 |
+
while 1:
|
15 |
+
d[ord(c)] = 1
|
16 |
+
pred = predict(net, d)
|
17 |
+
d[ord(c)] = 0
|
18 |
+
pred = [pred[i] for i in range(256)]
|
19 |
+
ind = sample(pred)
|
20 |
+
c = chr(ind)
|
21 |
+
prob += math.log(pred[ind])
|
22 |
+
if len(tac) and tac[-1] == '.':
|
23 |
+
break
|
24 |
+
tac = tac + c
|
25 |
+
return (tac, prob)
|
26 |
+
|
27 |
+
def predict_tactics(net, s, n):
|
28 |
+
tacs = []
|
29 |
+
for i in range(n):
|
30 |
+
reset_rnn(net)
|
31 |
+
tacs.append(predict_tactic(net, s))
|
32 |
+
tacs = sorted(tacs, key=lambda x: -x[1])
|
33 |
+
return tacs
|
34 |
+
|
35 |
+
net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
|
36 |
+
t = predict_tactics(net, "+++++\n", 10)
|
37 |
+
print t
|
model/scripts/dice_label.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p images
|
2 |
+
mkdir -p images/orig
|
3 |
+
mkdir -p images/train
|
4 |
+
mkdir -p images/val
|
5 |
+
|
6 |
+
ffmpeg -i Face1.mp4 images/orig/face1_%6d.jpg
|
7 |
+
ffmpeg -i Face2.mp4 images/orig/face2_%6d.jpg
|
8 |
+
ffmpeg -i Face3.mp4 images/orig/face3_%6d.jpg
|
9 |
+
ffmpeg -i Face4.mp4 images/orig/face4_%6d.jpg
|
10 |
+
ffmpeg -i Face5.mp4 images/orig/face5_%6d.jpg
|
11 |
+
ffmpeg -i Face6.mp4 images/orig/face6_%6d.jpg
|
12 |
+
|
13 |
+
mogrify -resize 100x100^ -gravity center -crop 100x100+0+0 +repage images/orig/*
|
14 |
+
|
15 |
+
ls images/orig/* | shuf | head -n 1000 | xargs mv -t images/val
|
16 |
+
mv images/orig/* images/train
|
17 |
+
|
18 |
+
find `pwd`/images/train > dice.train.list -name \*.jpg
|
19 |
+
find `pwd`/images/val > dice.val.list -name \*.jpg
|
20 |
+
|
model/scripts/gen_tactic.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Usage:
|
3 |
+
# wget http://pjreddie.com/media/files/peek.weights
|
4 |
+
# scripts/gen_tactic.sh < data/goal.txt
|
5 |
+
./darknet rnn generatetactic cfg/gru.cfg peek.weights 2>/dev/null
|
model/scripts/get_coco_dataset.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Clone COCO API
|
4 |
+
git clone https://github.com/pdollar/coco
|
5 |
+
cd coco
|
6 |
+
|
7 |
+
mkdir images
|
8 |
+
cd images
|
9 |
+
|
10 |
+
# Download Images
|
11 |
+
wget -c https://pjreddie.com/media/files/train2014.zip
|
12 |
+
wget -c https://pjreddie.com/media/files/val2014.zip
|
13 |
+
|
14 |
+
# Unzip
|
15 |
+
unzip -q train2014.zip
|
16 |
+
unzip -q val2014.zip
|
17 |
+
|
18 |
+
cd ..
|
19 |
+
|
20 |
+
# Download COCO Metadata
|
21 |
+
wget -c https://pjreddie.com/media/files/instances_train-val2014.zip
|
22 |
+
wget -c https://pjreddie.com/media/files/coco/5k.part
|
23 |
+
wget -c https://pjreddie.com/media/files/coco/trainvalno5k.part
|
24 |
+
wget -c https://pjreddie.com/media/files/coco/labels.tgz
|
25 |
+
tar xzf labels.tgz
|
26 |
+
unzip -q instances_train-val2014.zip
|
27 |
+
|
28 |
+
# Set Up Image Lists
|
29 |
+
paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt
|
30 |
+
paste <(awk "{print \"$PWD\"}" <trainvalno5k.part) trainvalno5k.part | tr -d '\t' > trainvalno5k.txt
|
31 |
+
|
model/scripts/imagenet_label.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
mkdir -p labelled
|
4 |
+
wd=`pwd`
|
5 |
+
|
6 |
+
for f in val/*.xml;
|
7 |
+
do
|
8 |
+
label=`grep -m1 "<name>" $f | grep -oP '<name>\K[^<]*'`
|
9 |
+
im=`echo $f | sed 's/val/imgs/; s/xml/JPEG/'`
|
10 |
+
out=`echo $im | sed 's/JPEG/'${label}'.JPEG/; s/imgs/labelled/'`
|
11 |
+
ln -s ${wd}/$im ${wd}/$out
|
12 |
+
done
|
13 |
+
|
14 |
+
find ${wd}/labelled -name \*.JPEG > inet.val.list
|
15 |
+
|
model/scripts/voc_label.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xml.etree.ElementTree as ET
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
from os import listdir, getcwd
|
5 |
+
from os.path import join
|
6 |
+
|
7 |
+
sets=[('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
|
8 |
+
|
9 |
+
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
|
10 |
+
|
11 |
+
|
12 |
+
def convert(size, box):
|
13 |
+
dw = 1./(size[0])
|
14 |
+
dh = 1./(size[1])
|
15 |
+
x = (box[0] + box[1])/2.0 - 1
|
16 |
+
y = (box[2] + box[3])/2.0 - 1
|
17 |
+
w = box[1] - box[0]
|
18 |
+
h = box[3] - box[2]
|
19 |
+
x = x*dw
|
20 |
+
w = w*dw
|
21 |
+
y = y*dh
|
22 |
+
h = h*dh
|
23 |
+
return (x,y,w,h)
|
24 |
+
|
25 |
+
def convert_annotation(year, image_id):
|
26 |
+
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
|
27 |
+
out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
|
28 |
+
tree=ET.parse(in_file)
|
29 |
+
root = tree.getroot()
|
30 |
+
size = root.find('size')
|
31 |
+
w = int(size.find('width').text)
|
32 |
+
h = int(size.find('height').text)
|
33 |
+
|
34 |
+
for obj in root.iter('object'):
|
35 |
+
difficult = obj.find('difficult').text
|
36 |
+
cls = obj.find('name').text
|
37 |
+
if cls not in classes or int(difficult)==1:
|
38 |
+
continue
|
39 |
+
cls_id = classes.index(cls)
|
40 |
+
xmlbox = obj.find('bndbox')
|
41 |
+
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
|
42 |
+
bb = convert((w,h), b)
|
43 |
+
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
|
44 |
+
|
45 |
+
wd = getcwd()
|
46 |
+
|
47 |
+
for year, image_set in sets:
|
48 |
+
if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
|
49 |
+
os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
|
50 |
+
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
|
51 |
+
list_file = open('%s_%s.txt'%(year, image_set), 'w')
|
52 |
+
for image_id in image_ids:
|
53 |
+
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
|
54 |
+
convert_annotation(year, image_id)
|
55 |
+
list_file.close()
|
56 |
+
|
57 |
+
os.system("cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt > train.txt")
|
58 |
+
os.system("cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt > train.all.txt")
|
59 |
+
|
model/src/activation_kernels.cu
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cuda_runtime.h"
|
2 |
+
#include "curand.h"
|
3 |
+
#include "cublas_v2.h"
|
4 |
+
|
5 |
+
extern "C" {
|
6 |
+
#include "activations.h"
|
7 |
+
#include "cuda.h"
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
__device__ float lhtan_activate_kernel(float x)
|
12 |
+
{
|
13 |
+
if(x < 0) return .001f*x;
|
14 |
+
if(x > 1) return .001f*(x-1.f) + 1.f;
|
15 |
+
return x;
|
16 |
+
}
|
17 |
+
__device__ float lhtan_gradient_kernel(float x)
|
18 |
+
{
|
19 |
+
if(x > 0 && x < 1) return 1;
|
20 |
+
return .001;
|
21 |
+
}
|
22 |
+
|
23 |
+
__device__ float hardtan_activate_kernel(float x)
|
24 |
+
{
|
25 |
+
if (x < -1) return -1;
|
26 |
+
if (x > 1) return 1;
|
27 |
+
return x;
|
28 |
+
}
|
29 |
+
__device__ float linear_activate_kernel(float x){return x;}
|
30 |
+
__device__ float logistic_activate_kernel(float x){return 1.f/(1.f + expf(-x));}
|
31 |
+
__device__ float loggy_activate_kernel(float x){return 2.f/(1.f + expf(-x)) - 1;}
|
32 |
+
__device__ float relu_activate_kernel(float x){return x*(x>0);}
|
33 |
+
__device__ float elu_activate_kernel(float x){return (x >= 0)*x + (x < 0)*(expf(x)-1);}
|
34 |
+
__device__ float selu_activate_kernel(float x){return (x >= 0)*1.0507f*x + (x < 0)*1.0507f*1.6732f*(expf(x)-1);}
|
35 |
+
__device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;}
|
36 |
+
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;}
|
37 |
+
__device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;}
|
38 |
+
__device__ float tanh_activate_kernel(float x){return (2.f/(1 + expf(-2*x)) - 1);}
|
39 |
+
__device__ float plse_activate_kernel(float x)
|
40 |
+
{
|
41 |
+
if(x < -4) return .01f * (x + 4);
|
42 |
+
if(x > 4) return .01f * (x - 4) + 1;
|
43 |
+
return .125f*x + .5f;
|
44 |
+
}
|
45 |
+
__device__ float stair_activate_kernel(float x)
|
46 |
+
{
|
47 |
+
int n = floorf(x);
|
48 |
+
if (n%2 == 0) return floorf(x/2);
|
49 |
+
else return (x - n) + floorf(x/2);
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
__device__ float hardtan_gradient_kernel(float x)
|
54 |
+
{
|
55 |
+
if (x > -1 && x < 1) return 1;
|
56 |
+
return 0;
|
57 |
+
}
|
58 |
+
__device__ float linear_gradient_kernel(float x){return 1;}
|
59 |
+
__device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
|
60 |
+
__device__ float loggy_gradient_kernel(float x)
|
61 |
+
{
|
62 |
+
float y = (x+1)/2;
|
63 |
+
return 2*(1-y)*y;
|
64 |
+
}
|
65 |
+
__device__ float relu_gradient_kernel(float x){return (x>0);}
|
66 |
+
__device__ float elu_gradient_kernel(float x){return (x >= 0) + (x < 0)*(x + 1);}
|
67 |
+
__device__ float selu_gradient_kernel(float x){return (x >= 0)*1.0507 + (x < 0)*(x + 1.0507*1.6732);}
|
68 |
+
__device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01f;}
|
69 |
+
__device__ float ramp_gradient_kernel(float x){return (x>0)+.1f;}
|
70 |
+
__device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1f;}
|
71 |
+
__device__ float tanh_gradient_kernel(float x){return 1-x*x;}
|
72 |
+
__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01f : .125f;}
|
73 |
+
__device__ float stair_gradient_kernel(float x)
|
74 |
+
{
|
75 |
+
if (floorf(x) == x) return 0;
|
76 |
+
return 1;
|
77 |
+
}
|
78 |
+
|
79 |
+
__device__ float activate_kernel(float x, ACTIVATION a)
|
80 |
+
{
|
81 |
+
switch(a){
|
82 |
+
case LINEAR:
|
83 |
+
return linear_activate_kernel(x);
|
84 |
+
case LOGISTIC:
|
85 |
+
return logistic_activate_kernel(x);
|
86 |
+
case LOGGY:
|
87 |
+
return loggy_activate_kernel(x);
|
88 |
+
case RELU:
|
89 |
+
return relu_activate_kernel(x);
|
90 |
+
case ELU:
|
91 |
+
return elu_activate_kernel(x);
|
92 |
+
case SELU:
|
93 |
+
return selu_activate_kernel(x);
|
94 |
+
case RELIE:
|
95 |
+
return relie_activate_kernel(x);
|
96 |
+
case RAMP:
|
97 |
+
return ramp_activate_kernel(x);
|
98 |
+
case LEAKY:
|
99 |
+
return leaky_activate_kernel(x);
|
100 |
+
case TANH:
|
101 |
+
return tanh_activate_kernel(x);
|
102 |
+
case PLSE:
|
103 |
+
return plse_activate_kernel(x);
|
104 |
+
case STAIR:
|
105 |
+
return stair_activate_kernel(x);
|
106 |
+
case HARDTAN:
|
107 |
+
return hardtan_activate_kernel(x);
|
108 |
+
case LHTAN:
|
109 |
+
return lhtan_activate_kernel(x);
|
110 |
+
}
|
111 |
+
return 0;
|
112 |
+
}
|
113 |
+
|
114 |
+
__device__ float gradient_kernel(float x, ACTIVATION a)
|
115 |
+
{
|
116 |
+
switch(a){
|
117 |
+
case LINEAR:
|
118 |
+
return linear_gradient_kernel(x);
|
119 |
+
case LOGISTIC:
|
120 |
+
return logistic_gradient_kernel(x);
|
121 |
+
case LOGGY:
|
122 |
+
return loggy_gradient_kernel(x);
|
123 |
+
case RELU:
|
124 |
+
return relu_gradient_kernel(x);
|
125 |
+
case ELU:
|
126 |
+
return elu_gradient_kernel(x);
|
127 |
+
case SELU:
|
128 |
+
return selu_gradient_kernel(x);
|
129 |
+
case RELIE:
|
130 |
+
return relie_gradient_kernel(x);
|
131 |
+
case RAMP:
|
132 |
+
return ramp_gradient_kernel(x);
|
133 |
+
case LEAKY:
|
134 |
+
return leaky_gradient_kernel(x);
|
135 |
+
case TANH:
|
136 |
+
return tanh_gradient_kernel(x);
|
137 |
+
case PLSE:
|
138 |
+
return plse_gradient_kernel(x);
|
139 |
+
case STAIR:
|
140 |
+
return stair_gradient_kernel(x);
|
141 |
+
case HARDTAN:
|
142 |
+
return hardtan_gradient_kernel(x);
|
143 |
+
case LHTAN:
|
144 |
+
return lhtan_gradient_kernel(x);
|
145 |
+
}
|
146 |
+
return 0;
|
147 |
+
}
|
148 |
+
|
149 |
+
__global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
|
150 |
+
{
|
151 |
+
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
152 |
+
int i = id % s;
|
153 |
+
int b = id / s;
|
154 |
+
float x1 = x[b*s + i];
|
155 |
+
float x2 = x[b*s + s/2 + i];
|
156 |
+
if(id < n) {
|
157 |
+
float de = dy[id];
|
158 |
+
dx[b*s + i] = x2*de;
|
159 |
+
dx[b*s + s/2 + i] = x1*de;
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
|
164 |
+
{
|
165 |
+
binary_gradient_array_kernel<<<cuda_gridsize(n/2), BLOCK>>>(x, dx, n/2, size, a, y);
|
166 |
+
check_error(cudaPeekAtLastError());
|
167 |
+
}
|
168 |
+
__global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
|
169 |
+
{
|
170 |
+
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
171 |
+
int i = id % s;
|
172 |
+
int b = id / s;
|
173 |
+
float x1 = x[b*s + i];
|
174 |
+
float x2 = x[b*s + s/2 + i];
|
175 |
+
if(id < n) y[id] = x1*x2;
|
176 |
+
}
|
177 |
+
|
178 |
+
extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
|
179 |
+
{
|
180 |
+
binary_activate_array_kernel<<<cuda_gridsize(n/2), BLOCK>>>(x, n/2, size, a, y);
|
181 |
+
check_error(cudaPeekAtLastError());
|
182 |
+
}
|
183 |
+
|
184 |
+
__global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
|
185 |
+
{
|
186 |
+
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
187 |
+
if(i < n) x[i] = activate_kernel(x[i], a);
|
188 |
+
}
|
189 |
+
|
190 |
+
__global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
|
191 |
+
{
|
192 |
+
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
193 |
+
if(i < n) delta[i] *= gradient_kernel(x[i], a);
|
194 |
+
}
|
195 |
+
|
196 |
+
extern "C" void activate_array_gpu(float *x, int n, ACTIVATION a)
|
197 |
+
{
|
198 |
+
activate_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a);
|
199 |
+
check_error(cudaPeekAtLastError());
|
200 |
+
}
|
201 |
+
|
202 |
+
extern "C" void gradient_array_gpu(float *x, int n, ACTIVATION a, float *delta)
|
203 |
+
{
|
204 |
+
gradient_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a, delta);
|
205 |
+
check_error(cudaPeekAtLastError());
|
206 |
+
}
|
model/src/activation_layer.c
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "activation_layer.h"
|
2 |
+
#include "utils.h"
|
3 |
+
#include "cuda.h"
|
4 |
+
#include "blas.h"
|
5 |
+
#include "gemm.h"
|
6 |
+
|
7 |
+
#include <math.h>
|
8 |
+
#include <stdio.h>
|
9 |
+
#include <stdlib.h>
|
10 |
+
#include <string.h>
|
11 |
+
|
12 |
+
layer make_activation_layer(int batch, int inputs, ACTIVATION activation)
|
13 |
+
{
|
14 |
+
layer l = {0};
|
15 |
+
l.type = ACTIVE;
|
16 |
+
|
17 |
+
l.inputs = inputs;
|
18 |
+
l.outputs = inputs;
|
19 |
+
l.batch=batch;
|
20 |
+
|
21 |
+
l.output = calloc(batch*inputs, sizeof(float*));
|
22 |
+
l.delta = calloc(batch*inputs, sizeof(float*));
|
23 |
+
|
24 |
+
l.forward = forward_activation_layer;
|
25 |
+
l.backward = backward_activation_layer;
|
26 |
+
#ifdef GPU
|
27 |
+
l.forward_gpu = forward_activation_layer_gpu;
|
28 |
+
l.backward_gpu = backward_activation_layer_gpu;
|
29 |
+
|
30 |
+
l.output_gpu = cuda_make_array(l.output, inputs*batch);
|
31 |
+
l.delta_gpu = cuda_make_array(l.delta, inputs*batch);
|
32 |
+
#endif
|
33 |
+
l.activation = activation;
|
34 |
+
fprintf(stderr, "Activation Layer: %d inputs\n", inputs);
|
35 |
+
return l;
|
36 |
+
}
|
37 |
+
|
38 |
+
void forward_activation_layer(layer l, network net)
|
39 |
+
{
|
40 |
+
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
|
41 |
+
activate_array(l.output, l.outputs*l.batch, l.activation);
|
42 |
+
}
|
43 |
+
|
44 |
+
void backward_activation_layer(layer l, network net)
|
45 |
+
{
|
46 |
+
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
47 |
+
copy_cpu(l.outputs*l.batch, l.delta, 1, net.delta, 1);
|
48 |
+
}
|
49 |
+
|
50 |
+
#ifdef GPU
|
51 |
+
|
52 |
+
void forward_activation_layer_gpu(layer l, network net)
|
53 |
+
{
|
54 |
+
copy_gpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
|
55 |
+
activate_array_gpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
56 |
+
}
|
57 |
+
|
58 |
+
void backward_activation_layer_gpu(layer l, network net)
|
59 |
+
{
|
60 |
+
gradient_array_gpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
|
61 |
+
copy_gpu(l.outputs*l.batch, l.delta_gpu, 1, net.delta_gpu, 1);
|
62 |
+
}
|
63 |
+
#endif
|
model/src/activation_layer.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef ACTIVATION_LAYER_H
|
2 |
+
#define ACTIVATION_LAYER_H
|
3 |
+
|
4 |
+
#include "activations.h"
|
5 |
+
#include "layer.h"
|
6 |
+
#include "network.h"
|
7 |
+
|
8 |
+
layer make_activation_layer(int batch, int inputs, ACTIVATION activation);
|
9 |
+
|
10 |
+
void forward_activation_layer(layer l, network net);
|
11 |
+
void backward_activation_layer(layer l, network net);
|
12 |
+
|
13 |
+
#ifdef GPU
|
14 |
+
void forward_activation_layer_gpu(layer l, network net);
|
15 |
+
void backward_activation_layer_gpu(layer l, network net);
|
16 |
+
#endif
|
17 |
+
|
18 |
+
#endif
|
19 |
+
|
model/src/activations.c
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "activations.h"
|
2 |
+
|
3 |
+
#include <math.h>
|
4 |
+
#include <stdio.h>
|
5 |
+
#include <stdlib.h>
|
6 |
+
#include <string.h>
|
7 |
+
|
8 |
+
char *get_activation_string(ACTIVATION a)
|
9 |
+
{
|
10 |
+
switch(a){
|
11 |
+
case LOGISTIC:
|
12 |
+
return "logistic";
|
13 |
+
case LOGGY:
|
14 |
+
return "loggy";
|
15 |
+
case RELU:
|
16 |
+
return "relu";
|
17 |
+
case ELU:
|
18 |
+
return "elu";
|
19 |
+
case SELU:
|
20 |
+
return "selu";
|
21 |
+
case RELIE:
|
22 |
+
return "relie";
|
23 |
+
case RAMP:
|
24 |
+
return "ramp";
|
25 |
+
case LINEAR:
|
26 |
+
return "linear";
|
27 |
+
case TANH:
|
28 |
+
return "tanh";
|
29 |
+
case PLSE:
|
30 |
+
return "plse";
|
31 |
+
case LEAKY:
|
32 |
+
return "leaky";
|
33 |
+
case STAIR:
|
34 |
+
return "stair";
|
35 |
+
case HARDTAN:
|
36 |
+
return "hardtan";
|
37 |
+
case LHTAN:
|
38 |
+
return "lhtan";
|
39 |
+
default:
|
40 |
+
break;
|
41 |
+
}
|
42 |
+
return "relu";
|
43 |
+
}
|
44 |
+
|
45 |
+
ACTIVATION get_activation(char *s)
|
46 |
+
{
|
47 |
+
if (strcmp(s, "logistic")==0) return LOGISTIC;
|
48 |
+
if (strcmp(s, "loggy")==0) return LOGGY;
|
49 |
+
if (strcmp(s, "relu")==0) return RELU;
|
50 |
+
if (strcmp(s, "elu")==0) return ELU;
|
51 |
+
if (strcmp(s, "selu")==0) return SELU;
|
52 |
+
if (strcmp(s, "relie")==0) return RELIE;
|
53 |
+
if (strcmp(s, "plse")==0) return PLSE;
|
54 |
+
if (strcmp(s, "hardtan")==0) return HARDTAN;
|
55 |
+
if (strcmp(s, "lhtan")==0) return LHTAN;
|
56 |
+
if (strcmp(s, "linear")==0) return LINEAR;
|
57 |
+
if (strcmp(s, "ramp")==0) return RAMP;
|
58 |
+
if (strcmp(s, "leaky")==0) return LEAKY;
|
59 |
+
if (strcmp(s, "tanh")==0) return TANH;
|
60 |
+
if (strcmp(s, "stair")==0) return STAIR;
|
61 |
+
fprintf(stderr, "Couldn't find activation function %s, going with ReLU\n", s);
|
62 |
+
return RELU;
|
63 |
+
}
|
64 |
+
|
65 |
+
float activate(float x, ACTIVATION a)
|
66 |
+
{
|
67 |
+
switch(a){
|
68 |
+
case LINEAR:
|
69 |
+
return linear_activate(x);
|
70 |
+
case LOGISTIC:
|
71 |
+
return logistic_activate(x);
|
72 |
+
case LOGGY:
|
73 |
+
return loggy_activate(x);
|
74 |
+
case RELU:
|
75 |
+
return relu_activate(x);
|
76 |
+
case ELU:
|
77 |
+
return elu_activate(x);
|
78 |
+
case SELU:
|
79 |
+
return selu_activate(x);
|
80 |
+
case RELIE:
|
81 |
+
return relie_activate(x);
|
82 |
+
case RAMP:
|
83 |
+
return ramp_activate(x);
|
84 |
+
case LEAKY:
|
85 |
+
return leaky_activate(x);
|
86 |
+
case TANH:
|
87 |
+
return tanh_activate(x);
|
88 |
+
case PLSE:
|
89 |
+
return plse_activate(x);
|
90 |
+
case STAIR:
|
91 |
+
return stair_activate(x);
|
92 |
+
case HARDTAN:
|
93 |
+
return hardtan_activate(x);
|
94 |
+
case LHTAN:
|
95 |
+
return lhtan_activate(x);
|
96 |
+
}
|
97 |
+
return 0;
|
98 |
+
}
|
99 |
+
|
100 |
+
void activate_array(float *x, const int n, const ACTIVATION a)
|
101 |
+
{
|
102 |
+
int i;
|
103 |
+
for(i = 0; i < n; ++i){
|
104 |
+
x[i] = activate(x[i], a);
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
float gradient(float x, ACTIVATION a)
|
109 |
+
{
|
110 |
+
switch(a){
|
111 |
+
case LINEAR:
|
112 |
+
return linear_gradient(x);
|
113 |
+
case LOGISTIC:
|
114 |
+
return logistic_gradient(x);
|
115 |
+
case LOGGY:
|
116 |
+
return loggy_gradient(x);
|
117 |
+
case RELU:
|
118 |
+
return relu_gradient(x);
|
119 |
+
case ELU:
|
120 |
+
return elu_gradient(x);
|
121 |
+
case SELU:
|
122 |
+
return selu_gradient(x);
|
123 |
+
case RELIE:
|
124 |
+
return relie_gradient(x);
|
125 |
+
case RAMP:
|
126 |
+
return ramp_gradient(x);
|
127 |
+
case LEAKY:
|
128 |
+
return leaky_gradient(x);
|
129 |
+
case TANH:
|
130 |
+
return tanh_gradient(x);
|
131 |
+
case PLSE:
|
132 |
+
return plse_gradient(x);
|
133 |
+
case STAIR:
|
134 |
+
return stair_gradient(x);
|
135 |
+
case HARDTAN:
|
136 |
+
return hardtan_gradient(x);
|
137 |
+
case LHTAN:
|
138 |
+
return lhtan_gradient(x);
|
139 |
+
}
|
140 |
+
return 0;
|
141 |
+
}
|
142 |
+
|
143 |
+
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta)
|
144 |
+
{
|
145 |
+
int i;
|
146 |
+
for(i = 0; i < n; ++i){
|
147 |
+
delta[i] *= gradient(x[i], a);
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
model/src/activations.h
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef ACTIVATIONS_H
|
2 |
+
#define ACTIVATIONS_H
|
3 |
+
#include "darknet.h"
|
4 |
+
#include "cuda.h"
|
5 |
+
#include "math.h"
|
6 |
+
|
7 |
+
ACTIVATION get_activation(char *s);
|
8 |
+
|
9 |
+
char *get_activation_string(ACTIVATION a);
|
10 |
+
float activate(float x, ACTIVATION a);
|
11 |
+
float gradient(float x, ACTIVATION a);
|
12 |
+
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
|
13 |
+
void activate_array(float *x, const int n, const ACTIVATION a);
|
14 |
+
#ifdef GPU
|
15 |
+
void activate_array_gpu(float *x, int n, ACTIVATION a);
|
16 |
+
void gradient_array_gpu(float *x, int n, ACTIVATION a, float *delta);
|
17 |
+
#endif
|
18 |
+
|
19 |
+
static inline float stair_activate(float x)
|
20 |
+
{
|
21 |
+
int n = floor(x);
|
22 |
+
if (n%2 == 0) return floor(x/2.);
|
23 |
+
else return (x - n) + floor(x/2.);
|
24 |
+
}
|
25 |
+
static inline float hardtan_activate(float x)
|
26 |
+
{
|
27 |
+
if (x < -1) return -1;
|
28 |
+
if (x > 1) return 1;
|
29 |
+
return x;
|
30 |
+
}
|
31 |
+
static inline float linear_activate(float x){return x;}
|
32 |
+
static inline float logistic_activate(float x){return 1./(1. + exp(-x));}
|
33 |
+
static inline float loggy_activate(float x){return 2./(1. + exp(-x)) - 1;}
|
34 |
+
static inline float relu_activate(float x){return x*(x>0);}
|
35 |
+
static inline float elu_activate(float x){return (x >= 0)*x + (x < 0)*(exp(x)-1);}
|
36 |
+
static inline float selu_activate(float x){return (x >= 0)*1.0507*x + (x < 0)*1.0507*1.6732*(exp(x)-1);}
|
37 |
+
static inline float relie_activate(float x){return (x>0) ? x : .01*x;}
|
38 |
+
static inline float ramp_activate(float x){return x*(x>0)+.1*x;}
|
39 |
+
static inline float leaky_activate(float x){return (x>0) ? x : .1*x;}
|
40 |
+
static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
|
41 |
+
static inline float plse_activate(float x)
|
42 |
+
{
|
43 |
+
if(x < -4) return .01 * (x + 4);
|
44 |
+
if(x > 4) return .01 * (x - 4) + 1;
|
45 |
+
return .125*x + .5;
|
46 |
+
}
|
47 |
+
|
48 |
+
static inline float lhtan_activate(float x)
|
49 |
+
{
|
50 |
+
if(x < 0) return .001*x;
|
51 |
+
if(x > 1) return .001*(x-1) + 1;
|
52 |
+
return x;
|
53 |
+
}
|
54 |
+
static inline float lhtan_gradient(float x)
|
55 |
+
{
|
56 |
+
if(x > 0 && x < 1) return 1;
|
57 |
+
return .001;
|
58 |
+
}
|
59 |
+
|
60 |
+
static inline float hardtan_gradient(float x)
|
61 |
+
{
|
62 |
+
if (x > -1 && x < 1) return 1;
|
63 |
+
return 0;
|
64 |
+
}
|
65 |
+
static inline float linear_gradient(float x){return 1;}
|
66 |
+
static inline float logistic_gradient(float x){return (1-x)*x;}
|
67 |
+
static inline float loggy_gradient(float x)
|
68 |
+
{
|
69 |
+
float y = (x+1.)/2.;
|
70 |
+
return 2*(1-y)*y;
|
71 |
+
}
|
72 |
+
static inline float stair_gradient(float x)
|
73 |
+
{
|
74 |
+
if (floor(x) == x) return 0;
|
75 |
+
return 1;
|
76 |
+
}
|
77 |
+
static inline float relu_gradient(float x){return (x>0);}
|
78 |
+
static inline float elu_gradient(float x){return (x >= 0) + (x < 0)*(x + 1);}
|
79 |
+
static inline float selu_gradient(float x){return (x >= 0)*1.0507 + (x < 0)*(x + 1.0507*1.6732);}
|
80 |
+
static inline float relie_gradient(float x){return (x>0) ? 1 : .01;}
|
81 |
+
static inline float ramp_gradient(float x){return (x>0)+.1;}
|
82 |
+
static inline float leaky_gradient(float x){return (x>0) ? 1 : .1;}
|
83 |
+
static inline float tanh_gradient(float x){return 1-x*x;}
|
84 |
+
static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;}
|
85 |
+
|
86 |
+
#endif
|
87 |
+
|