Spaces:
Build error
Build error
mshukor
commited on
Commit
·
87d7283
1
Parent(s):
d41c3ca
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- README.md +8 -9
- TimeSformer/.gitignore +143 -0
- TimeSformer/CODE_OF_CONDUCT.md +5 -0
- TimeSformer/CONTRIBUTING.md +25 -0
- TimeSformer/LICENSE +399 -0
- TimeSformer/README.md +248 -0
- TimeSformer/configs/Kinetics/SLOWFAST_4x16_R50.yaml +63 -0
- TimeSformer/configs/Kinetics/SLOWFAST_8x8_R101.yaml +63 -0
- TimeSformer/configs/Kinetics/SLOWFAST_8x8_R50.yaml +63 -0
- TimeSformer/configs/Kinetics/TimeSformer_divST_16x16_448.yaml +45 -0
- TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml +45 -0
- TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml +45 -0
- TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml +46 -0
- TimeSformer/configs/Kinetics/TimeSformer_divST_96x4_224.yaml +45 -0
- TimeSformer/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml +45 -0
- TimeSformer/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml +45 -0
- TimeSformer/configs/SSv2/SLOWFAST_16x8_R50.yaml +83 -0
- TimeSformer/configs/SSv2/TimeSformer_divST_16_448.yaml +48 -0
- TimeSformer/configs/SSv2/TimeSformer_divST_64_224.yaml +48 -0
- TimeSformer/configs/SSv2/TimeSformer_divST_8_224.yaml +48 -0
- TimeSformer/environment.yml +26 -0
- TimeSformer/example.ipynb +84 -0
- TimeSformer/setup.cfg +23 -0
- TimeSformer/setup.py +23 -0
- TimeSformer/slurm_scripts/run_multi_node_job.sh +25 -0
- TimeSformer/slurm_scripts/run_single_node_job.sh +35 -0
- TimeSformer/timesformer/__init__.py +5 -0
- TimeSformer/timesformer/config/__init__.py +1 -0
- TimeSformer/timesformer/config/defaults.py +820 -0
- TimeSformer/timesformer/datasets/DATASET.md +26 -0
- TimeSformer/timesformer/datasets/__init__.py +5 -0
- TimeSformer/timesformer/datasets/build.py +30 -0
- TimeSformer/timesformer/datasets/cv2_transform.py +796 -0
- TimeSformer/timesformer/datasets/decoder.py +392 -0
- TimeSformer/timesformer/datasets/kinetics.py +294 -0
- TimeSformer/timesformer/datasets/loader.py +134 -0
- TimeSformer/timesformer/datasets/multigrid_helper.py +78 -0
- TimeSformer/timesformer/datasets/ssv2.py +278 -0
- TimeSformer/timesformer/datasets/transform.py +459 -0
- TimeSformer/timesformer/datasets/utils.py +380 -0
- TimeSformer/timesformer/datasets/video_container.py +31 -0
- TimeSformer/timesformer/models/__init__.py +5 -0
- TimeSformer/timesformer/models/batchnorm_helper.py +217 -0
- TimeSformer/timesformer/models/build.py +54 -0
- TimeSformer/timesformer/models/conv2d_same.py +74 -0
- TimeSformer/timesformer/models/custom_video_model_builder.py +4 -0
- TimeSformer/timesformer/models/features.py +266 -0
- TimeSformer/timesformer/models/head_helper.py +235 -0
- TimeSformer/timesformer/models/helpers.py +360 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 mshukor
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: eP-ALM
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.12.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
|
TimeSformer/.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
# Docker file from Python is inspired from here :
|
6 |
+
# https://github.com/github/gitignore/blob/master/Python.gitignore
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
tests/report/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
102 |
+
__pypackages__/
|
103 |
+
|
104 |
+
# Celery stuff
|
105 |
+
celerybeat-schedule
|
106 |
+
celerybeat.pid
|
107 |
+
|
108 |
+
# SageMath parsed files
|
109 |
+
*.sage.py
|
110 |
+
|
111 |
+
# Environments
|
112 |
+
.env
|
113 |
+
.venv
|
114 |
+
env/
|
115 |
+
venv/
|
116 |
+
ENV/
|
117 |
+
env.bak/
|
118 |
+
venv.bak/
|
119 |
+
|
120 |
+
# Spyder project settings
|
121 |
+
.spyderproject
|
122 |
+
.spyproject
|
123 |
+
|
124 |
+
# Rope project settings
|
125 |
+
.ropeproject
|
126 |
+
|
127 |
+
# mkdocs documentation
|
128 |
+
/site
|
129 |
+
|
130 |
+
# mypy
|
131 |
+
.mypy_cache/
|
132 |
+
.dmypy.json
|
133 |
+
dmypy.json
|
134 |
+
|
135 |
+
# Pyre type checker
|
136 |
+
.pyre/
|
137 |
+
|
138 |
+
# pytype static type analyzer
|
139 |
+
.pytype/
|
140 |
+
|
141 |
+
|
142 |
+
# Cython debug symbols
|
143 |
+
cython_debug/
|
TimeSformer/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
4 |
+
Please read the [full text](https://code.fb.com/codeofconduct/)
|
5 |
+
so that you can understand what actions will and will not be tolerated.
|
TimeSformer/CONTRIBUTING.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to TimeSformer
|
2 |
+
|
3 |
+
## Pull Requests
|
4 |
+
We actively welcome your pull requests.
|
5 |
+
|
6 |
+
1. Fork the repo and create your branch from `master`.
|
7 |
+
2. If you've added code that should be tested, add tests.
|
8 |
+
3. If you've changed APIs, update the documentation.
|
9 |
+
4. Ensure the test suite passes.
|
10 |
+
5. Make sure your code lints.
|
11 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
12 |
+
|
13 |
+
## Contributor License Agreement ("CLA")
|
14 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
15 |
+
to do this once to work on any of Facebook's open source projects.
|
16 |
+
|
17 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
18 |
+
|
19 |
+
## Issues
|
20 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
21 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
22 |
+
|
23 |
+
## License
|
24 |
+
By contributing to TimeSformer, you agree that your contributions will be licensed
|
25 |
+
under the [LICENSE.md](LICENSE.md) file in the root directory of this source tree.
|
TimeSformer/LICENSE
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
TimeSformer/README.md
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TimeSformer
|
2 |
+
|
3 |
+
This is an official pytorch implementation of our ICML 2021 paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/pdf/2102.05095.pdf). In this repository, we provide PyTorch code for training and testing our proposed TimeSformer model. TimeSformer provides an efficient video classification framework that achieves state-of-the-art results on several video action recognition benchmarks such as Kinetics-400.
|
4 |
+
|
5 |
+
If you find TimeSformer useful in your research, please use the following BibTeX entry for citation.
|
6 |
+
|
7 |
+
```BibTeX
|
8 |
+
@inproceedings{gberta_2021_ICML,
|
9 |
+
author = {Gedas Bertasius and Heng Wang and Lorenzo Torresani},
|
10 |
+
title = {Is Space-Time Attention All You Need for Video Understanding?},
|
11 |
+
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
12 |
+
month = {July},
|
13 |
+
year = {2021}
|
14 |
+
}
|
15 |
+
```
|
16 |
+
|
17 |
+
# Model Zoo
|
18 |
+
|
19 |
+
We provide TimeSformer models pretrained on Kinetics-400 (K400), Kinetics-600 (K600), Something-Something-V2 (SSv2), and HowTo100M datasets.
|
20 |
+
|
21 |
+
| name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
|
22 |
+
| --- | --- | --- | --- | --- | --- | --- |
|
23 |
+
| TimeSformer | K400 | 8 | 224 | 77.9 | 93.2 | [model](https://www.dropbox.com/s/g5t24we9gl5yk88/TimeSformer_divST_8x32_224_K400.pyth?dl=0) |
|
24 |
+
| TimeSformer-HR | K400 | 16 | 448 | 79.6 | 94.0 | [model](https://www.dropbox.com/s/6f0x172lpqy3oxt/TimeSformer_divST_16x16_448_K400.pyth?dl=0) |
|
25 |
+
| TimeSformer-L | K400 | 96 | 224 | 80.6 | 94.7 | [model](https://www.dropbox.com/s/r1iuxahif3sgimo/TimeSformer_divST_96x4_224_K400.pyth?dl=0) |
|
26 |
+
|
27 |
+
| name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
|
28 |
+
| --- | --- | --- | --- | --- | --- | --- |
|
29 |
+
| TimeSformer | K600 | 8 | 224 | 79.1 | 94.4 | [model](https://www.dropbox.com/s/4h2qt41m2z3aqrb/TimeSformer_divST_8x32_224_K600.pyth?dl=0) |
|
30 |
+
| TimeSformer-HR | K600 | 16 | 448 | 81.8 | 95.8 | [model](https://www.dropbox.com/s/ft1e92g2vhvxecv/TimeSformer_divST_16x16_448_K600.pyth?dl=0) |
|
31 |
+
| TimeSformer-L | K600 | 96 | 224 | 82.2 | 95.6 | [model](https://www.dropbox.com/s/857rx6xeclxfhdg/TimeSformer_divST_96x4_224_K600.pyth?dl=0) |
|
32 |
+
|
33 |
+
| name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
|
34 |
+
| --- | --- | --- | --- | --- | --- | --- |
|
35 |
+
| TimeSformer | SSv2 | 8 | 224 | 59.1 | 85.6 | [model](https://www.dropbox.com/s/tybhuml57y24wpm/TimeSformer_divST_8_224_SSv2.pyth?dl=0) |
|
36 |
+
| TimeSformer-HR | SSv2 | 16 | 448 | 61.8 | 86.9 | [model](https://www.dropbox.com/s/9t68uzk8w2fpfnv/TimeSformer_divST_16_448_SSv2.pyth?dl=0) |
|
37 |
+
| TimeSformer-L | SSv2 | 64 | 224 | 62.0 | 87.5 | [model](https://www.dropbox.com/s/3f1rm2al8mhprwa/TimeSformer_divST_64_224_SSv2.pyth?dl=0) |
|
38 |
+
|
39 |
+
| name | dataset | # of frames | spatial crop | single clip coverage | acc@1 | url |
|
40 |
+
| --- | --- | --- | --- | --- | --- | --- |
|
41 |
+
| TimeSformer | HowTo100M | 8 | 224 | 8.5s | 56.8 | [model](https://www.dropbox.com/s/9v8hcm88b9tc6ff/TimeSformer_divST_8x32_224_HowTo100M.pyth?dl=0) |
|
42 |
+
| TimeSformer | HowTo100M | 32 | 224 | 34.1s | 61.2 | [model](https://www.dropbox.com/s/4roflx4q1gscu85/TimeSformer_divST_32x32_224_HowTo100M.pyth?dl=0) |
|
43 |
+
| TimeSformer | HowTo100M | 64 | 448 | 68.3s | 62.2 | [model](https://www.dropbox.com/s/15bvqltl1j5vyp3/TimeSformer_divST_64x32_224_HowTo100M.pyth?dl=0) |
|
44 |
+
| TimeSformer | HowTo100M | 96 | 224 | 102.4s | 62.6 | [model](https://www.dropbox.com/s/t2mzgahnfhgakma/TimeSformer_divST_96x32_224_HowTo100M.pyth?dl=0) |
|
45 |
+
|
46 |
+
We note that these models were re-trained using a slightly different implementation than the one used in the paper. Therefore, there might be a small difference in performance compared to the results reported in the paper.
|
47 |
+
|
48 |
+
You can load the pretrained models as follows:
|
49 |
+
|
50 |
+
```python
|
51 |
+
import torch
|
52 |
+
from timesformer.models.vit import TimeSformer
|
53 |
+
|
54 |
+
model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time', pretrained_model='/path/to/pretrained/model.pyth')
|
55 |
+
|
56 |
+
dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)
|
57 |
+
|
58 |
+
pred = model(dummy_video,) # (2, 400)
|
59 |
+
```
|
60 |
+
|
61 |
+
# Installation
|
62 |
+
|
63 |
+
First, create a conda virtual environment and activate it:
|
64 |
+
```
|
65 |
+
conda create -n timesformer python=3.7 -y
|
66 |
+
source activate timesformer
|
67 |
+
```
|
68 |
+
|
69 |
+
Then, install the following packages:
|
70 |
+
|
71 |
+
- torchvision: `pip install torchvision` or `conda install torchvision -c pytorch`
|
72 |
+
- [fvcore](https://github.com/facebookresearch/fvcore/): `pip install 'git+https://github.com/facebookresearch/fvcore'`
|
73 |
+
- simplejson: `pip install simplejson`
|
74 |
+
- einops: `pip install einops`
|
75 |
+
- timm: `pip install timm`
|
76 |
+
- PyAV: `conda install av -c conda-forge`
|
77 |
+
- psutil: `pip install psutil`
|
78 |
+
- scikit-learn: `pip install scikit-learn`
|
79 |
+
- OpenCV: `pip install opencv-python`
|
80 |
+
- tensorboard: `pip install tensorboard`
|
81 |
+
|
82 |
+
Lastly, build the TimeSformer codebase by running:
|
83 |
+
```
|
84 |
+
git clone https://github.com/facebookresearch/TimeSformer
|
85 |
+
cd TimeSformer
|
86 |
+
python setup.py build develop
|
87 |
+
```
|
88 |
+
|
89 |
+
# Usage
|
90 |
+
|
91 |
+
## Dataset Preparation
|
92 |
+
|
93 |
+
Please use the dataset preparation instructions provided in [DATASET.md](timesformer/datasets/DATASET.md).
|
94 |
+
|
95 |
+
## Training the Default TimeSformer
|
96 |
+
|
97 |
+
Training the default TimeSformer that uses divided space-time attention, and operates on 8-frame clips cropped at 224x224 spatial resolution, can be done using the following command:
|
98 |
+
|
99 |
+
```
|
100 |
+
python tools/run_net.py \
|
101 |
+
--cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml \
|
102 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
103 |
+
NUM_GPUS 8 \
|
104 |
+
TRAIN.BATCH_SIZE 8 \
|
105 |
+
```
|
106 |
+
You may need to pass location of your dataset in the command line by adding `DATA.PATH_TO_DATA_DIR path_to_your_dataset`, or you can simply add
|
107 |
+
|
108 |
+
```
|
109 |
+
DATA:
|
110 |
+
PATH_TO_DATA_DIR: path_to_your_dataset
|
111 |
+
```
|
112 |
+
|
113 |
+
To the yaml configs file, then you do not need to pass it to the command line every time.
|
114 |
+
|
115 |
+
## Using a Different Number of GPUs
|
116 |
+
|
117 |
+
If you want to use a smaller number of GPUs, you need to modify .yaml configuration files in [`configs/`](configs/). Specifically, you need to modify the NUM_GPUS, TRAIN.BATCH_SIZE, TEST.BATCH_SIZE, DATA_LOADER.NUM_WORKERS entries in each configuration file. The BATCH_SIZE entry should be the same or higher as the NUM_GPUS entry. In [`configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml`](configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml), we provide a sample configuration file for a 4 GPU setup.
|
118 |
+
|
119 |
+
|
120 |
+
## Using Different Self-Attention Schemes
|
121 |
+
|
122 |
+
If you want to experiment with different space-time self-attention schemes, e.g., space-only or joint space-time attention, use the following commands:
|
123 |
+
|
124 |
+
|
125 |
+
```
|
126 |
+
python tools/run_net.py \
|
127 |
+
--cfg configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml \
|
128 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
129 |
+
NUM_GPUS 8 \
|
130 |
+
TRAIN.BATCH_SIZE 8 \
|
131 |
+
```
|
132 |
+
|
133 |
+
and
|
134 |
+
|
135 |
+
```
|
136 |
+
python tools/run_net.py \
|
137 |
+
--cfg configs/Kinetics/TimeSformer_jointST_8x32_224.yaml \
|
138 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
139 |
+
NUM_GPUS 8 \
|
140 |
+
TRAIN.BATCH_SIZE 8 \
|
141 |
+
```
|
142 |
+
|
143 |
+
## Training Different TimeSformer Variants
|
144 |
+
|
145 |
+
If you want to train more powerful TimeSformer variants, e.g., TimeSformer-HR (operating on 16-frame clips sampled at 448x448 spatial resolution), and TimeSformer-L (operating on 96-frame clips sampled at 224x224 spatial resolution), use the following commands:
|
146 |
+
|
147 |
+
```
|
148 |
+
python tools/run_net.py \
|
149 |
+
--cfg configs/Kinetics/TimeSformer_divST_16x16_448.yaml \
|
150 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
151 |
+
NUM_GPUS 8 \
|
152 |
+
TRAIN.BATCH_SIZE 8 \
|
153 |
+
```
|
154 |
+
|
155 |
+
and
|
156 |
+
|
157 |
+
```
|
158 |
+
python tools/run_net.py \
|
159 |
+
--cfg configs/Kinetics/TimeSformer_divST_96x4_224.yaml \
|
160 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
161 |
+
NUM_GPUS 8 \
|
162 |
+
TRAIN.BATCH_SIZE 8 \
|
163 |
+
```
|
164 |
+
|
165 |
+
Note that for these models you will need a set of GPUs with ~32GB of memory.
|
166 |
+
|
167 |
+
## Inference
|
168 |
+
|
169 |
+
Use `TRAIN.ENABLE` and `TEST.ENABLE` to control whether training or testing is required for a given run. When testing, you also have to provide the path to the checkpoint model via TEST.CHECKPOINT_FILE_PATH.
|
170 |
+
```
|
171 |
+
python tools/run_net.py \
|
172 |
+
--cfg configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml \
|
173 |
+
DATA.PATH_TO_DATA_DIR path_to_your_dataset \
|
174 |
+
TEST.CHECKPOINT_FILE_PATH path_to_your_checkpoint \
|
175 |
+
TRAIN.ENABLE False \
|
176 |
+
```
|
177 |
+
|
178 |
+
## Single-Node Training via Slurm
|
179 |
+
|
180 |
+
To train TimeSformer via Slurm, please check out our single node Slurm training script [`slurm_scripts/run_single_node_job.sh`](slurm_scripts/run_single_node_job.sh).
|
181 |
+
|
182 |
+
|
183 |
+
## Multi-Node Training via Submitit
|
184 |
+
|
185 |
+
Distributed training is available via Slurm and submitit
|
186 |
+
|
187 |
+
```
|
188 |
+
pip install submitit
|
189 |
+
```
|
190 |
+
|
191 |
+
To train TimeSformer model on Kinetics using 4 nodes with 8 gpus each use the following command:
|
192 |
+
```
|
193 |
+
python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --name ${JOB_NAME} --use_volta32
|
194 |
+
```
|
195 |
+
|
196 |
+
We provide a script for launching slurm jobs in [`slurm_scripts/run_multi_node_job.sh`](slurm_scripts/run_multi_node_job.sh).
|
197 |
+
|
198 |
+
## Finetuning
|
199 |
+
|
200 |
+
To finetune from an existing PyTorch checkpoint add the following line in the command line, or you can also add it in the YAML config:
|
201 |
+
|
202 |
+
```
|
203 |
+
TRAIN.CHECKPOINT_FILE_PATH path_to_your_PyTorch_checkpoint
|
204 |
+
TRAIN.FINETUNE True
|
205 |
+
```
|
206 |
+
|
207 |
+
## HowTo100M Dataset Split
|
208 |
+
|
209 |
+
If you want to experiment with the long-term video modeling task on HowTo100M, please download the train/test split files from [here](https://www.dropbox.com/sh/ttvsxwqypijjuda/AACmJx1CnddW6cVBoc21eSuva?dl=0).
|
210 |
+
|
211 |
+
|
212 |
+
# Environment
|
213 |
+
|
214 |
+
The code was developed using python 3.7 on Ubuntu 20.04. For training, we used four GPU compute nodes each node containing 8 Tesla V100 GPUs (32 GPUs in total). Other platforms or GPU cards have not been fully tested.
|
215 |
+
|
216 |
+
# License
|
217 |
+
|
218 |
+
The majority of this work is licensed under [CC-NC 4.0 International license](LICENSE). However portions of the project are available under separate license terms: [SlowFast](https://github.com/facebookresearch/SlowFast) and [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) are licensed under the Apache 2.0 license.
|
219 |
+
|
220 |
+
# Contributing
|
221 |
+
|
222 |
+
We actively welcome your pull requests. Please see [CONTRIBUTING.md](CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) for more info.
|
223 |
+
|
224 |
+
# Acknowledgements
|
225 |
+
|
226 |
+
TimeSformer is built on top of [PySlowFast](https://github.com/facebookresearch/SlowFast) and [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) by [Ross Wightman](https://github.com/rwightman). We thank the authors for releasing their code. If you use our model, please consider citing these works as well:
|
227 |
+
|
228 |
+
```BibTeX
|
229 |
+
@misc{fan2020pyslowfast,
|
230 |
+
author = {Haoqi Fan and Yanghao Li and Bo Xiong and Wan-Yen Lo and
|
231 |
+
Christoph Feichtenhofer},
|
232 |
+
title = {PySlowFast},
|
233 |
+
howpublished = {\url{https://github.com/facebookresearch/slowfast}},
|
234 |
+
year = {2020}
|
235 |
+
}
|
236 |
+
```
|
237 |
+
|
238 |
+
```BibTeX
|
239 |
+
@misc{rw2019timm,
|
240 |
+
author = {Ross Wightman},
|
241 |
+
title = {PyTorch Image Models},
|
242 |
+
year = {2019},
|
243 |
+
publisher = {GitHub},
|
244 |
+
journal = {GitHub repository},
|
245 |
+
doi = {10.5281/zenodo.4414861},
|
246 |
+
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
|
247 |
+
}
|
248 |
+
```
|
TimeSformer/configs/Kinetics/SLOWFAST_4x16_R50.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 64
|
5 |
+
EVAL_PERIOD: 10
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 32
|
11 |
+
SAMPLING_RATE: 2
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 256
|
15 |
+
INPUT_CHANNEL_NUM: [3, 3]
|
16 |
+
SLOWFAST:
|
17 |
+
ALPHA: 8
|
18 |
+
BETA_INV: 8
|
19 |
+
FUSION_CONV_CHANNEL_RATIO: 2
|
20 |
+
FUSION_KERNEL_SZ: 5
|
21 |
+
RESNET:
|
22 |
+
ZERO_INIT_FINAL_BN: True
|
23 |
+
WIDTH_PER_GROUP: 64
|
24 |
+
NUM_GROUPS: 1
|
25 |
+
DEPTH: 50
|
26 |
+
TRANS_FUNC: bottleneck_transform
|
27 |
+
STRIDE_1X1: False
|
28 |
+
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
|
29 |
+
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
|
30 |
+
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
31 |
+
NONLOCAL:
|
32 |
+
LOCATION: [[[], []], [[], []], [[], []], [[], []]]
|
33 |
+
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
34 |
+
INSTANTIATION: dot_product
|
35 |
+
BN:
|
36 |
+
USE_PRECISE_STATS: True
|
37 |
+
NUM_BATCHES_PRECISE: 200
|
38 |
+
SOLVER:
|
39 |
+
BASE_LR: 0.8
|
40 |
+
LR_POLICY: cosine
|
41 |
+
MAX_EPOCH: 196
|
42 |
+
MOMENTUM: 0.9
|
43 |
+
WEIGHT_DECAY: 1e-4
|
44 |
+
WARMUP_EPOCHS: 34.0
|
45 |
+
WARMUP_START_LR: 0.01
|
46 |
+
OPTIMIZING_METHOD: sgd
|
47 |
+
MODEL:
|
48 |
+
NUM_CLASSES: 400
|
49 |
+
ARCH: slowfast
|
50 |
+
MODEL_NAME: SlowFast
|
51 |
+
LOSS_FUNC: cross_entropy
|
52 |
+
DROPOUT_RATE: 0.5
|
53 |
+
TEST:
|
54 |
+
ENABLE: True
|
55 |
+
DATASET: kinetics
|
56 |
+
BATCH_SIZE: 64
|
57 |
+
DATA_LOADER:
|
58 |
+
NUM_WORKERS: 8
|
59 |
+
PIN_MEMORY: True
|
60 |
+
NUM_GPUS: 8
|
61 |
+
NUM_SHARDS: 1
|
62 |
+
RNG_SEED: 0
|
63 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/SLOWFAST_8x8_R101.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 64
|
5 |
+
EVAL_PERIOD: 10
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 32
|
11 |
+
SAMPLING_RATE: 2
|
12 |
+
TRAIN_JITTER_SCALES: [256, 340]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 256
|
15 |
+
INPUT_CHANNEL_NUM: [3, 3]
|
16 |
+
SLOWFAST:
|
17 |
+
ALPHA: 4
|
18 |
+
BETA_INV: 8
|
19 |
+
FUSION_CONV_CHANNEL_RATIO: 2
|
20 |
+
FUSION_KERNEL_SZ: 5
|
21 |
+
RESNET:
|
22 |
+
ZERO_INIT_FINAL_BN: True
|
23 |
+
WIDTH_PER_GROUP: 64
|
24 |
+
NUM_GROUPS: 1
|
25 |
+
DEPTH: 101
|
26 |
+
TRANS_FUNC: bottleneck_transform
|
27 |
+
STRIDE_1X1: False
|
28 |
+
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
|
29 |
+
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
|
30 |
+
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
31 |
+
NONLOCAL:
|
32 |
+
LOCATION: [[[], []], [[], []], [[], []], [[], []]]
|
33 |
+
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
34 |
+
INSTANTIATION: dot_product
|
35 |
+
BN:
|
36 |
+
USE_PRECISE_STATS: True
|
37 |
+
NUM_BATCHES_PRECISE: 200
|
38 |
+
SOLVER:
|
39 |
+
BASE_LR: 0.8 ## 8 nodes
|
40 |
+
LR_POLICY: cosine
|
41 |
+
MAX_EPOCH: 196
|
42 |
+
MOMENTUM: 0.9
|
43 |
+
WEIGHT_DECAY: 1e-4
|
44 |
+
WARMUP_EPOCHS: 34.0
|
45 |
+
WARMUP_START_LR: 0.01
|
46 |
+
OPTIMIZING_METHOD: sgd
|
47 |
+
MODEL:
|
48 |
+
NUM_CLASSES: 400
|
49 |
+
ARCH: slowfast
|
50 |
+
MODEL_NAME: SlowFast
|
51 |
+
LOSS_FUNC: cross_entropy
|
52 |
+
DROPOUT_RATE: 0.5
|
53 |
+
TEST:
|
54 |
+
ENABLE: True
|
55 |
+
DATASET: kinetics
|
56 |
+
BATCH_SIZE: 64
|
57 |
+
DATA_LOADER:
|
58 |
+
NUM_WORKERS: 8
|
59 |
+
PIN_MEMORY: True
|
60 |
+
NUM_GPUS: 8
|
61 |
+
NUM_SHARDS: 1
|
62 |
+
RNG_SEED: 0
|
63 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/SLOWFAST_8x8_R50.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 64
|
5 |
+
EVAL_PERIOD: 10
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 32
|
11 |
+
SAMPLING_RATE: 2
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 256
|
15 |
+
INPUT_CHANNEL_NUM: [3, 3]
|
16 |
+
SLOWFAST:
|
17 |
+
ALPHA: 4
|
18 |
+
BETA_INV: 8
|
19 |
+
FUSION_CONV_CHANNEL_RATIO: 2
|
20 |
+
FUSION_KERNEL_SZ: 7
|
21 |
+
RESNET:
|
22 |
+
ZERO_INIT_FINAL_BN: True
|
23 |
+
WIDTH_PER_GROUP: 64
|
24 |
+
NUM_GROUPS: 1
|
25 |
+
DEPTH: 50
|
26 |
+
TRANS_FUNC: bottleneck_transform
|
27 |
+
STRIDE_1X1: False
|
28 |
+
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
|
29 |
+
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
|
30 |
+
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
31 |
+
NONLOCAL:
|
32 |
+
LOCATION: [[[], []], [[], []], [[], []], [[], []]]
|
33 |
+
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
34 |
+
INSTANTIATION: dot_product
|
35 |
+
BN:
|
36 |
+
USE_PRECISE_STATS: True
|
37 |
+
NUM_BATCHES_PRECISE: 200
|
38 |
+
SOLVER:
|
39 |
+
BASE_LR: 0.8
|
40 |
+
LR_POLICY: cosine
|
41 |
+
MAX_EPOCH: 196
|
42 |
+
MOMENTUM: 0.9
|
43 |
+
WEIGHT_DECAY: 1e-4
|
44 |
+
WARMUP_EPOCHS: 34.0
|
45 |
+
WARMUP_START_LR: 0.01
|
46 |
+
OPTIMIZING_METHOD: sgd
|
47 |
+
MODEL:
|
48 |
+
NUM_CLASSES: 400
|
49 |
+
ARCH: slowfast
|
50 |
+
MODEL_NAME: SlowFast
|
51 |
+
LOSS_FUNC: cross_entropy
|
52 |
+
DROPOUT_RATE: 0.5
|
53 |
+
TEST:
|
54 |
+
ENABLE: True
|
55 |
+
DATASET: kinetics
|
56 |
+
BATCH_SIZE: 64
|
57 |
+
DATA_LOADER:
|
58 |
+
NUM_WORKERS: 8
|
59 |
+
PIN_MEMORY: True
|
60 |
+
NUM_GPUS: 8
|
61 |
+
NUM_SHARDS: 1
|
62 |
+
RNG_SEED: 0
|
63 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_divST_16x16_448.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 16
|
11 |
+
SAMPLING_RATE: 16
|
12 |
+
TRAIN_JITTER_SCALES: [448, 512]
|
13 |
+
TRAIN_CROP_SIZE: 448
|
14 |
+
TEST_CROP_SIZE: 448
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'divided_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 8
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 8
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 8
|
11 |
+
SAMPLING_RATE: 32
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'divided_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 8
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 8
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 4
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 8
|
11 |
+
SAMPLING_RATE: 32
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'divided_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 4
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 4
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 4
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: False
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 8
|
11 |
+
SAMPLING_RATE: 32
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'divided_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
CHECKPOINT_FILE_PATH: '/checkpoint/gedas/jobs/timesformer/kinetics_400/TimeSformer_divST_8x32_224/checkpoints/checkpoint_epoch_00025.pyth'
|
40 |
+
DATA_LOADER:
|
41 |
+
NUM_WORKERS: 8
|
42 |
+
PIN_MEMORY: True
|
43 |
+
NUM_GPUS: 8
|
44 |
+
NUM_SHARDS: 1
|
45 |
+
RNG_SEED: 0
|
46 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_divST_96x4_224.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 96
|
11 |
+
SAMPLING_RATE: 4
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'divided_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 8
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 8
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 8
|
11 |
+
SAMPLING_RATE: 32
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'joint_space_time'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 8
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 8
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: kinetics
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: /path/to/kinetics/
|
10 |
+
NUM_FRAMES: 8
|
11 |
+
SAMPLING_RATE: 32
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
TIMESFORMER:
|
17 |
+
ATTENTION_TYPE: 'space_only'
|
18 |
+
SOLVER:
|
19 |
+
BASE_LR: 0.005
|
20 |
+
LR_POLICY: steps_with_relative_lrs
|
21 |
+
STEPS: [0, 11, 14]
|
22 |
+
LRS: [1, 0.1, 0.01]
|
23 |
+
MAX_EPOCH: 15
|
24 |
+
MOMENTUM: 0.9
|
25 |
+
WEIGHT_DECAY: 1e-4
|
26 |
+
OPTIMIZING_METHOD: sgd
|
27 |
+
MODEL:
|
28 |
+
MODEL_NAME: vit_base_patch16_224
|
29 |
+
NUM_CLASSES: 400
|
30 |
+
ARCH: vit
|
31 |
+
LOSS_FUNC: cross_entropy
|
32 |
+
DROPOUT_RATE: 0.5
|
33 |
+
TEST:
|
34 |
+
ENABLE: True
|
35 |
+
DATASET: kinetics
|
36 |
+
BATCH_SIZE: 8
|
37 |
+
NUM_ENSEMBLE_VIEWS: 1
|
38 |
+
NUM_SPATIAL_CROPS: 3
|
39 |
+
DATA_LOADER:
|
40 |
+
NUM_WORKERS: 8
|
41 |
+
PIN_MEMORY: True
|
42 |
+
NUM_GPUS: 8
|
43 |
+
NUM_SHARDS: 1
|
44 |
+
RNG_SEED: 0
|
45 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/SSv2/SLOWFAST_16x8_R50.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: ssv2
|
4 |
+
BATCH_SIZE: 16
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
|
10 |
+
PATH_PREFIX: "/path/to/ssv2/frames/"
|
11 |
+
NUM_FRAMES: 64
|
12 |
+
SAMPLING_RATE: 2
|
13 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
14 |
+
TRAIN_CROP_SIZE: 224
|
15 |
+
TEST_CROP_SIZE: 256
|
16 |
+
INPUT_CHANNEL_NUM: [3, 3]
|
17 |
+
INV_UNIFORM_SAMPLE: True
|
18 |
+
RANDOM_FLIP: False
|
19 |
+
REVERSE_INPUT_CHANNEL: True
|
20 |
+
SLOWFAST:
|
21 |
+
ALPHA: 4
|
22 |
+
BETA_INV: 8
|
23 |
+
FUSION_CONV_CHANNEL_RATIO: 2
|
24 |
+
FUSION_KERNEL_SZ: 7
|
25 |
+
RESNET:
|
26 |
+
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
|
27 |
+
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
28 |
+
ZERO_INIT_FINAL_BN: True
|
29 |
+
WIDTH_PER_GROUP: 64
|
30 |
+
NUM_GROUPS: 1
|
31 |
+
DEPTH: 50
|
32 |
+
TRANS_FUNC: bottleneck_transform
|
33 |
+
STRIDE_1X1: False
|
34 |
+
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
|
35 |
+
NONLOCAL:
|
36 |
+
LOCATION: [[[], []], [[], []], [[], []], [[], []]]
|
37 |
+
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
|
38 |
+
INSTANTIATION: dot_product
|
39 |
+
BN:
|
40 |
+
USE_PRECISE_STATS: True
|
41 |
+
NUM_BATCHES_PRECISE: 200
|
42 |
+
NORM_TYPE: sync_batchnorm
|
43 |
+
NUM_SYNC_DEVICES: 4
|
44 |
+
SOLVER:
|
45 |
+
BASE_LR: 0.2 #8 nodes
|
46 |
+
LR_POLICY: cosine
|
47 |
+
MAX_EPOCH: 200
|
48 |
+
MOMENTUM: 0.9
|
49 |
+
WEIGHT_DECAY: 1e-4
|
50 |
+
WARMUP_EPOCHS: 34.0
|
51 |
+
WARMUP_START_LR: 0.01
|
52 |
+
OPTIMIZING_METHOD: sgd
|
53 |
+
#SOLVER:
|
54 |
+
# BASE_LR: 0.03
|
55 |
+
# LR_POLICY: steps_with_relative_lrs
|
56 |
+
# LRS: [1, 0.1, 0.01, 0.001, 0.0001, 0.00001]
|
57 |
+
# STEPS: [0, 14, 18]
|
58 |
+
# MAX_EPOCH: 22
|
59 |
+
# MOMENTUM: 0.9
|
60 |
+
# WEIGHT_DECAY: 1e-6
|
61 |
+
# WARMUP_EPOCHS: 0.19
|
62 |
+
# WARMUP_START_LR: 0.0001
|
63 |
+
# OPTIMIZING_METHOD: sgd
|
64 |
+
MODEL:
|
65 |
+
NUM_CLASSES: 174
|
66 |
+
ARCH: slowfast
|
67 |
+
LOSS_FUNC: cross_entropy
|
68 |
+
DROPOUT_RATE: 0.5
|
69 |
+
TEST:
|
70 |
+
ENABLE: True
|
71 |
+
DATASET: ssv2
|
72 |
+
BATCH_SIZE: 16
|
73 |
+
NUM_ENSEMBLE_VIEWS: 1
|
74 |
+
NUM_SPATIAL_CROPS: 1
|
75 |
+
DATA_LOADER:
|
76 |
+
NUM_WORKERS: 4
|
77 |
+
PIN_MEMORY: True
|
78 |
+
NUM_GPUS: 8
|
79 |
+
NUM_SHARDS: 1
|
80 |
+
RNG_SEED: 0
|
81 |
+
OUTPUT_DIR: .
|
82 |
+
#LOG_MODEL_INFO: False
|
83 |
+
LOG_MODEL_INFO: True
|
TimeSformer/configs/SSv2/TimeSformer_divST_16_448.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: ssv2
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
|
10 |
+
PATH_PREFIX: "/path/to/ssv2/frames/"
|
11 |
+
NUM_FRAMES: 16
|
12 |
+
TRAIN_JITTER_SCALES: [448, 512]
|
13 |
+
TRAIN_CROP_SIZE: 448
|
14 |
+
TEST_CROP_SIZE: 448
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
INV_UNIFORM_SAMPLE: True
|
17 |
+
RANDOM_FLIP: False
|
18 |
+
REVERSE_INPUT_CHANNEL: True
|
19 |
+
TIMESFORMER:
|
20 |
+
ATTENTION_TYPE: 'divided_space_time'
|
21 |
+
SOLVER:
|
22 |
+
BASE_LR: 0.005
|
23 |
+
LR_POLICY: steps_with_relative_lrs
|
24 |
+
STEPS: [0, 11, 14]
|
25 |
+
LRS: [1, 0.1, 0.01]
|
26 |
+
MAX_EPOCH: 15
|
27 |
+
MOMENTUM: 0.9
|
28 |
+
WEIGHT_DECAY: 1e-4
|
29 |
+
OPTIMIZING_METHOD: sgd
|
30 |
+
MODEL:
|
31 |
+
MODEL_NAME: vit_base_patch16_224
|
32 |
+
NUM_CLASSES: 174
|
33 |
+
ARCH: vit
|
34 |
+
LOSS_FUNC: cross_entropy
|
35 |
+
DROPOUT_RATE: 0.5
|
36 |
+
TEST:
|
37 |
+
ENABLE: True
|
38 |
+
DATASET: ssv2
|
39 |
+
BATCH_SIZE: 8
|
40 |
+
NUM_ENSEMBLE_VIEWS: 1
|
41 |
+
NUM_SPATIAL_CROPS: 3
|
42 |
+
DATA_LOADER:
|
43 |
+
NUM_WORKERS: 4
|
44 |
+
PIN_MEMORY: True
|
45 |
+
NUM_GPUS: 8
|
46 |
+
NUM_SHARDS: 1
|
47 |
+
RNG_SEED: 0
|
48 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/SSv2/TimeSformer_divST_64_224.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: ssv2
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
|
10 |
+
PATH_PREFIX: "/path/to/ssv2/frames/"
|
11 |
+
NUM_FRAMES: 64
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
INV_UNIFORM_SAMPLE: True
|
17 |
+
RANDOM_FLIP: False
|
18 |
+
REVERSE_INPUT_CHANNEL: True
|
19 |
+
TIMESFORMER:
|
20 |
+
ATTENTION_TYPE: 'divided_space_time'
|
21 |
+
SOLVER:
|
22 |
+
BASE_LR: 0.005
|
23 |
+
LR_POLICY: steps_with_relative_lrs
|
24 |
+
STEPS: [0, 11, 14]
|
25 |
+
LRS: [1, 0.1, 0.01]
|
26 |
+
MAX_EPOCH: 15
|
27 |
+
MOMENTUM: 0.9
|
28 |
+
WEIGHT_DECAY: 1e-4
|
29 |
+
OPTIMIZING_METHOD: sgd
|
30 |
+
MODEL:
|
31 |
+
MODEL_NAME: vit_base_patch16_224
|
32 |
+
NUM_CLASSES: 174
|
33 |
+
ARCH: vit
|
34 |
+
LOSS_FUNC: cross_entropy
|
35 |
+
DROPOUT_RATE: 0.5
|
36 |
+
TEST:
|
37 |
+
ENABLE: True
|
38 |
+
DATASET: ssv2
|
39 |
+
BATCH_SIZE: 8
|
40 |
+
NUM_ENSEMBLE_VIEWS: 1
|
41 |
+
NUM_SPATIAL_CROPS: 3
|
42 |
+
DATA_LOADER:
|
43 |
+
NUM_WORKERS: 4
|
44 |
+
PIN_MEMORY: True
|
45 |
+
NUM_GPUS: 8
|
46 |
+
NUM_SHARDS: 1
|
47 |
+
RNG_SEED: 0
|
48 |
+
OUTPUT_DIR: .
|
TimeSformer/configs/SSv2/TimeSformer_divST_8_224.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: ssv2
|
4 |
+
BATCH_SIZE: 8
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
DATA:
|
9 |
+
PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
|
10 |
+
PATH_PREFIX: "/path/to/ssv2/frames/"
|
11 |
+
NUM_FRAMES: 8
|
12 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
13 |
+
TRAIN_CROP_SIZE: 224
|
14 |
+
TEST_CROP_SIZE: 224
|
15 |
+
INPUT_CHANNEL_NUM: [3]
|
16 |
+
INV_UNIFORM_SAMPLE: True
|
17 |
+
RANDOM_FLIP: False
|
18 |
+
REVERSE_INPUT_CHANNEL: True
|
19 |
+
TIMESFORMER:
|
20 |
+
ATTENTION_TYPE: 'divided_space_time'
|
21 |
+
SOLVER:
|
22 |
+
BASE_LR: 0.005
|
23 |
+
LR_POLICY: steps_with_relative_lrs
|
24 |
+
STEPS: [0, 11, 14]
|
25 |
+
LRS: [1, 0.1, 0.01]
|
26 |
+
MAX_EPOCH: 15
|
27 |
+
MOMENTUM: 0.9
|
28 |
+
WEIGHT_DECAY: 1e-4
|
29 |
+
OPTIMIZING_METHOD: sgd
|
30 |
+
MODEL:
|
31 |
+
MODEL_NAME: vit_base_patch16_224
|
32 |
+
NUM_CLASSES: 174
|
33 |
+
ARCH: vit
|
34 |
+
LOSS_FUNC: cross_entropy
|
35 |
+
DROPOUT_RATE: 0.5
|
36 |
+
TEST:
|
37 |
+
ENABLE: True
|
38 |
+
DATASET: ssv2
|
39 |
+
BATCH_SIZE: 8
|
40 |
+
NUM_ENSEMBLE_VIEWS: 1
|
41 |
+
NUM_SPATIAL_CROPS: 3
|
42 |
+
DATA_LOADER:
|
43 |
+
NUM_WORKERS: 4
|
44 |
+
PIN_MEMORY: True
|
45 |
+
NUM_GPUS: 8
|
46 |
+
NUM_SHARDS: 1
|
47 |
+
RNG_SEED: 0
|
48 |
+
OUTPUT_DIR: .
|
TimeSformer/environment.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: timesformer
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- conda-forge
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- python>3.7
|
8 |
+
- jupyterlab
|
9 |
+
- pandas>=1.2
|
10 |
+
- numpy>1.19
|
11 |
+
- pytorch>=1.6
|
12 |
+
- torchvision>=0.7
|
13 |
+
- scikit-learn>=0.22
|
14 |
+
- opencv>=4.2
|
15 |
+
- pyyaml>=5.1
|
16 |
+
- yacs>=0.1.6
|
17 |
+
- einops>=0.3
|
18 |
+
- tensorboard
|
19 |
+
- psutil
|
20 |
+
- tqdm
|
21 |
+
- matplotlib
|
22 |
+
- simplejson
|
23 |
+
- pip
|
24 |
+
- pip:
|
25 |
+
- fvcore
|
26 |
+
- av
|
TimeSformer/example.ipynb
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "08fe0c59",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from pathlib import Path\n",
|
11 |
+
"\n",
|
12 |
+
"import torch\n",
|
13 |
+
"from timesformer.models.vit import TimeSformer"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": 2,
|
19 |
+
"id": "10239d32",
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [
|
22 |
+
{
|
23 |
+
"data": {
|
24 |
+
"text/plain": [
|
25 |
+
"True"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
"execution_count": 2,
|
29 |
+
"metadata": {},
|
30 |
+
"output_type": "execute_result"
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"model_file = Path.home()/'TimeSformer/models/TimeSformer_divST_8x32_224_K600.pyth'\n",
|
35 |
+
"model_file.exists()"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 3,
|
41 |
+
"id": "652fb03e",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"model = TimeSformer(img_size=224, num_classes=600, num_frames=8, attention_type='divided_space_time', pretrained_model=str(model_file))\n",
|
46 |
+
"\n",
|
47 |
+
"dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)\n",
|
48 |
+
"\n",
|
49 |
+
"pred = model(dummy_video,) # (2, 600)"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 6,
|
55 |
+
"id": "83de13c5-791c-4db7-aba4-6d29ce88584e",
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"assert pred.shape == (2,600)"
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"metadata": {
|
64 |
+
"kernelspec": {
|
65 |
+
"display_name": "Python 3",
|
66 |
+
"language": "python",
|
67 |
+
"name": "python3"
|
68 |
+
},
|
69 |
+
"language_info": {
|
70 |
+
"codemirror_mode": {
|
71 |
+
"name": "ipython",
|
72 |
+
"version": 3
|
73 |
+
},
|
74 |
+
"file_extension": ".py",
|
75 |
+
"mimetype": "text/x-python",
|
76 |
+
"name": "python",
|
77 |
+
"nbconvert_exporter": "python",
|
78 |
+
"pygments_lexer": "ipython3",
|
79 |
+
"version": "3.9.4"
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"nbformat": 4,
|
83 |
+
"nbformat_minor": 5
|
84 |
+
}
|
TimeSformer/setup.cfg
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[isort]
|
2 |
+
line_length=100
|
3 |
+
multi_line_output=4
|
4 |
+
known_standard_library=numpy,setuptools
|
5 |
+
known_myself=timesformer
|
6 |
+
known_third_party=fvcore,av,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib,torchvision,yaml,tqdm,psutil,opencv-python,pandas,tensorboard,moviepy,sklearn,cv2
|
7 |
+
no_lines_before=STDLIB,THIRDPARTY
|
8 |
+
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
|
9 |
+
default_section=FIRSTPARTY
|
10 |
+
|
11 |
+
[mypy]
|
12 |
+
python_version=3.6
|
13 |
+
ignore_missing_imports = True
|
14 |
+
warn_unused_configs = True
|
15 |
+
disallow_untyped_defs = True
|
16 |
+
check_untyped_defs = True
|
17 |
+
warn_unused_ignores = True
|
18 |
+
warn_redundant_casts = True
|
19 |
+
show_column_numbers = True
|
20 |
+
follow_imports = silent
|
21 |
+
allow_redefinition = True
|
22 |
+
; Require all functions to be annotated
|
23 |
+
disallow_incomplete_defs = True
|
TimeSformer/setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from setuptools import find_packages, setup
|
4 |
+
|
5 |
+
setup(
|
6 |
+
name="timesformer",
|
7 |
+
version="1.0",
|
8 |
+
author="FBAI",
|
9 |
+
url="unknown",
|
10 |
+
description="TimeSformer",
|
11 |
+
keywords = [
|
12 |
+
'artificial intelligence',
|
13 |
+
'attention mechanism',
|
14 |
+
'transformers',
|
15 |
+
'video classification',
|
16 |
+
],
|
17 |
+
install_requires=[
|
18 |
+
'einops>=0.3',
|
19 |
+
'torch>=1.6'
|
20 |
+
],
|
21 |
+
extras_require={"tensorboard_video_visualization": ["moviepy"]},
|
22 |
+
packages=find_packages(exclude=("configs", "tests")),
|
23 |
+
)
|
TimeSformer/slurm_scripts/run_multi_node_job.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
# A script with a list of commands for submitting SLURM jobs
|
3 |
+
|
4 |
+
#### Kinetics training
|
5 |
+
JOB_NAME=TimeSformer_divST_8x32_224
|
6 |
+
python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition dev --comment "" --name ${JOB_NAME} --use_volta32
|
7 |
+
|
8 |
+
#JOB_NAME=TimeSformer_jointST_8x32_224
|
9 |
+
#python tools/submit.py --cfg configs/Kinetics/TimeSformer_jointST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
|
10 |
+
|
11 |
+
#JOB_NAME=TimeSformer_spaceOnly_8x32_224
|
12 |
+
#python tools/submit.py --cfg configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
|
13 |
+
|
14 |
+
#### Kinetics inference
|
15 |
+
#JOB_NAME=TimeSformer_divST_8x32_224_TEST_3clips
|
16 |
+
#python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition dev --comment "" --name ${JOB_NAME} --use_volta32
|
17 |
+
|
18 |
+
|
19 |
+
##### SSv2 training
|
20 |
+
#JOB_NAME=TimeSformer_divST_8_224
|
21 |
+
#python tools/submit.py --cfg configs/SSv2/TimeSformer_divST_8_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
|
22 |
+
|
23 |
+
##### Sth-Sth_v2 inference
|
24 |
+
#JOB_NAME=TimeSformer_divST_8_224_TEST_3clips
|
25 |
+
#python tools/submit.py --cfg configs/SSv2/TimeSformer_divST_8_224_TEST.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
|
TimeSformer/slurm_scripts/run_single_node_job.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
# A script with a list of commands for submitting SLURM jobs
|
3 |
+
|
4 |
+
#SBATCH --job-name=timesformer
|
5 |
+
#SBATCH --mail-type=END,FAIL,REQUEUE
|
6 |
+
#SBATCH [email protected]
|
7 |
+
|
8 |
+
## %j is the job id, %u is the user id
|
9 |
+
#SBATCH --output=/path/to/output/logs/slog-%A-%a.out
|
10 |
+
|
11 |
+
## filename for job standard error output (stderr)
|
12 |
+
#SBATCH --error=/path/to/error/logs/slog-%A-%a.err
|
13 |
+
|
14 |
+
#SBATCH --array=1
|
15 |
+
#SBATCH --partition=partition_of_your_choice
|
16 |
+
#SBATCH --nodes=1 -C volta32gb
|
17 |
+
#SBATCH --ntasks-per-node=1
|
18 |
+
#SBATCH --gpus-per-node=8
|
19 |
+
#SBATCH --cpus-per-task=80
|
20 |
+
#SBATCH --mem=480GB
|
21 |
+
#SBATCH --signal=USR1@600
|
22 |
+
#SBATCH --time=72:00:00
|
23 |
+
#SBATCH --open-mode=append
|
24 |
+
|
25 |
+
module purge
|
26 |
+
module load cuda/10.0
|
27 |
+
module load NCCL/2.4.7-1-cuda.10.0
|
28 |
+
module load cudnn/v7.4-cuda.10.0
|
29 |
+
source activate timesformer
|
30 |
+
|
31 |
+
WORKINGDIR=/path/to/TimeSformer
|
32 |
+
CURPYTHON=/path/to/python
|
33 |
+
|
34 |
+
srun --label ${CURPYTHON} ${WORKINGDIR}/tools/run_net.py --cfg ${WORKINGDIR}/configs/Kinetics/TimeSformer_divST_8x32_224.yaml NUM_GPUS 8 TRAIN.BATCH_SIZE 8
|
35 |
+
|
TimeSformer/timesformer/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from timesformer.utils.env import setup_environment
|
4 |
+
|
5 |
+
setup_environment()
|
TimeSformer/timesformer/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
TimeSformer/timesformer/config/defaults.py
ADDED
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""Configs."""
|
4 |
+
from fvcore.common.config import CfgNode
|
5 |
+
# -----------------------------------------------------------------------------
|
6 |
+
# Config definition
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
_C = CfgNode()
|
9 |
+
|
10 |
+
# ---------------------------------------------------------------------------- #
|
11 |
+
# Batch norm options
|
12 |
+
# ---------------------------------------------------------------------------- #
|
13 |
+
_C.BN = CfgNode()
|
14 |
+
|
15 |
+
# Precise BN stats.
|
16 |
+
_C.BN.USE_PRECISE_STATS = False
|
17 |
+
|
18 |
+
# Number of samples use to compute precise bn.
|
19 |
+
_C.BN.NUM_BATCHES_PRECISE = 200
|
20 |
+
|
21 |
+
# Weight decay value that applies on BN.
|
22 |
+
_C.BN.WEIGHT_DECAY = 0.0
|
23 |
+
|
24 |
+
# Norm type, options include `batchnorm`, `sub_batchnorm`, `sync_batchnorm`
|
25 |
+
_C.BN.NORM_TYPE = "batchnorm"
|
26 |
+
|
27 |
+
# Parameter for SubBatchNorm, where it splits the batch dimension into
|
28 |
+
# NUM_SPLITS splits, and run BN on each of them separately independently.
|
29 |
+
_C.BN.NUM_SPLITS = 1
|
30 |
+
|
31 |
+
# Parameter for NaiveSyncBatchNorm3d, where the stats across `NUM_SYNC_DEVICES`
|
32 |
+
# devices will be synchronized.
|
33 |
+
_C.BN.NUM_SYNC_DEVICES = 1
|
34 |
+
|
35 |
+
|
36 |
+
# ---------------------------------------------------------------------------- #
|
37 |
+
# Training options.
|
38 |
+
# ---------------------------------------------------------------------------- #
|
39 |
+
_C.TRAIN = CfgNode()
|
40 |
+
|
41 |
+
# If True Train the model, else skip training.
|
42 |
+
_C.TRAIN.ENABLE = True
|
43 |
+
|
44 |
+
# Dataset.
|
45 |
+
_C.TRAIN.DATASET = "kinetics"
|
46 |
+
|
47 |
+
##
|
48 |
+
_C.TRAIN.FINETUNE = False
|
49 |
+
|
50 |
+
# Total mini-batch size.
|
51 |
+
_C.TRAIN.BATCH_SIZE = 64
|
52 |
+
|
53 |
+
# Evaluate model on test data every eval period epochs.
|
54 |
+
_C.TRAIN.EVAL_PERIOD = 10
|
55 |
+
|
56 |
+
# Save model checkpoint every checkpoint period epochs.
|
57 |
+
_C.TRAIN.CHECKPOINT_PERIOD = 10
|
58 |
+
|
59 |
+
# Resume training from the latest checkpoint in the output directory.
|
60 |
+
_C.TRAIN.AUTO_RESUME = True
|
61 |
+
|
62 |
+
# Path to the checkpoint to load the initial weight.
|
63 |
+
_C.TRAIN.CHECKPOINT_FILE_PATH = ""
|
64 |
+
|
65 |
+
# Checkpoint types include `caffe2` or `pytorch`.
|
66 |
+
_C.TRAIN.CHECKPOINT_TYPE = "pytorch"
|
67 |
+
|
68 |
+
# If True, perform inflation when loading checkpoint.
|
69 |
+
_C.TRAIN.CHECKPOINT_INFLATE = False
|
70 |
+
|
71 |
+
# If True, reset epochs when loading checkpoint.
|
72 |
+
_C.TRAIN.CHECKPOINT_EPOCH_RESET = False
|
73 |
+
|
74 |
+
# If set, clear all layer names according to the pattern provided.
|
75 |
+
_C.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN = () # ("backbone.",)
|
76 |
+
|
77 |
+
# ---------------------------------------------------------------------------- #
|
78 |
+
# Testing options
|
79 |
+
# ---------------------------------------------------------------------------- #
|
80 |
+
_C.TEST = CfgNode()
|
81 |
+
|
82 |
+
# If True test the model, else skip the testing.
|
83 |
+
_C.TEST.ENABLE = True
|
84 |
+
|
85 |
+
# Dataset for testing.
|
86 |
+
_C.TEST.DATASET = "kinetics"
|
87 |
+
|
88 |
+
# Total mini-batch size
|
89 |
+
_C.TEST.BATCH_SIZE = 8
|
90 |
+
|
91 |
+
# Path to the checkpoint to load the initial weight.
|
92 |
+
_C.TEST.CHECKPOINT_FILE_PATH = ""
|
93 |
+
|
94 |
+
# Number of clips to sample from a video uniformly for aggregating the
|
95 |
+
# prediction results.
|
96 |
+
_C.TEST.NUM_ENSEMBLE_VIEWS = 10
|
97 |
+
|
98 |
+
# Number of crops to sample from a frame spatially for aggregating the
|
99 |
+
# prediction results.
|
100 |
+
_C.TEST.NUM_SPATIAL_CROPS = 3
|
101 |
+
|
102 |
+
# Checkpoint types include `caffe2` or `pytorch`.
|
103 |
+
_C.TEST.CHECKPOINT_TYPE = "pytorch"
|
104 |
+
# Path to saving prediction results file.
|
105 |
+
_C.TEST.SAVE_RESULTS_PATH = ""
|
106 |
+
# -----------------------------------------------------------------------------
|
107 |
+
# ResNet options
|
108 |
+
# -----------------------------------------------------------------------------
|
109 |
+
_C.RESNET = CfgNode()
|
110 |
+
|
111 |
+
# Transformation function.
|
112 |
+
_C.RESNET.TRANS_FUNC = "bottleneck_transform"
|
113 |
+
|
114 |
+
# Number of groups. 1 for ResNet, and larger than 1 for ResNeXt).
|
115 |
+
_C.RESNET.NUM_GROUPS = 1
|
116 |
+
|
117 |
+
# Width of each group (64 -> ResNet; 4 -> ResNeXt).
|
118 |
+
_C.RESNET.WIDTH_PER_GROUP = 64
|
119 |
+
|
120 |
+
# Apply relu in a inplace manner.
|
121 |
+
_C.RESNET.INPLACE_RELU = True
|
122 |
+
|
123 |
+
# Apply stride to 1x1 conv.
|
124 |
+
_C.RESNET.STRIDE_1X1 = False
|
125 |
+
|
126 |
+
# If true, initialize the gamma of the final BN of each block to zero.
|
127 |
+
_C.RESNET.ZERO_INIT_FINAL_BN = False
|
128 |
+
|
129 |
+
# Number of weight layers.
|
130 |
+
_C.RESNET.DEPTH = 50
|
131 |
+
|
132 |
+
# If the current block has more than NUM_BLOCK_TEMP_KERNEL blocks, use temporal
|
133 |
+
# kernel of 1 for the rest of the blocks.
|
134 |
+
_C.RESNET.NUM_BLOCK_TEMP_KERNEL = [[3], [4], [6], [3]]
|
135 |
+
|
136 |
+
# Size of stride on different res stages.
|
137 |
+
_C.RESNET.SPATIAL_STRIDES = [[1], [2], [2], [2]]
|
138 |
+
|
139 |
+
# Size of dilation on different res stages.
|
140 |
+
_C.RESNET.SPATIAL_DILATIONS = [[1], [1], [1], [1]]
|
141 |
+
|
142 |
+
# ---------------------------------------------------------------------------- #
|
143 |
+
# X3D options
|
144 |
+
# See https://arxiv.org/abs/2004.04730 for details about X3D Networks.
|
145 |
+
# ---------------------------------------------------------------------------- #
|
146 |
+
_C.X3D = CfgNode()
|
147 |
+
|
148 |
+
# Width expansion factor.
|
149 |
+
_C.X3D.WIDTH_FACTOR = 1.0
|
150 |
+
|
151 |
+
# Depth expansion factor.
|
152 |
+
_C.X3D.DEPTH_FACTOR = 1.0
|
153 |
+
|
154 |
+
# Bottleneck expansion factor for the 3x3x3 conv.
|
155 |
+
_C.X3D.BOTTLENECK_FACTOR = 1.0 #
|
156 |
+
|
157 |
+
# Dimensions of the last linear layer before classificaiton.
|
158 |
+
_C.X3D.DIM_C5 = 2048
|
159 |
+
|
160 |
+
# Dimensions of the first 3x3 conv layer.
|
161 |
+
_C.X3D.DIM_C1 = 12
|
162 |
+
|
163 |
+
# Whether to scale the width of Res2, default is false.
|
164 |
+
_C.X3D.SCALE_RES2 = False
|
165 |
+
|
166 |
+
# Whether to use a BatchNorm (BN) layer before the classifier, default is false.
|
167 |
+
_C.X3D.BN_LIN5 = False
|
168 |
+
|
169 |
+
# Whether to use channelwise (=depthwise) convolution in the center (3x3x3)
|
170 |
+
# convolution operation of the residual blocks.
|
171 |
+
_C.X3D.CHANNELWISE_3x3x3 = True
|
172 |
+
|
173 |
+
# -----------------------------------------------------------------------------
|
174 |
+
# Nonlocal options
|
175 |
+
# -----------------------------------------------------------------------------
|
176 |
+
_C.NONLOCAL = CfgNode()
|
177 |
+
|
178 |
+
# Index of each stage and block to add nonlocal layers.
|
179 |
+
_C.NONLOCAL.LOCATION = [[[]], [[]], [[]], [[]]]
|
180 |
+
|
181 |
+
# Number of group for nonlocal for each stage.
|
182 |
+
_C.NONLOCAL.GROUP = [[1], [1], [1], [1]]
|
183 |
+
|
184 |
+
# Instatiation to use for non-local layer.
|
185 |
+
_C.NONLOCAL.INSTANTIATION = "dot_product"
|
186 |
+
|
187 |
+
|
188 |
+
# Size of pooling layers used in Non-Local.
|
189 |
+
_C.NONLOCAL.POOL = [
|
190 |
+
# Res2
|
191 |
+
[[1, 2, 2], [1, 2, 2]],
|
192 |
+
# Res3
|
193 |
+
[[1, 2, 2], [1, 2, 2]],
|
194 |
+
# Res4
|
195 |
+
[[1, 2, 2], [1, 2, 2]],
|
196 |
+
# Res5
|
197 |
+
[[1, 2, 2], [1, 2, 2]],
|
198 |
+
]
|
199 |
+
|
200 |
+
# -----------------------------------------------------------------------------
|
201 |
+
# Model options
|
202 |
+
# -----------------------------------------------------------------------------
|
203 |
+
_C.MODEL = CfgNode()
|
204 |
+
|
205 |
+
# Model architecture.
|
206 |
+
_C.MODEL.ARCH = "slowfast"
|
207 |
+
|
208 |
+
# Model name
|
209 |
+
_C.MODEL.MODEL_NAME = "SlowFast"
|
210 |
+
|
211 |
+
# The number of classes to predict for the model.
|
212 |
+
_C.MODEL.NUM_CLASSES = 400
|
213 |
+
|
214 |
+
# Loss function.
|
215 |
+
_C.MODEL.LOSS_FUNC = "cross_entropy"
|
216 |
+
|
217 |
+
# Model architectures that has one single pathway.
|
218 |
+
_C.MODEL.SINGLE_PATHWAY_ARCH = ["c2d", "i3d", "slow", "x3d"]
|
219 |
+
|
220 |
+
# Model architectures that has multiple pathways.
|
221 |
+
_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"]
|
222 |
+
|
223 |
+
# Dropout rate before final projection in the backbone.
|
224 |
+
_C.MODEL.DROPOUT_RATE = 0.5
|
225 |
+
|
226 |
+
# Randomly drop rate for Res-blocks, linearly increase from res2 to res5
|
227 |
+
_C.MODEL.DROPCONNECT_RATE = 0.0
|
228 |
+
|
229 |
+
# The std to initialize the fc layer(s).
|
230 |
+
_C.MODEL.FC_INIT_STD = 0.01
|
231 |
+
|
232 |
+
# Activation layer for the output head.
|
233 |
+
_C.MODEL.HEAD_ACT = "softmax"
|
234 |
+
|
235 |
+
|
236 |
+
# -----------------------------------------------------------------------------
|
237 |
+
# SlowFast options
|
238 |
+
# -----------------------------------------------------------------------------
|
239 |
+
_C.SLOWFAST = CfgNode()
|
240 |
+
|
241 |
+
# Corresponds to the inverse of the channel reduction ratio, $\beta$ between
|
242 |
+
# the Slow and Fast pathways.
|
243 |
+
_C.SLOWFAST.BETA_INV = 8
|
244 |
+
|
245 |
+
# Corresponds to the frame rate reduction ratio, $\alpha$ between the Slow and
|
246 |
+
# Fast pathways.
|
247 |
+
_C.SLOWFAST.ALPHA = 8
|
248 |
+
|
249 |
+
# Ratio of channel dimensions between the Slow and Fast pathways.
|
250 |
+
_C.SLOWFAST.FUSION_CONV_CHANNEL_RATIO = 2
|
251 |
+
|
252 |
+
# Kernel dimension used for fusing information from Fast pathway to Slow
|
253 |
+
# pathway.
|
254 |
+
_C.SLOWFAST.FUSION_KERNEL_SZ = 5
|
255 |
+
|
256 |
+
####### TimeSformer Options
|
257 |
+
_C.TIMESFORMER = CfgNode()
|
258 |
+
_C.TIMESFORMER.ATTENTION_TYPE = 'divided_space_time'
|
259 |
+
_C.TIMESFORMER.PRETRAINED_MODEL = ''
|
260 |
+
|
261 |
+
## MixUp parameters
|
262 |
+
_C.MIXUP = CfgNode()
|
263 |
+
_C.MIXUP.ENABLED = False
|
264 |
+
_C.MIXUP.ALPHA = 0.8
|
265 |
+
_C.MIXUP.CUTMIX_ALPHA = 1.0
|
266 |
+
_C.MIXUP.CUTMIX_MINMAX = None
|
267 |
+
_C.MIXUP.PROB = 1.0
|
268 |
+
_C.MIXUP.SWITCH_PROB = 0.5
|
269 |
+
_C.MIXUP.MODE = 'batch'
|
270 |
+
|
271 |
+
_C.EMA = CfgNode()
|
272 |
+
_C.EMA.ENABLED = False
|
273 |
+
|
274 |
+
# -----------------------------------------------------------------------------
|
275 |
+
# Data options
|
276 |
+
# -----------------------------------------------------------------------------
|
277 |
+
_C.DATA = CfgNode()
|
278 |
+
|
279 |
+
# The path to the data directory.
|
280 |
+
_C.DATA.PATH_TO_DATA_DIR = ""
|
281 |
+
|
282 |
+
# The separator used between path and label.
|
283 |
+
_C.DATA.PATH_LABEL_SEPARATOR = " "
|
284 |
+
|
285 |
+
# Video path prefix if any.
|
286 |
+
_C.DATA.PATH_PREFIX = ""
|
287 |
+
|
288 |
+
# The spatial crop size of the input clip.
|
289 |
+
_C.DATA.CROP_SIZE = 224
|
290 |
+
|
291 |
+
# The number of frames of the input clip.
|
292 |
+
_C.DATA.NUM_FRAMES = 8
|
293 |
+
|
294 |
+
# The video sampling rate of the input clip.
|
295 |
+
_C.DATA.SAMPLING_RATE = 8
|
296 |
+
|
297 |
+
# The mean value of the video raw pixels across the R G B channels.
|
298 |
+
_C.DATA.MEAN = [0.45, 0.45, 0.45]
|
299 |
+
# List of input frame channel dimensions.
|
300 |
+
|
301 |
+
_C.DATA.INPUT_CHANNEL_NUM = [3, 3]
|
302 |
+
|
303 |
+
# The std value of the video raw pixels across the R G B channels.
|
304 |
+
_C.DATA.STD = [0.225, 0.225, 0.225]
|
305 |
+
|
306 |
+
# The spatial augmentation jitter scales for training.
|
307 |
+
_C.DATA.TRAIN_JITTER_SCALES = [256, 320]
|
308 |
+
|
309 |
+
# The spatial crop size for training.
|
310 |
+
_C.DATA.TRAIN_CROP_SIZE = 224
|
311 |
+
|
312 |
+
# The spatial crop size for testing.
|
313 |
+
_C.DATA.TEST_CROP_SIZE = 256
|
314 |
+
|
315 |
+
# Input videos may has different fps, convert it to the target video fps before
|
316 |
+
# frame sampling.
|
317 |
+
_C.DATA.TARGET_FPS = 30
|
318 |
+
|
319 |
+
# Decoding backend, options include `pyav` or `torchvision`
|
320 |
+
_C.DATA.DECODING_BACKEND = "pyav"
|
321 |
+
|
322 |
+
# if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a
|
323 |
+
# reciprocal to get the scale. If False, take a uniform sample from
|
324 |
+
# [min_scale, max_scale].
|
325 |
+
_C.DATA.INV_UNIFORM_SAMPLE = False
|
326 |
+
|
327 |
+
# If True, perform random horizontal flip on the video frames during training.
|
328 |
+
_C.DATA.RANDOM_FLIP = True
|
329 |
+
|
330 |
+
# If True, calculdate the map as metric.
|
331 |
+
_C.DATA.MULTI_LABEL = False
|
332 |
+
|
333 |
+
# Method to perform the ensemble, options include "sum" and "max".
|
334 |
+
_C.DATA.ENSEMBLE_METHOD = "sum"
|
335 |
+
|
336 |
+
# If True, revert the default input channel (RBG <-> BGR).
|
337 |
+
_C.DATA.REVERSE_INPUT_CHANNEL = False
|
338 |
+
|
339 |
+
############
|
340 |
+
_C.DATA.TEMPORAL_EXTENT = 8
|
341 |
+
_C.DATA.DEIT_TRANSFORMS = False
|
342 |
+
_C.DATA.COLOR_JITTER = 0.
|
343 |
+
_C.DATA.AUTO_AUGMENT = ''
|
344 |
+
_C.DATA.RE_PROB = 0.0
|
345 |
+
|
346 |
+
# ---------------------------------------------------------------------------- #
|
347 |
+
# Optimizer options
|
348 |
+
# ---------------------------------------------------------------------------- #
|
349 |
+
_C.SOLVER = CfgNode()
|
350 |
+
|
351 |
+
# Base learning rate.
|
352 |
+
_C.SOLVER.BASE_LR = 0.1
|
353 |
+
|
354 |
+
# Learning rate policy (see utils/lr_policy.py for options and examples).
|
355 |
+
_C.SOLVER.LR_POLICY = "cosine"
|
356 |
+
|
357 |
+
# Final learning rates for 'cosine' policy.
|
358 |
+
_C.SOLVER.COSINE_END_LR = 0.0
|
359 |
+
|
360 |
+
# Exponential decay factor.
|
361 |
+
_C.SOLVER.GAMMA = 0.1
|
362 |
+
|
363 |
+
# Step size for 'exp' and 'cos' policies (in epochs).
|
364 |
+
_C.SOLVER.STEP_SIZE = 1
|
365 |
+
|
366 |
+
# Steps for 'steps_' policies (in epochs).
|
367 |
+
_C.SOLVER.STEPS = []
|
368 |
+
|
369 |
+
# Learning rates for 'steps_' policies.
|
370 |
+
_C.SOLVER.LRS = []
|
371 |
+
|
372 |
+
# Maximal number of epochs.
|
373 |
+
_C.SOLVER.MAX_EPOCH = 300
|
374 |
+
|
375 |
+
# Momentum.
|
376 |
+
_C.SOLVER.MOMENTUM = 0.9
|
377 |
+
|
378 |
+
# Momentum dampening.
|
379 |
+
_C.SOLVER.DAMPENING = 0.0
|
380 |
+
|
381 |
+
# Nesterov momentum.
|
382 |
+
_C.SOLVER.NESTEROV = True
|
383 |
+
|
384 |
+
# L2 regularization.
|
385 |
+
_C.SOLVER.WEIGHT_DECAY = 1e-4
|
386 |
+
|
387 |
+
# Start the warm up from SOLVER.BASE_LR * SOLVER.WARMUP_FACTOR.
|
388 |
+
_C.SOLVER.WARMUP_FACTOR = 0.1
|
389 |
+
|
390 |
+
# Gradually warm up the SOLVER.BASE_LR over this number of epochs.
|
391 |
+
_C.SOLVER.WARMUP_EPOCHS = 0.0
|
392 |
+
|
393 |
+
# The start learning rate of the warm up.
|
394 |
+
_C.SOLVER.WARMUP_START_LR = 0.01
|
395 |
+
|
396 |
+
# Optimization method.
|
397 |
+
_C.SOLVER.OPTIMIZING_METHOD = "sgd"
|
398 |
+
|
399 |
+
# Base learning rate is linearly scaled with NUM_SHARDS.
|
400 |
+
_C.SOLVER.BASE_LR_SCALE_NUM_SHARDS = False
|
401 |
+
|
402 |
+
# ---------------------------------------------------------------------------- #
|
403 |
+
# Misc options
|
404 |
+
# ---------------------------------------------------------------------------- #
|
405 |
+
|
406 |
+
# Number of GPUs to use (applies to both training and testing).
|
407 |
+
_C.NUM_GPUS = 1
|
408 |
+
|
409 |
+
# Number of machine to use for the job.
|
410 |
+
_C.NUM_SHARDS = 1
|
411 |
+
|
412 |
+
# The index of the current machine.
|
413 |
+
_C.SHARD_ID = 0
|
414 |
+
|
415 |
+
# Output basedir.
|
416 |
+
_C.OUTPUT_DIR = "./tmp"
|
417 |
+
|
418 |
+
# Note that non-determinism may still be present due to non-deterministic
|
419 |
+
# operator implementations in GPU operator libraries.
|
420 |
+
_C.RNG_SEED = 1
|
421 |
+
|
422 |
+
# Log period in iters.
|
423 |
+
_C.LOG_PERIOD = 10
|
424 |
+
|
425 |
+
# If True, log the model info.
|
426 |
+
_C.LOG_MODEL_INFO = False
|
427 |
+
|
428 |
+
# Distributed backend.
|
429 |
+
_C.DIST_BACKEND = "nccl"
|
430 |
+
|
431 |
+
# Global batch size
|
432 |
+
_C.GLOBAL_BATCH_SIZE = 64
|
433 |
+
|
434 |
+
# ---------------------------------------------------------------------------- #
|
435 |
+
# Benchmark options
|
436 |
+
# ---------------------------------------------------------------------------- #
|
437 |
+
_C.BENCHMARK = CfgNode()
|
438 |
+
|
439 |
+
# Number of epochs for data loading benchmark.
|
440 |
+
_C.BENCHMARK.NUM_EPOCHS = 5
|
441 |
+
|
442 |
+
# Log period in iters for data loading benchmark.
|
443 |
+
_C.BENCHMARK.LOG_PERIOD = 100
|
444 |
+
|
445 |
+
# If True, shuffle dataloader for epoch during benchmark.
|
446 |
+
_C.BENCHMARK.SHUFFLE = True
|
447 |
+
|
448 |
+
|
449 |
+
# ---------------------------------------------------------------------------- #
|
450 |
+
# Common train/test data loader options
|
451 |
+
# ---------------------------------------------------------------------------- #
|
452 |
+
_C.DATA_LOADER = CfgNode()
|
453 |
+
|
454 |
+
# Number of data loader workers per training process.
|
455 |
+
_C.DATA_LOADER.NUM_WORKERS = 8
|
456 |
+
|
457 |
+
# Load data to pinned host memory.
|
458 |
+
_C.DATA_LOADER.PIN_MEMORY = True
|
459 |
+
|
460 |
+
# Enable multi thread decoding.
|
461 |
+
_C.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE = False
|
462 |
+
|
463 |
+
|
464 |
+
# ---------------------------------------------------------------------------- #
|
465 |
+
# Detection options.
|
466 |
+
# ---------------------------------------------------------------------------- #
|
467 |
+
_C.DETECTION = CfgNode()
|
468 |
+
|
469 |
+
# Whether enable video detection.
|
470 |
+
_C.DETECTION.ENABLE = False
|
471 |
+
|
472 |
+
# Aligned version of RoI. More details can be found at slowfast/models/head_helper.py
|
473 |
+
_C.DETECTION.ALIGNED = True
|
474 |
+
|
475 |
+
# Spatial scale factor.
|
476 |
+
_C.DETECTION.SPATIAL_SCALE_FACTOR = 16
|
477 |
+
|
478 |
+
# RoI tranformation resolution.
|
479 |
+
_C.DETECTION.ROI_XFORM_RESOLUTION = 7
|
480 |
+
|
481 |
+
|
482 |
+
# -----------------------------------------------------------------------------
|
483 |
+
# AVA Dataset options
|
484 |
+
# -----------------------------------------------------------------------------
|
485 |
+
_C.AVA = CfgNode()
|
486 |
+
|
487 |
+
# Directory path of frames.
|
488 |
+
_C.AVA.FRAME_DIR = "/mnt/fair-flash3-east/ava_trainval_frames.img/"
|
489 |
+
|
490 |
+
# Directory path for files of frame lists.
|
491 |
+
_C.AVA.FRAME_LIST_DIR = (
|
492 |
+
"/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
|
493 |
+
)
|
494 |
+
|
495 |
+
# Directory path for annotation files.
|
496 |
+
_C.AVA.ANNOTATION_DIR = (
|
497 |
+
"/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
|
498 |
+
)
|
499 |
+
|
500 |
+
# Filenames of training samples list files.
|
501 |
+
_C.AVA.TRAIN_LISTS = ["train.csv"]
|
502 |
+
|
503 |
+
# Filenames of test samples list files.
|
504 |
+
_C.AVA.TEST_LISTS = ["val.csv"]
|
505 |
+
|
506 |
+
# Filenames of box list files for training. Note that we assume files which
|
507 |
+
# contains predicted boxes will have a suffix "predicted_boxes" in the
|
508 |
+
# filename.
|
509 |
+
_C.AVA.TRAIN_GT_BOX_LISTS = ["ava_train_v2.2.csv"]
|
510 |
+
_C.AVA.TRAIN_PREDICT_BOX_LISTS = []
|
511 |
+
|
512 |
+
# Filenames of box list files for test.
|
513 |
+
_C.AVA.TEST_PREDICT_BOX_LISTS = ["ava_val_predicted_boxes.csv"]
|
514 |
+
|
515 |
+
# This option controls the score threshold for the predicted boxes to use.
|
516 |
+
_C.AVA.DETECTION_SCORE_THRESH = 0.9
|
517 |
+
|
518 |
+
# If use BGR as the format of input frames.
|
519 |
+
_C.AVA.BGR = False
|
520 |
+
|
521 |
+
# Training augmentation parameters
|
522 |
+
# Whether to use color augmentation method.
|
523 |
+
_C.AVA.TRAIN_USE_COLOR_AUGMENTATION = False
|
524 |
+
|
525 |
+
# Whether to only use PCA jitter augmentation when using color augmentation
|
526 |
+
# method (otherwise combine with color jitter method).
|
527 |
+
_C.AVA.TRAIN_PCA_JITTER_ONLY = True
|
528 |
+
|
529 |
+
# Eigenvalues for PCA jittering. Note PCA is RGB based.
|
530 |
+
_C.AVA.TRAIN_PCA_EIGVAL = [0.225, 0.224, 0.229]
|
531 |
+
|
532 |
+
# Eigenvectors for PCA jittering.
|
533 |
+
_C.AVA.TRAIN_PCA_EIGVEC = [
|
534 |
+
[-0.5675, 0.7192, 0.4009],
|
535 |
+
[-0.5808, -0.0045, -0.8140],
|
536 |
+
[-0.5836, -0.6948, 0.4203],
|
537 |
+
]
|
538 |
+
|
539 |
+
# Whether to do horizontal flipping during test.
|
540 |
+
_C.AVA.TEST_FORCE_FLIP = False
|
541 |
+
|
542 |
+
# Whether to use full test set for validation split.
|
543 |
+
_C.AVA.FULL_TEST_ON_VAL = False
|
544 |
+
|
545 |
+
# The name of the file to the ava label map.
|
546 |
+
_C.AVA.LABEL_MAP_FILE = "ava_action_list_v2.2_for_activitynet_2019.pbtxt"
|
547 |
+
|
548 |
+
# The name of the file to the ava exclusion.
|
549 |
+
_C.AVA.EXCLUSION_FILE = "ava_val_excluded_timestamps_v2.2.csv"
|
550 |
+
|
551 |
+
# The name of the file to the ava groundtruth.
|
552 |
+
_C.AVA.GROUNDTRUTH_FILE = "ava_val_v2.2.csv"
|
553 |
+
|
554 |
+
# Backend to process image, includes `pytorch` and `cv2`.
|
555 |
+
_C.AVA.IMG_PROC_BACKEND = "cv2"
|
556 |
+
|
557 |
+
# ---------------------------------------------------------------------------- #
|
558 |
+
# Multigrid training options
|
559 |
+
# See https://arxiv.org/abs/1912.00998 for details about multigrid training.
|
560 |
+
# ---------------------------------------------------------------------------- #
|
561 |
+
_C.MULTIGRID = CfgNode()
|
562 |
+
|
563 |
+
# Multigrid training allows us to train for more epochs with fewer iterations.
|
564 |
+
# This hyperparameter specifies how many times more epochs to train.
|
565 |
+
# The default setting in paper trains for 1.5x more epochs than baseline.
|
566 |
+
_C.MULTIGRID.EPOCH_FACTOR = 1.5
|
567 |
+
|
568 |
+
# Enable short cycles.
|
569 |
+
_C.MULTIGRID.SHORT_CYCLE = False
|
570 |
+
# Short cycle additional spatial dimensions relative to the default crop size.
|
571 |
+
_C.MULTIGRID.SHORT_CYCLE_FACTORS = [0.5, 0.5 ** 0.5]
|
572 |
+
|
573 |
+
_C.MULTIGRID.LONG_CYCLE = False
|
574 |
+
# (Temporal, Spatial) dimensions relative to the default shape.
|
575 |
+
_C.MULTIGRID.LONG_CYCLE_FACTORS = [
|
576 |
+
(0.25, 0.5 ** 0.5),
|
577 |
+
(0.5, 0.5 ** 0.5),
|
578 |
+
(0.5, 1),
|
579 |
+
(1, 1),
|
580 |
+
]
|
581 |
+
|
582 |
+
# While a standard BN computes stats across all examples in a GPU,
|
583 |
+
# for multigrid training we fix the number of clips to compute BN stats on.
|
584 |
+
# See https://arxiv.org/abs/1912.00998 for details.
|
585 |
+
_C.MULTIGRID.BN_BASE_SIZE = 8
|
586 |
+
|
587 |
+
# Multigrid training epochs are not proportional to actual training time or
|
588 |
+
# computations, so _C.TRAIN.EVAL_PERIOD leads to too frequent or rare
|
589 |
+
# evaluation. We use a multigrid-specific rule to determine when to evaluate:
|
590 |
+
# This hyperparameter defines how many times to evaluate a model per long
|
591 |
+
# cycle shape.
|
592 |
+
_C.MULTIGRID.EVAL_FREQ = 3
|
593 |
+
|
594 |
+
# No need to specify; Set automatically and used as global variables.
|
595 |
+
_C.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = 0
|
596 |
+
_C.MULTIGRID.DEFAULT_B = 0
|
597 |
+
_C.MULTIGRID.DEFAULT_T = 0
|
598 |
+
_C.MULTIGRID.DEFAULT_S = 0
|
599 |
+
|
600 |
+
# -----------------------------------------------------------------------------
|
601 |
+
# Tensorboard Visualization Options
|
602 |
+
# -----------------------------------------------------------------------------
|
603 |
+
_C.TENSORBOARD = CfgNode()
|
604 |
+
|
605 |
+
# Log to summary writer, this will automatically.
|
606 |
+
# log loss, lr and metrics during train/eval.
|
607 |
+
_C.TENSORBOARD.ENABLE = False
|
608 |
+
# Provide path to prediction results for visualization.
|
609 |
+
# This is a pickle file of [prediction_tensor, label_tensor]
|
610 |
+
_C.TENSORBOARD.PREDICTIONS_PATH = ""
|
611 |
+
# Path to directory for tensorboard logs.
|
612 |
+
# Default to to cfg.OUTPUT_DIR/runs-{cfg.TRAIN.DATASET}.
|
613 |
+
_C.TENSORBOARD.LOG_DIR = ""
|
614 |
+
# Path to a json file providing class_name - id mapping
|
615 |
+
# in the format {"class_name1": id1, "class_name2": id2, ...}.
|
616 |
+
# This file must be provided to enable plotting confusion matrix
|
617 |
+
# by a subset or parent categories.
|
618 |
+
_C.TENSORBOARD.CLASS_NAMES_PATH = ""
|
619 |
+
|
620 |
+
# Path to a json file for categories -> classes mapping
|
621 |
+
# in the format {"parent_class": ["child_class1", "child_class2",...], ...}.
|
622 |
+
_C.TENSORBOARD.CATEGORIES_PATH = ""
|
623 |
+
|
624 |
+
# Config for confusion matrices visualization.
|
625 |
+
_C.TENSORBOARD.CONFUSION_MATRIX = CfgNode()
|
626 |
+
# Visualize confusion matrix.
|
627 |
+
_C.TENSORBOARD.CONFUSION_MATRIX.ENABLE = False
|
628 |
+
# Figure size of the confusion matrices plotted.
|
629 |
+
_C.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE = [8, 8]
|
630 |
+
# Path to a subset of categories to visualize.
|
631 |
+
# File contains class names separated by newline characters.
|
632 |
+
_C.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH = ""
|
633 |
+
|
634 |
+
# Config for histogram visualization.
|
635 |
+
_C.TENSORBOARD.HISTOGRAM = CfgNode()
|
636 |
+
# Visualize histograms.
|
637 |
+
_C.TENSORBOARD.HISTOGRAM.ENABLE = False
|
638 |
+
# Path to a subset of classes to plot histograms.
|
639 |
+
# Class names must be separated by newline characters.
|
640 |
+
_C.TENSORBOARD.HISTOGRAM.SUBSET_PATH = ""
|
641 |
+
# Visualize top-k most predicted classes on histograms for each
|
642 |
+
# chosen true label.
|
643 |
+
_C.TENSORBOARD.HISTOGRAM.TOPK = 10
|
644 |
+
# Figure size of the histograms plotted.
|
645 |
+
_C.TENSORBOARD.HISTOGRAM.FIGSIZE = [8, 8]
|
646 |
+
|
647 |
+
# Config for layers' weights and activations visualization.
|
648 |
+
# _C.TENSORBOARD.ENABLE must be True.
|
649 |
+
_C.TENSORBOARD.MODEL_VIS = CfgNode()
|
650 |
+
|
651 |
+
# If False, skip model visualization.
|
652 |
+
_C.TENSORBOARD.MODEL_VIS.ENABLE = False
|
653 |
+
|
654 |
+
# If False, skip visualizing model weights.
|
655 |
+
_C.TENSORBOARD.MODEL_VIS.MODEL_WEIGHTS = False
|
656 |
+
|
657 |
+
# If False, skip visualizing model activations.
|
658 |
+
_C.TENSORBOARD.MODEL_VIS.ACTIVATIONS = False
|
659 |
+
|
660 |
+
# If False, skip visualizing input videos.
|
661 |
+
_C.TENSORBOARD.MODEL_VIS.INPUT_VIDEO = False
|
662 |
+
|
663 |
+
|
664 |
+
# List of strings containing data about layer names and their indexing to
|
665 |
+
# visualize weights and activations for. The indexing is meant for
|
666 |
+
# choosing a subset of activations outputed by a layer for visualization.
|
667 |
+
# If indexing is not specified, visualize all activations outputed by the layer.
|
668 |
+
# For each string, layer name and indexing is separated by whitespaces.
|
669 |
+
# e.g.: [layer1 1,2;1,2, layer2, layer3 150,151;3,4]; this means for each array `arr`
|
670 |
+
# along the batch dimension in `layer1`, we take arr[[1, 2], [1, 2]]
|
671 |
+
_C.TENSORBOARD.MODEL_VIS.LAYER_LIST = []
|
672 |
+
# Top-k predictions to plot on videos
|
673 |
+
_C.TENSORBOARD.MODEL_VIS.TOPK_PREDS = 1
|
674 |
+
# Colormap to for text boxes and bounding boxes colors
|
675 |
+
_C.TENSORBOARD.MODEL_VIS.COLORMAP = "Pastel2"
|
676 |
+
# Config for visualization video inputs with Grad-CAM.
|
677 |
+
# _C.TENSORBOARD.ENABLE must be True.
|
678 |
+
_C.TENSORBOARD.MODEL_VIS.GRAD_CAM = CfgNode()
|
679 |
+
# Whether to run visualization using Grad-CAM technique.
|
680 |
+
_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.ENABLE = True
|
681 |
+
# CNN layers to use for Grad-CAM. The number of layers must be equal to
|
682 |
+
# number of pathway(s).
|
683 |
+
_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.LAYER_LIST = []
|
684 |
+
# If True, visualize Grad-CAM using true labels for each instances.
|
685 |
+
# If False, use the highest predicted class.
|
686 |
+
_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.USE_TRUE_LABEL = False
|
687 |
+
# Colormap to for text boxes and bounding boxes colors
|
688 |
+
_C.TENSORBOARD.MODEL_VIS.GRAD_CAM.COLORMAP = "viridis"
|
689 |
+
|
690 |
+
# Config for visualization for wrong prediction visualization.
|
691 |
+
# _C.TENSORBOARD.ENABLE must be True.
|
692 |
+
_C.TENSORBOARD.WRONG_PRED_VIS = CfgNode()
|
693 |
+
_C.TENSORBOARD.WRONG_PRED_VIS.ENABLE = False
|
694 |
+
# Folder tag to origanize model eval videos under.
|
695 |
+
_C.TENSORBOARD.WRONG_PRED_VIS.TAG = "Incorrectly classified videos."
|
696 |
+
# Subset of labels to visualize. Only wrong predictions with true labels
|
697 |
+
# within this subset is visualized.
|
698 |
+
_C.TENSORBOARD.WRONG_PRED_VIS.SUBSET_PATH = ""
|
699 |
+
|
700 |
+
|
701 |
+
# ---------------------------------------------------------------------------- #
|
702 |
+
# Demo options
|
703 |
+
# ---------------------------------------------------------------------------- #
|
704 |
+
_C.DEMO = CfgNode()
|
705 |
+
|
706 |
+
# Run model in DEMO mode.
|
707 |
+
_C.DEMO.ENABLE = False
|
708 |
+
|
709 |
+
# Path to a json file providing class_name - id mapping
|
710 |
+
# in the format {"class_name1": id1, "class_name2": id2, ...}.
|
711 |
+
_C.DEMO.LABEL_FILE_PATH = ""
|
712 |
+
|
713 |
+
# Specify a camera device as input. This will be prioritized
|
714 |
+
# over input video if set.
|
715 |
+
# If -1, use input video instead.
|
716 |
+
_C.DEMO.WEBCAM = -1
|
717 |
+
|
718 |
+
# Path to input video for demo.
|
719 |
+
_C.DEMO.INPUT_VIDEO = ""
|
720 |
+
# Custom width for reading input video data.
|
721 |
+
_C.DEMO.DISPLAY_WIDTH = 0
|
722 |
+
# Custom height for reading input video data.
|
723 |
+
_C.DEMO.DISPLAY_HEIGHT = 0
|
724 |
+
# Path to Detectron2 object detection model configuration,
|
725 |
+
# only used for detection tasks.
|
726 |
+
_C.DEMO.DETECTRON2_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
|
727 |
+
# Path to Detectron2 object detection model pre-trained weights.
|
728 |
+
_C.DEMO.DETECTRON2_WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
|
729 |
+
# Threshold for choosing predicted bounding boxes by Detectron2.
|
730 |
+
_C.DEMO.DETECTRON2_THRESH = 0.9
|
731 |
+
# Number of overlapping frames between 2 consecutive clips.
|
732 |
+
# Increase this number for more frequent action predictions.
|
733 |
+
# The number of overlapping frames cannot be larger than
|
734 |
+
# half of the sequence length `cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE`
|
735 |
+
_C.DEMO.BUFFER_SIZE = 0
|
736 |
+
# If specified, the visualized outputs will be written this a video file of
|
737 |
+
# this path. Otherwise, the visualized outputs will be displayed in a window.
|
738 |
+
_C.DEMO.OUTPUT_FILE = ""
|
739 |
+
# Frames per second rate for writing to output video file.
|
740 |
+
# If not set (-1), use fps rate from input file.
|
741 |
+
_C.DEMO.OUTPUT_FPS = -1
|
742 |
+
# Input format from demo video reader ("RGB" or "BGR").
|
743 |
+
_C.DEMO.INPUT_FORMAT = "BGR"
|
744 |
+
# Draw visualization frames in [keyframe_idx - CLIP_VIS_SIZE, keyframe_idx + CLIP_VIS_SIZE] inclusively.
|
745 |
+
_C.DEMO.CLIP_VIS_SIZE = 10
|
746 |
+
# Number of processes to run video visualizer.
|
747 |
+
_C.DEMO.NUM_VIS_INSTANCES = 2
|
748 |
+
|
749 |
+
# Path to pre-computed predicted boxes
|
750 |
+
_C.DEMO.PREDS_BOXES = ""
|
751 |
+
# Whether to run in with multi-threaded video reader.
|
752 |
+
_C.DEMO.THREAD_ENABLE = False
|
753 |
+
# Take one clip for every `DEMO.NUM_CLIPS_SKIP` + 1 for prediction and visualization.
|
754 |
+
# This is used for fast demo speed by reducing the prediction/visualiztion frequency.
|
755 |
+
# If -1, take the most recent read clip for visualization. This mode is only supported
|
756 |
+
# if `DEMO.THREAD_ENABLE` is set to True.
|
757 |
+
_C.DEMO.NUM_CLIPS_SKIP = 0
|
758 |
+
# Path to ground-truth boxes and labels (optional)
|
759 |
+
_C.DEMO.GT_BOXES = ""
|
760 |
+
# The starting second of the video w.r.t bounding boxes file.
|
761 |
+
_C.DEMO.STARTING_SECOND = 900
|
762 |
+
# Frames per second of the input video/folder of images.
|
763 |
+
_C.DEMO.FPS = 30
|
764 |
+
# Visualize with top-k predictions or predictions above certain threshold(s).
|
765 |
+
# Option: {"thres", "top-k"}
|
766 |
+
_C.DEMO.VIS_MODE = "thres"
|
767 |
+
# Threshold for common class names.
|
768 |
+
_C.DEMO.COMMON_CLASS_THRES = 0.7
|
769 |
+
# Theshold for uncommon class names. This will not be
|
770 |
+
# used if `_C.DEMO.COMMON_CLASS_NAMES` is empty.
|
771 |
+
_C.DEMO.UNCOMMON_CLASS_THRES = 0.3
|
772 |
+
# This is chosen based on distribution of examples in
|
773 |
+
# each classes in AVA dataset.
|
774 |
+
_C.DEMO.COMMON_CLASS_NAMES = [
|
775 |
+
"watch (a person)",
|
776 |
+
"talk to (e.g., self, a person, a group)",
|
777 |
+
"listen to (a person)",
|
778 |
+
"touch (an object)",
|
779 |
+
"carry/hold (an object)",
|
780 |
+
"walk",
|
781 |
+
"sit",
|
782 |
+
"lie/sleep",
|
783 |
+
"bend/bow (at the waist)",
|
784 |
+
]
|
785 |
+
# Slow-motion rate for the visualization. The visualized portions of the
|
786 |
+
# video will be played `_C.DEMO.SLOWMO` times slower than usual speed.
|
787 |
+
_C.DEMO.SLOWMO = 1
|
788 |
+
|
789 |
+
def _assert_and_infer_cfg(cfg):
|
790 |
+
# BN assertions.
|
791 |
+
if cfg.BN.USE_PRECISE_STATS:
|
792 |
+
assert cfg.BN.NUM_BATCHES_PRECISE >= 0
|
793 |
+
# TRAIN assertions.
|
794 |
+
assert cfg.TRAIN.CHECKPOINT_TYPE in ["pytorch", "caffe2"]
|
795 |
+
assert cfg.TRAIN.BATCH_SIZE % cfg.NUM_GPUS == 0
|
796 |
+
|
797 |
+
# TEST assertions.
|
798 |
+
assert cfg.TEST.CHECKPOINT_TYPE in ["pytorch", "caffe2"]
|
799 |
+
assert cfg.TEST.BATCH_SIZE % cfg.NUM_GPUS == 0
|
800 |
+
assert cfg.TEST.NUM_SPATIAL_CROPS == 3
|
801 |
+
|
802 |
+
# RESNET assertions.
|
803 |
+
assert cfg.RESNET.NUM_GROUPS > 0
|
804 |
+
assert cfg.RESNET.WIDTH_PER_GROUP > 0
|
805 |
+
assert cfg.RESNET.WIDTH_PER_GROUP % cfg.RESNET.NUM_GROUPS == 0
|
806 |
+
|
807 |
+
# Execute LR scaling by num_shards.
|
808 |
+
if cfg.SOLVER.BASE_LR_SCALE_NUM_SHARDS:
|
809 |
+
cfg.SOLVER.BASE_LR *= cfg.NUM_SHARDS
|
810 |
+
|
811 |
+
# General assertions.
|
812 |
+
assert cfg.SHARD_ID < cfg.NUM_SHARDS
|
813 |
+
return cfg
|
814 |
+
|
815 |
+
|
816 |
+
def get_cfg():
|
817 |
+
"""
|
818 |
+
Get a copy of the default config.
|
819 |
+
"""
|
820 |
+
return _assert_and_infer_cfg(_C.clone())
|
TimeSformer/timesformer/datasets/DATASET.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset Preparation
|
2 |
+
|
3 |
+
## Kinetics
|
4 |
+
|
5 |
+
The Kinetics Dataset could be downloaded from the following [link](https://github.com/cvdfoundation/kinetics-dataset):
|
6 |
+
|
7 |
+
After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is:
|
8 |
+
|
9 |
+
```
|
10 |
+
path_to_video_1 label_1
|
11 |
+
path_to_video_2 label_2
|
12 |
+
path_to_video_3 label_3
|
13 |
+
...
|
14 |
+
path_to_video_N label_N
|
15 |
+
```
|
16 |
+
|
17 |
+
## Something-Something V2
|
18 |
+
1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something).
|
19 |
+
|
20 |
+
2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)).
|
21 |
+
|
22 |
+
3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command
|
23 |
+
`ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"`
|
24 |
+
in experiments.) Please put the frames in a structure consistent with the frame lists.
|
25 |
+
|
26 |
+
Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames.
|
TimeSformer/timesformer/datasets/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from .build import DATASET_REGISTRY, build_dataset # noqa
|
4 |
+
from .kinetics import Kinetics # noqa
|
5 |
+
from .ssv2 import Ssv2 # noqa
|
TimeSformer/timesformer/datasets/build.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from fvcore.common.registry import Registry
|
4 |
+
|
5 |
+
DATASET_REGISTRY = Registry("DATASET")
|
6 |
+
DATASET_REGISTRY.__doc__ = """
|
7 |
+
Registry for dataset.
|
8 |
+
|
9 |
+
The registered object will be called with `obj(cfg, split)`.
|
10 |
+
The call should return a `torch.utils.data.Dataset` object.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
def build_dataset(dataset_name, cfg, split):
|
15 |
+
"""
|
16 |
+
Build a dataset, defined by `dataset_name`.
|
17 |
+
Args:
|
18 |
+
dataset_name (str): the name of the dataset to be constructed.
|
19 |
+
cfg (CfgNode): configs. Details can be found in
|
20 |
+
slowfast/config/defaults.py
|
21 |
+
split (str): the split of the data loader. Options include `train`,
|
22 |
+
`val`, and `test`.
|
23 |
+
Returns:
|
24 |
+
Dataset: a constructed dataset specified by dataset_name.
|
25 |
+
"""
|
26 |
+
# Capitalize the the first letter of the dataset_name since the dataset_name
|
27 |
+
# in configs may be in lowercase but the name of dataset class should always
|
28 |
+
# start with an uppercase letter.
|
29 |
+
name = dataset_name.capitalize()
|
30 |
+
return DATASET_REGISTRY.get(name)(cfg, split)
|
TimeSformer/timesformer/datasets/cv2_transform.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
def clip_boxes_to_image(boxes, height, width):
|
9 |
+
"""
|
10 |
+
Clip the boxes with the height and width of the image size.
|
11 |
+
Args:
|
12 |
+
boxes (ndarray): bounding boxes to peform crop. The dimension is
|
13 |
+
`num boxes` x 4.
|
14 |
+
height (int): the height of the image.
|
15 |
+
width (int): the width of the image.
|
16 |
+
Returns:
|
17 |
+
boxes (ndarray): cropped bounding boxes.
|
18 |
+
"""
|
19 |
+
boxes[:, [0, 2]] = np.minimum(
|
20 |
+
width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
|
21 |
+
)
|
22 |
+
boxes[:, [1, 3]] = np.minimum(
|
23 |
+
height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
|
24 |
+
)
|
25 |
+
return boxes
|
26 |
+
|
27 |
+
|
28 |
+
def random_short_side_scale_jitter_list(images, min_size, max_size, boxes=None):
|
29 |
+
"""
|
30 |
+
Perform a spatial short scale jittering on the given images and
|
31 |
+
corresponding boxes.
|
32 |
+
Args:
|
33 |
+
images (list): list of images to perform scale jitter. Dimension is
|
34 |
+
`height` x `width` x `channel`.
|
35 |
+
min_size (int): the minimal size to scale the frames.
|
36 |
+
max_size (int): the maximal size to scale the frames.
|
37 |
+
boxes (list): optional. Corresponding boxes to images. Dimension is
|
38 |
+
`num boxes` x 4.
|
39 |
+
Returns:
|
40 |
+
(list): the list of scaled images with dimension of
|
41 |
+
`new height` x `new width` x `channel`.
|
42 |
+
(ndarray or None): the scaled boxes with dimension of
|
43 |
+
`num boxes` x 4.
|
44 |
+
"""
|
45 |
+
size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)))
|
46 |
+
|
47 |
+
height = images[0].shape[0]
|
48 |
+
width = images[0].shape[1]
|
49 |
+
if (width <= height and width == size) or (
|
50 |
+
height <= width and height == size
|
51 |
+
):
|
52 |
+
return images, boxes
|
53 |
+
new_width = size
|
54 |
+
new_height = size
|
55 |
+
if width < height:
|
56 |
+
new_height = int(math.floor((float(height) / width) * size))
|
57 |
+
if boxes is not None:
|
58 |
+
boxes = [
|
59 |
+
proposal * float(new_height) / height for proposal in boxes
|
60 |
+
]
|
61 |
+
else:
|
62 |
+
new_width = int(math.floor((float(width) / height) * size))
|
63 |
+
if boxes is not None:
|
64 |
+
boxes = [proposal * float(new_width) / width for proposal in boxes]
|
65 |
+
return (
|
66 |
+
[
|
67 |
+
cv2.resize(
|
68 |
+
image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
|
69 |
+
).astype(np.float32)
|
70 |
+
for image in images
|
71 |
+
],
|
72 |
+
boxes,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def scale(size, image):
|
77 |
+
"""
|
78 |
+
Scale the short side of the image to size.
|
79 |
+
Args:
|
80 |
+
size (int): size to scale the image.
|
81 |
+
image (array): image to perform short side scale. Dimension is
|
82 |
+
`height` x `width` x `channel`.
|
83 |
+
Returns:
|
84 |
+
(ndarray): the scaled image with dimension of
|
85 |
+
`height` x `width` x `channel`.
|
86 |
+
"""
|
87 |
+
height = image.shape[0]
|
88 |
+
width = image.shape[1]
|
89 |
+
if (width <= height and width == size) or (
|
90 |
+
height <= width and height == size
|
91 |
+
):
|
92 |
+
return image
|
93 |
+
new_width = size
|
94 |
+
new_height = size
|
95 |
+
if width < height:
|
96 |
+
new_height = int(math.floor((float(height) / width) * size))
|
97 |
+
else:
|
98 |
+
new_width = int(math.floor((float(width) / height) * size))
|
99 |
+
img = cv2.resize(
|
100 |
+
image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
|
101 |
+
)
|
102 |
+
return img.astype(np.float32)
|
103 |
+
|
104 |
+
|
105 |
+
def scale_boxes(size, boxes, height, width):
|
106 |
+
"""
|
107 |
+
Scale the short side of the box to size.
|
108 |
+
Args:
|
109 |
+
size (int): size to scale the image.
|
110 |
+
boxes (ndarray): bounding boxes to peform scale. The dimension is
|
111 |
+
`num boxes` x 4.
|
112 |
+
height (int): the height of the image.
|
113 |
+
width (int): the width of the image.
|
114 |
+
Returns:
|
115 |
+
boxes (ndarray): scaled bounding boxes.
|
116 |
+
"""
|
117 |
+
if (width <= height and width == size) or (
|
118 |
+
height <= width and height == size
|
119 |
+
):
|
120 |
+
return boxes
|
121 |
+
|
122 |
+
new_width = size
|
123 |
+
new_height = size
|
124 |
+
if width < height:
|
125 |
+
new_height = int(math.floor((float(height) / width) * size))
|
126 |
+
boxes *= float(new_height) / height
|
127 |
+
else:
|
128 |
+
new_width = int(math.floor((float(width) / height) * size))
|
129 |
+
boxes *= float(new_width) / width
|
130 |
+
return boxes
|
131 |
+
|
132 |
+
|
133 |
+
def horizontal_flip_list(prob, images, order="CHW", boxes=None):
|
134 |
+
"""
|
135 |
+
Horizontally flip the list of image and optional boxes.
|
136 |
+
Args:
|
137 |
+
prob (float): probability to flip.
|
138 |
+
image (list): ilist of images to perform short side scale. Dimension is
|
139 |
+
`height` x `width` x `channel` or `channel` x `height` x `width`.
|
140 |
+
order (str): order of the `height`, `channel` and `width`.
|
141 |
+
boxes (list): optional. Corresponding boxes to images.
|
142 |
+
Dimension is `num boxes` x 4.
|
143 |
+
Returns:
|
144 |
+
(ndarray): the scaled image with dimension of
|
145 |
+
`height` x `width` x `channel`.
|
146 |
+
(list): optional. Corresponding boxes to images. Dimension is
|
147 |
+
`num boxes` x 4.
|
148 |
+
"""
|
149 |
+
_, width, _ = images[0].shape
|
150 |
+
if np.random.uniform() < prob:
|
151 |
+
if boxes is not None:
|
152 |
+
boxes = [flip_boxes(proposal, width) for proposal in boxes]
|
153 |
+
if order == "CHW":
|
154 |
+
out_images = []
|
155 |
+
for image in images:
|
156 |
+
image = np.asarray(image).swapaxes(2, 0)
|
157 |
+
image = image[::-1]
|
158 |
+
out_images.append(image.swapaxes(0, 2))
|
159 |
+
return out_images, boxes
|
160 |
+
elif order == "HWC":
|
161 |
+
return [cv2.flip(image, 1) for image in images], boxes
|
162 |
+
return images, boxes
|
163 |
+
|
164 |
+
|
165 |
+
def spatial_shift_crop_list(size, images, spatial_shift_pos, boxes=None):
|
166 |
+
"""
|
167 |
+
Perform left, center, or right crop of the given list of images.
|
168 |
+
Args:
|
169 |
+
size (int): size to crop.
|
170 |
+
image (list): ilist of images to perform short side scale. Dimension is
|
171 |
+
`height` x `width` x `channel` or `channel` x `height` x `width`.
|
172 |
+
spatial_shift_pos (int): option includes 0 (left), 1 (middle), and
|
173 |
+
2 (right) crop.
|
174 |
+
boxes (list): optional. Corresponding boxes to images.
|
175 |
+
Dimension is `num boxes` x 4.
|
176 |
+
Returns:
|
177 |
+
cropped (ndarray): the cropped list of images with dimension of
|
178 |
+
`height` x `width` x `channel`.
|
179 |
+
boxes (list): optional. Corresponding boxes to images. Dimension is
|
180 |
+
`num boxes` x 4.
|
181 |
+
"""
|
182 |
+
|
183 |
+
assert spatial_shift_pos in [0, 1, 2]
|
184 |
+
|
185 |
+
height = images[0].shape[0]
|
186 |
+
width = images[0].shape[1]
|
187 |
+
y_offset = int(math.ceil((height - size) / 2))
|
188 |
+
x_offset = int(math.ceil((width - size) / 2))
|
189 |
+
|
190 |
+
if height > width:
|
191 |
+
if spatial_shift_pos == 0:
|
192 |
+
y_offset = 0
|
193 |
+
elif spatial_shift_pos == 2:
|
194 |
+
y_offset = height - size
|
195 |
+
else:
|
196 |
+
if spatial_shift_pos == 0:
|
197 |
+
x_offset = 0
|
198 |
+
elif spatial_shift_pos == 2:
|
199 |
+
x_offset = width - size
|
200 |
+
|
201 |
+
cropped = [
|
202 |
+
image[y_offset : y_offset + size, x_offset : x_offset + size, :]
|
203 |
+
for image in images
|
204 |
+
]
|
205 |
+
assert cropped[0].shape[0] == size, "Image height not cropped properly"
|
206 |
+
assert cropped[0].shape[1] == size, "Image width not cropped properly"
|
207 |
+
|
208 |
+
if boxes is not None:
|
209 |
+
for i in range(len(boxes)):
|
210 |
+
boxes[i][:, [0, 2]] -= x_offset
|
211 |
+
boxes[i][:, [1, 3]] -= y_offset
|
212 |
+
return cropped, boxes
|
213 |
+
|
214 |
+
|
215 |
+
def CHW2HWC(image):
|
216 |
+
"""
|
217 |
+
Transpose the dimension from `channel` x `height` x `width` to
|
218 |
+
`height` x `width` x `channel`.
|
219 |
+
Args:
|
220 |
+
image (array): image to transpose.
|
221 |
+
Returns
|
222 |
+
(array): transposed image.
|
223 |
+
"""
|
224 |
+
return image.transpose([1, 2, 0])
|
225 |
+
|
226 |
+
|
227 |
+
def HWC2CHW(image):
|
228 |
+
"""
|
229 |
+
Transpose the dimension from `height` x `width` x `channel` to
|
230 |
+
`channel` x `height` x `width`.
|
231 |
+
Args:
|
232 |
+
image (array): image to transpose.
|
233 |
+
Returns
|
234 |
+
(array): transposed image.
|
235 |
+
"""
|
236 |
+
return image.transpose([2, 0, 1])
|
237 |
+
|
238 |
+
|
239 |
+
def color_jitter_list(
|
240 |
+
images, img_brightness=0, img_contrast=0, img_saturation=0
|
241 |
+
):
|
242 |
+
"""
|
243 |
+
Perform color jitter on the list of images.
|
244 |
+
Args:
|
245 |
+
images (list): list of images to perform color jitter.
|
246 |
+
img_brightness (float): jitter ratio for brightness.
|
247 |
+
img_contrast (float): jitter ratio for contrast.
|
248 |
+
img_saturation (float): jitter ratio for saturation.
|
249 |
+
Returns:
|
250 |
+
images (list): the jittered list of images.
|
251 |
+
"""
|
252 |
+
jitter = []
|
253 |
+
if img_brightness != 0:
|
254 |
+
jitter.append("brightness")
|
255 |
+
if img_contrast != 0:
|
256 |
+
jitter.append("contrast")
|
257 |
+
if img_saturation != 0:
|
258 |
+
jitter.append("saturation")
|
259 |
+
|
260 |
+
if len(jitter) > 0:
|
261 |
+
order = np.random.permutation(np.arange(len(jitter)))
|
262 |
+
for idx in range(0, len(jitter)):
|
263 |
+
if jitter[order[idx]] == "brightness":
|
264 |
+
images = brightness_list(img_brightness, images)
|
265 |
+
elif jitter[order[idx]] == "contrast":
|
266 |
+
images = contrast_list(img_contrast, images)
|
267 |
+
elif jitter[order[idx]] == "saturation":
|
268 |
+
images = saturation_list(img_saturation, images)
|
269 |
+
return images
|
270 |
+
|
271 |
+
|
272 |
+
def lighting_list(imgs, alphastd, eigval, eigvec, alpha=None):
|
273 |
+
"""
|
274 |
+
Perform AlexNet-style PCA jitter on the given list of images.
|
275 |
+
Args:
|
276 |
+
images (list): list of images to perform lighting jitter.
|
277 |
+
alphastd (float): jitter ratio for PCA jitter.
|
278 |
+
eigval (list): eigenvalues for PCA jitter.
|
279 |
+
eigvec (list[list]): eigenvectors for PCA jitter.
|
280 |
+
Returns:
|
281 |
+
out_images (list): the list of jittered images.
|
282 |
+
"""
|
283 |
+
if alphastd == 0:
|
284 |
+
return imgs
|
285 |
+
# generate alpha1, alpha2, alpha3
|
286 |
+
alpha = np.random.normal(0, alphastd, size=(1, 3))
|
287 |
+
eig_vec = np.array(eigvec)
|
288 |
+
eig_val = np.reshape(eigval, (1, 3))
|
289 |
+
rgb = np.sum(
|
290 |
+
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
|
291 |
+
axis=1,
|
292 |
+
)
|
293 |
+
out_images = []
|
294 |
+
for img in imgs:
|
295 |
+
for idx in range(img.shape[0]):
|
296 |
+
img[idx] = img[idx] + rgb[2 - idx]
|
297 |
+
out_images.append(img)
|
298 |
+
return out_images
|
299 |
+
|
300 |
+
|
301 |
+
def color_normalization(image, mean, stddev):
|
302 |
+
"""
|
303 |
+
Perform color normalization on the image with the given mean and stddev.
|
304 |
+
Args:
|
305 |
+
image (array): image to perform color normalization.
|
306 |
+
mean (float): mean value to subtract.
|
307 |
+
stddev (float): stddev to devide.
|
308 |
+
"""
|
309 |
+
# Input image should in format of CHW
|
310 |
+
assert len(mean) == image.shape[0], "channel mean not computed properly"
|
311 |
+
assert len(stddev) == image.shape[0], "channel stddev not computed properly"
|
312 |
+
for idx in range(image.shape[0]):
|
313 |
+
image[idx] = image[idx] - mean[idx]
|
314 |
+
image[idx] = image[idx] / stddev[idx]
|
315 |
+
return image
|
316 |
+
|
317 |
+
|
318 |
+
def pad_image(image, pad_size, order="CHW"):
|
319 |
+
"""
|
320 |
+
Pad the given image with the size of pad_size.
|
321 |
+
Args:
|
322 |
+
image (array): image to pad.
|
323 |
+
pad_size (int): size to pad.
|
324 |
+
order (str): order of the `height`, `channel` and `width`.
|
325 |
+
Returns:
|
326 |
+
img (array): padded image.
|
327 |
+
"""
|
328 |
+
if order == "CHW":
|
329 |
+
img = np.pad(
|
330 |
+
image,
|
331 |
+
((0, 0), (pad_size, pad_size), (pad_size, pad_size)),
|
332 |
+
mode=str("constant"),
|
333 |
+
)
|
334 |
+
elif order == "HWC":
|
335 |
+
img = np.pad(
|
336 |
+
image,
|
337 |
+
((pad_size, pad_size), (pad_size, pad_size), (0, 0)),
|
338 |
+
mode=str("constant"),
|
339 |
+
)
|
340 |
+
return img
|
341 |
+
|
342 |
+
|
343 |
+
def horizontal_flip(prob, image, order="CHW"):
|
344 |
+
"""
|
345 |
+
Horizontally flip the image.
|
346 |
+
Args:
|
347 |
+
prob (float): probability to flip.
|
348 |
+
image (array): image to pad.
|
349 |
+
order (str): order of the `height`, `channel` and `width`.
|
350 |
+
Returns:
|
351 |
+
img (array): flipped image.
|
352 |
+
"""
|
353 |
+
assert order in ["CHW", "HWC"], "order {} is not supported".format(order)
|
354 |
+
if np.random.uniform() < prob:
|
355 |
+
if order == "CHW":
|
356 |
+
image = image[:, :, ::-1]
|
357 |
+
elif order == "HWC":
|
358 |
+
image = image[:, ::-1, :]
|
359 |
+
else:
|
360 |
+
raise NotImplementedError("Unknown order {}".format(order))
|
361 |
+
return image
|
362 |
+
|
363 |
+
|
364 |
+
def flip_boxes(boxes, im_width):
|
365 |
+
"""
|
366 |
+
Horizontally flip the boxes.
|
367 |
+
Args:
|
368 |
+
boxes (array): box to flip.
|
369 |
+
im_width (int): width of the image.
|
370 |
+
Returns:
|
371 |
+
boxes_flipped (array): flipped box.
|
372 |
+
"""
|
373 |
+
|
374 |
+
boxes_flipped = boxes.copy()
|
375 |
+
boxes_flipped[:, 0::4] = im_width - boxes[:, 2::4] - 1
|
376 |
+
boxes_flipped[:, 2::4] = im_width - boxes[:, 0::4] - 1
|
377 |
+
return boxes_flipped
|
378 |
+
|
379 |
+
|
380 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
381 |
+
"""
|
382 |
+
Crop the boxes given the offsets.
|
383 |
+
Args:
|
384 |
+
boxes (array): boxes to crop.
|
385 |
+
x_offset (int): offset on x.
|
386 |
+
y_offset (int): offset on y.
|
387 |
+
"""
|
388 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
389 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
390 |
+
return boxes
|
391 |
+
|
392 |
+
|
393 |
+
def random_crop_list(images, size, pad_size=0, order="CHW", boxes=None):
|
394 |
+
"""
|
395 |
+
Perform random crop on a list of images.
|
396 |
+
Args:
|
397 |
+
images (list): list of images to perform random crop.
|
398 |
+
size (int): size to crop.
|
399 |
+
pad_size (int): padding size.
|
400 |
+
order (str): order of the `height`, `channel` and `width`.
|
401 |
+
boxes (list): optional. Corresponding boxes to images.
|
402 |
+
Dimension is `num boxes` x 4.
|
403 |
+
Returns:
|
404 |
+
cropped (ndarray): the cropped list of images with dimension of
|
405 |
+
`height` x `width` x `channel`.
|
406 |
+
boxes (list): optional. Corresponding boxes to images. Dimension is
|
407 |
+
`num boxes` x 4.
|
408 |
+
"""
|
409 |
+
# explicitly dealing processing per image order to avoid flipping images.
|
410 |
+
if pad_size > 0:
|
411 |
+
images = [
|
412 |
+
pad_image(pad_size=pad_size, image=image, order=order)
|
413 |
+
for image in images
|
414 |
+
]
|
415 |
+
|
416 |
+
# image format should be CHW.
|
417 |
+
if order == "CHW":
|
418 |
+
if images[0].shape[1] == size and images[0].shape[2] == size:
|
419 |
+
return images, boxes
|
420 |
+
height = images[0].shape[1]
|
421 |
+
width = images[0].shape[2]
|
422 |
+
y_offset = 0
|
423 |
+
if height > size:
|
424 |
+
y_offset = int(np.random.randint(0, height - size))
|
425 |
+
x_offset = 0
|
426 |
+
if width > size:
|
427 |
+
x_offset = int(np.random.randint(0, width - size))
|
428 |
+
cropped = [
|
429 |
+
image[:, y_offset : y_offset + size, x_offset : x_offset + size]
|
430 |
+
for image in images
|
431 |
+
]
|
432 |
+
assert cropped[0].shape[1] == size, "Image not cropped properly"
|
433 |
+
assert cropped[0].shape[2] == size, "Image not cropped properly"
|
434 |
+
elif order == "HWC":
|
435 |
+
if images[0].shape[0] == size and images[0].shape[1] == size:
|
436 |
+
return images, boxes
|
437 |
+
height = images[0].shape[0]
|
438 |
+
width = images[0].shape[1]
|
439 |
+
y_offset = 0
|
440 |
+
if height > size:
|
441 |
+
y_offset = int(np.random.randint(0, height - size))
|
442 |
+
x_offset = 0
|
443 |
+
if width > size:
|
444 |
+
x_offset = int(np.random.randint(0, width - size))
|
445 |
+
cropped = [
|
446 |
+
image[y_offset : y_offset + size, x_offset : x_offset + size, :]
|
447 |
+
for image in images
|
448 |
+
]
|
449 |
+
assert cropped[0].shape[0] == size, "Image not cropped properly"
|
450 |
+
assert cropped[0].shape[1] == size, "Image not cropped properly"
|
451 |
+
|
452 |
+
if boxes is not None:
|
453 |
+
boxes = [crop_boxes(proposal, x_offset, y_offset) for proposal in boxes]
|
454 |
+
return cropped, boxes
|
455 |
+
|
456 |
+
|
457 |
+
def center_crop(size, image):
|
458 |
+
"""
|
459 |
+
Perform center crop on input images.
|
460 |
+
Args:
|
461 |
+
size (int): size of the cropped height and width.
|
462 |
+
image (array): the image to perform center crop.
|
463 |
+
"""
|
464 |
+
height = image.shape[0]
|
465 |
+
width = image.shape[1]
|
466 |
+
y_offset = int(math.ceil((height - size) / 2))
|
467 |
+
x_offset = int(math.ceil((width - size) / 2))
|
468 |
+
cropped = image[y_offset : y_offset + size, x_offset : x_offset + size, :]
|
469 |
+
assert cropped.shape[0] == size, "Image height not cropped properly"
|
470 |
+
assert cropped.shape[1] == size, "Image width not cropped properly"
|
471 |
+
return cropped
|
472 |
+
|
473 |
+
|
474 |
+
# ResNet style scale jittering: randomly select the scale from
|
475 |
+
# [1/max_size, 1/min_size]
|
476 |
+
def random_scale_jitter(image, min_size, max_size):
|
477 |
+
"""
|
478 |
+
Perform ResNet style random scale jittering: randomly select the scale from
|
479 |
+
[1/max_size, 1/min_size].
|
480 |
+
Args:
|
481 |
+
image (array): image to perform random scale.
|
482 |
+
min_size (int): min size to scale.
|
483 |
+
max_size (int) max size to scale.
|
484 |
+
Returns:
|
485 |
+
image (array): scaled image.
|
486 |
+
"""
|
487 |
+
img_scale = int(
|
488 |
+
round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
|
489 |
+
)
|
490 |
+
image = scale(img_scale, image)
|
491 |
+
return image
|
492 |
+
|
493 |
+
|
494 |
+
def random_scale_jitter_list(images, min_size, max_size):
|
495 |
+
"""
|
496 |
+
Perform ResNet style random scale jittering on a list of image: randomly
|
497 |
+
select the scale from [1/max_size, 1/min_size]. Note that all the image
|
498 |
+
will share the same scale.
|
499 |
+
Args:
|
500 |
+
images (list): list of images to perform random scale.
|
501 |
+
min_size (int): min size to scale.
|
502 |
+
max_size (int) max size to scale.
|
503 |
+
Returns:
|
504 |
+
images (list): list of scaled image.
|
505 |
+
"""
|
506 |
+
img_scale = int(
|
507 |
+
round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
|
508 |
+
)
|
509 |
+
return [scale(img_scale, image) for image in images]
|
510 |
+
|
511 |
+
|
512 |
+
def random_sized_crop(image, size, area_frac=0.08):
|
513 |
+
"""
|
514 |
+
Perform random sized cropping on the given image. Random crop with size
|
515 |
+
8% - 100% image area and aspect ratio in [3/4, 4/3].
|
516 |
+
Args:
|
517 |
+
image (array): image to crop.
|
518 |
+
size (int): size to crop.
|
519 |
+
area_frac (float): area of fraction.
|
520 |
+
Returns:
|
521 |
+
(array): cropped image.
|
522 |
+
"""
|
523 |
+
for _ in range(0, 10):
|
524 |
+
height = image.shape[0]
|
525 |
+
width = image.shape[1]
|
526 |
+
area = height * width
|
527 |
+
target_area = np.random.uniform(area_frac, 1.0) * area
|
528 |
+
aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
|
529 |
+
w = int(round(math.sqrt(float(target_area) * aspect_ratio)))
|
530 |
+
h = int(round(math.sqrt(float(target_area) / aspect_ratio)))
|
531 |
+
if np.random.uniform() < 0.5:
|
532 |
+
w, h = h, w
|
533 |
+
if h <= height and w <= width:
|
534 |
+
if height == h:
|
535 |
+
y_offset = 0
|
536 |
+
else:
|
537 |
+
y_offset = np.random.randint(0, height - h)
|
538 |
+
if width == w:
|
539 |
+
x_offset = 0
|
540 |
+
else:
|
541 |
+
x_offset = np.random.randint(0, width - w)
|
542 |
+
y_offset = int(y_offset)
|
543 |
+
x_offset = int(x_offset)
|
544 |
+
cropped = image[y_offset : y_offset + h, x_offset : x_offset + w, :]
|
545 |
+
assert (
|
546 |
+
cropped.shape[0] == h and cropped.shape[1] == w
|
547 |
+
), "Wrong crop size"
|
548 |
+
cropped = cv2.resize(
|
549 |
+
cropped, (size, size), interpolation=cv2.INTER_LINEAR
|
550 |
+
)
|
551 |
+
return cropped.astype(np.float32)
|
552 |
+
return center_crop(size, scale(size, image))
|
553 |
+
|
554 |
+
|
555 |
+
def lighting(img, alphastd, eigval, eigvec):
|
556 |
+
"""
|
557 |
+
Perform AlexNet-style PCA jitter on the given image.
|
558 |
+
Args:
|
559 |
+
image (array): list of images to perform lighting jitter.
|
560 |
+
alphastd (float): jitter ratio for PCA jitter.
|
561 |
+
eigval (array): eigenvalues for PCA jitter.
|
562 |
+
eigvec (list): eigenvectors for PCA jitter.
|
563 |
+
Returns:
|
564 |
+
img (tensor): the jittered image.
|
565 |
+
"""
|
566 |
+
if alphastd == 0:
|
567 |
+
return img
|
568 |
+
# generate alpha1, alpha2, alpha3.
|
569 |
+
alpha = np.random.normal(0, alphastd, size=(1, 3))
|
570 |
+
eig_vec = np.array(eigvec)
|
571 |
+
eig_val = np.reshape(eigval, (1, 3))
|
572 |
+
rgb = np.sum(
|
573 |
+
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
|
574 |
+
axis=1,
|
575 |
+
)
|
576 |
+
for idx in range(img.shape[0]):
|
577 |
+
img[idx] = img[idx] + rgb[2 - idx]
|
578 |
+
return img
|
579 |
+
|
580 |
+
|
581 |
+
def random_sized_crop_list(images, size, crop_area_fraction=0.08):
|
582 |
+
"""
|
583 |
+
Perform random sized cropping on the given list of images. Random crop with
|
584 |
+
size 8% - 100% image area and aspect ratio in [3/4, 4/3].
|
585 |
+
Args:
|
586 |
+
images (list): image to crop.
|
587 |
+
size (int): size to crop.
|
588 |
+
area_frac (float): area of fraction.
|
589 |
+
Returns:
|
590 |
+
(list): list of cropped image.
|
591 |
+
"""
|
592 |
+
for _ in range(0, 10):
|
593 |
+
height = images[0].shape[0]
|
594 |
+
width = images[0].shape[1]
|
595 |
+
area = height * width
|
596 |
+
target_area = np.random.uniform(crop_area_fraction, 1.0) * area
|
597 |
+
aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
|
598 |
+
w = int(round(math.sqrt(float(target_area) * aspect_ratio)))
|
599 |
+
h = int(round(math.sqrt(float(target_area) / aspect_ratio)))
|
600 |
+
if np.random.uniform() < 0.5:
|
601 |
+
w, h = h, w
|
602 |
+
if h <= height and w <= width:
|
603 |
+
if height == h:
|
604 |
+
y_offset = 0
|
605 |
+
else:
|
606 |
+
y_offset = np.random.randint(0, height - h)
|
607 |
+
if width == w:
|
608 |
+
x_offset = 0
|
609 |
+
else:
|
610 |
+
x_offset = np.random.randint(0, width - w)
|
611 |
+
y_offset = int(y_offset)
|
612 |
+
x_offset = int(x_offset)
|
613 |
+
|
614 |
+
croppsed_images = []
|
615 |
+
for image in images:
|
616 |
+
cropped = image[
|
617 |
+
y_offset : y_offset + h, x_offset : x_offset + w, :
|
618 |
+
]
|
619 |
+
assert (
|
620 |
+
cropped.shape[0] == h and cropped.shape[1] == w
|
621 |
+
), "Wrong crop size"
|
622 |
+
cropped = cv2.resize(
|
623 |
+
cropped, (size, size), interpolation=cv2.INTER_LINEAR
|
624 |
+
)
|
625 |
+
croppsed_images.append(cropped.astype(np.float32))
|
626 |
+
return croppsed_images
|
627 |
+
|
628 |
+
return [center_crop(size, scale(size, image)) for image in images]
|
629 |
+
|
630 |
+
|
631 |
+
def blend(image1, image2, alpha):
|
632 |
+
return image1 * alpha + image2 * (1 - alpha)
|
633 |
+
|
634 |
+
|
635 |
+
def grayscale(image):
|
636 |
+
"""
|
637 |
+
Convert the image to gray scale.
|
638 |
+
Args:
|
639 |
+
image (tensor): image to convert to gray scale. Dimension is
|
640 |
+
`channel` x `height` x `width`.
|
641 |
+
Returns:
|
642 |
+
img_gray (tensor): image in gray scale.
|
643 |
+
"""
|
644 |
+
# R -> 0.299, G -> 0.587, B -> 0.114.
|
645 |
+
img_gray = np.copy(image)
|
646 |
+
gray_channel = 0.299 * image[2] + 0.587 * image[1] + 0.114 * image[0]
|
647 |
+
img_gray[0] = gray_channel
|
648 |
+
img_gray[1] = gray_channel
|
649 |
+
img_gray[2] = gray_channel
|
650 |
+
return img_gray
|
651 |
+
|
652 |
+
|
653 |
+
def saturation(var, image):
|
654 |
+
"""
|
655 |
+
Perform color saturation on the given image.
|
656 |
+
Args:
|
657 |
+
var (float): variance.
|
658 |
+
image (array): image to perform color saturation.
|
659 |
+
Returns:
|
660 |
+
(array): image that performed color saturation.
|
661 |
+
"""
|
662 |
+
img_gray = grayscale(image)
|
663 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
664 |
+
return blend(image, img_gray, alpha)
|
665 |
+
|
666 |
+
|
667 |
+
def brightness(var, image):
|
668 |
+
"""
|
669 |
+
Perform color brightness on the given image.
|
670 |
+
Args:
|
671 |
+
var (float): variance.
|
672 |
+
image (array): image to perform color brightness.
|
673 |
+
Returns:
|
674 |
+
(array): image that performed color brightness.
|
675 |
+
"""
|
676 |
+
img_bright = np.zeros(image.shape).astype(image.dtype)
|
677 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
678 |
+
return blend(image, img_bright, alpha)
|
679 |
+
|
680 |
+
|
681 |
+
def contrast(var, image):
|
682 |
+
"""
|
683 |
+
Perform color contrast on the given image.
|
684 |
+
Args:
|
685 |
+
var (float): variance.
|
686 |
+
image (array): image to perform color contrast.
|
687 |
+
Returns:
|
688 |
+
(array): image that performed color contrast.
|
689 |
+
"""
|
690 |
+
img_gray = grayscale(image)
|
691 |
+
img_gray.fill(np.mean(img_gray[0]))
|
692 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
693 |
+
return blend(image, img_gray, alpha)
|
694 |
+
|
695 |
+
|
696 |
+
def saturation_list(var, images):
|
697 |
+
"""
|
698 |
+
Perform color saturation on the list of given images.
|
699 |
+
Args:
|
700 |
+
var (float): variance.
|
701 |
+
images (list): list of images to perform color saturation.
|
702 |
+
Returns:
|
703 |
+
(list): list of images that performed color saturation.
|
704 |
+
"""
|
705 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
706 |
+
|
707 |
+
out_images = []
|
708 |
+
for image in images:
|
709 |
+
img_gray = grayscale(image)
|
710 |
+
out_images.append(blend(image, img_gray, alpha))
|
711 |
+
return out_images
|
712 |
+
|
713 |
+
|
714 |
+
def brightness_list(var, images):
|
715 |
+
"""
|
716 |
+
Perform color brightness on the given list of images.
|
717 |
+
Args:
|
718 |
+
var (float): variance.
|
719 |
+
images (list): list of images to perform color brightness.
|
720 |
+
Returns:
|
721 |
+
(array): list of images that performed color brightness.
|
722 |
+
"""
|
723 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
724 |
+
|
725 |
+
out_images = []
|
726 |
+
for image in images:
|
727 |
+
img_bright = np.zeros(image.shape).astype(image.dtype)
|
728 |
+
out_images.append(blend(image, img_bright, alpha))
|
729 |
+
return out_images
|
730 |
+
|
731 |
+
|
732 |
+
def contrast_list(var, images):
|
733 |
+
"""
|
734 |
+
Perform color contrast on the given list of images.
|
735 |
+
Args:
|
736 |
+
var (float): variance.
|
737 |
+
images (list): list of images to perform color contrast.
|
738 |
+
Returns:
|
739 |
+
(array): image that performed color contrast.
|
740 |
+
"""
|
741 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
742 |
+
|
743 |
+
out_images = []
|
744 |
+
for image in images:
|
745 |
+
img_gray = grayscale(image)
|
746 |
+
img_gray.fill(np.mean(img_gray[0]))
|
747 |
+
out_images.append(blend(image, img_gray, alpha))
|
748 |
+
return out_images
|
749 |
+
|
750 |
+
|
751 |
+
def color_jitter(image, img_brightness=0, img_contrast=0, img_saturation=0):
|
752 |
+
"""
|
753 |
+
Perform color jitter on the given image.
|
754 |
+
Args:
|
755 |
+
image (array): image to perform color jitter.
|
756 |
+
img_brightness (float): jitter ratio for brightness.
|
757 |
+
img_contrast (float): jitter ratio for contrast.
|
758 |
+
img_saturation (float): jitter ratio for saturation.
|
759 |
+
Returns:
|
760 |
+
image (array): the jittered image.
|
761 |
+
"""
|
762 |
+
jitter = []
|
763 |
+
if img_brightness != 0:
|
764 |
+
jitter.append("brightness")
|
765 |
+
if img_contrast != 0:
|
766 |
+
jitter.append("contrast")
|
767 |
+
if img_saturation != 0:
|
768 |
+
jitter.append("saturation")
|
769 |
+
|
770 |
+
if len(jitter) > 0:
|
771 |
+
order = np.random.permutation(np.arange(len(jitter)))
|
772 |
+
for idx in range(0, len(jitter)):
|
773 |
+
if jitter[order[idx]] == "brightness":
|
774 |
+
image = brightness(img_brightness, image)
|
775 |
+
elif jitter[order[idx]] == "contrast":
|
776 |
+
image = contrast(img_contrast, image)
|
777 |
+
elif jitter[order[idx]] == "saturation":
|
778 |
+
image = saturation(img_saturation, image)
|
779 |
+
return image
|
780 |
+
|
781 |
+
|
782 |
+
def revert_scaled_boxes(size, boxes, img_height, img_width):
|
783 |
+
"""
|
784 |
+
Revert scaled input boxes to match the original image size.
|
785 |
+
Args:
|
786 |
+
size (int): size of the cropped image.
|
787 |
+
boxes (array): shape (num_boxes, 4).
|
788 |
+
img_height (int): height of original image.
|
789 |
+
img_width (int): width of original image.
|
790 |
+
Returns:
|
791 |
+
reverted_boxes (array): boxes scaled back to the original image size.
|
792 |
+
"""
|
793 |
+
scaled_aspect = np.min([img_height, img_width])
|
794 |
+
scale_ratio = scaled_aspect / size
|
795 |
+
reverted_boxes = boxes * scale_ratio
|
796 |
+
return reverted_boxes
|
TimeSformer/timesformer/datasets/decoder.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import torchvision.io as io
|
8 |
+
|
9 |
+
|
10 |
+
def temporal_sampling(frames, start_idx, end_idx, num_samples):
|
11 |
+
"""
|
12 |
+
Given the start and end frame index, sample num_samples frames between
|
13 |
+
the start and end with equal interval.
|
14 |
+
Args:
|
15 |
+
frames (tensor): a tensor of video frames, dimension is
|
16 |
+
`num video frames` x `channel` x `height` x `width`.
|
17 |
+
start_idx (int): the index of the start frame.
|
18 |
+
end_idx (int): the index of the end frame.
|
19 |
+
num_samples (int): number of frames to sample.
|
20 |
+
Returns:
|
21 |
+
frames (tersor): a tensor of temporal sampled video frames, dimension is
|
22 |
+
`num clip frames` x `channel` x `height` x `width`.
|
23 |
+
"""
|
24 |
+
index = torch.linspace(start_idx, end_idx, num_samples)
|
25 |
+
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
|
26 |
+
frames = torch.index_select(frames, 0, index)
|
27 |
+
return frames
|
28 |
+
|
29 |
+
|
30 |
+
def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
|
31 |
+
"""
|
32 |
+
Sample a clip of size clip_size from a video of size video_size and
|
33 |
+
return the indices of the first and last frame of the clip. If clip_idx is
|
34 |
+
-1, the clip is randomly sampled, otherwise uniformly split the video to
|
35 |
+
num_clips clips, and select the start and end index of clip_idx-th video
|
36 |
+
clip.
|
37 |
+
Args:
|
38 |
+
video_size (int): number of overall frames.
|
39 |
+
clip_size (int): size of the clip to sample from the frames.
|
40 |
+
clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
|
41 |
+
clip_idx is larger than -1, uniformly split the video to num_clips
|
42 |
+
clips, and select the start and end index of the clip_idx-th video
|
43 |
+
clip.
|
44 |
+
num_clips (int): overall number of clips to uniformly sample from the
|
45 |
+
given video for testing.
|
46 |
+
Returns:
|
47 |
+
start_idx (int): the start frame index.
|
48 |
+
end_idx (int): the end frame index.
|
49 |
+
"""
|
50 |
+
delta = max(video_size - clip_size, 0)
|
51 |
+
if clip_idx == -1:
|
52 |
+
# Random temporal sampling.
|
53 |
+
start_idx = random.uniform(0, delta)
|
54 |
+
else:
|
55 |
+
# Uniformly sample the clip with the given index.
|
56 |
+
start_idx = delta * clip_idx / num_clips
|
57 |
+
end_idx = start_idx + clip_size - 1
|
58 |
+
return start_idx, end_idx
|
59 |
+
|
60 |
+
|
61 |
+
def pyav_decode_stream(
|
62 |
+
container, start_pts, end_pts, stream, stream_name, buffer_size=0
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Decode the video with PyAV decoder.
|
66 |
+
Args:
|
67 |
+
container (container): PyAV container.
|
68 |
+
start_pts (int): the starting Presentation TimeStamp to fetch the
|
69 |
+
video frames.
|
70 |
+
end_pts (int): the ending Presentation TimeStamp of the decoded frames.
|
71 |
+
stream (stream): PyAV stream.
|
72 |
+
stream_name (dict): a dictionary of streams. For example, {"video": 0}
|
73 |
+
means video stream at stream index 0.
|
74 |
+
buffer_size (int): number of additional frames to decode beyond end_pts.
|
75 |
+
Returns:
|
76 |
+
result (list): list of frames decoded.
|
77 |
+
max_pts (int): max Presentation TimeStamp of the video sequence.
|
78 |
+
"""
|
79 |
+
# Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
|
80 |
+
# margin pts.
|
81 |
+
margin = 1024
|
82 |
+
seek_offset = max(start_pts - margin, 0)
|
83 |
+
|
84 |
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
85 |
+
frames = {}
|
86 |
+
buffer_count = 0
|
87 |
+
max_pts = 0
|
88 |
+
for frame in container.decode(**stream_name):
|
89 |
+
max_pts = max(max_pts, frame.pts)
|
90 |
+
if frame.pts < start_pts:
|
91 |
+
continue
|
92 |
+
if frame.pts <= end_pts:
|
93 |
+
frames[frame.pts] = frame
|
94 |
+
else:
|
95 |
+
buffer_count += 1
|
96 |
+
frames[frame.pts] = frame
|
97 |
+
if buffer_count >= buffer_size:
|
98 |
+
break
|
99 |
+
result = [frames[pts] for pts in sorted(frames)]
|
100 |
+
return result, max_pts
|
101 |
+
|
102 |
+
|
103 |
+
def torchvision_decode(
|
104 |
+
video_handle,
|
105 |
+
sampling_rate,
|
106 |
+
num_frames,
|
107 |
+
clip_idx,
|
108 |
+
video_meta,
|
109 |
+
num_clips=10,
|
110 |
+
target_fps=30,
|
111 |
+
modalities=("visual",),
|
112 |
+
max_spatial_scale=0,
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
If video_meta is not empty, perform temporal selective decoding to sample a
|
116 |
+
clip from the video with TorchVision decoder. If video_meta is empty, decode
|
117 |
+
the entire video and update the video_meta.
|
118 |
+
Args:
|
119 |
+
video_handle (bytes): raw bytes of the video file.
|
120 |
+
sampling_rate (int): frame sampling rate (interval between two sampled
|
121 |
+
frames).
|
122 |
+
num_frames (int): number of frames to sample.
|
123 |
+
clip_idx (int): if clip_idx is -1, perform random temporal
|
124 |
+
sampling. If clip_idx is larger than -1, uniformly split the
|
125 |
+
video to num_clips clips, and select the clip_idx-th video clip.
|
126 |
+
video_meta (dict): a dict contains VideoMetaData. Details can be found
|
127 |
+
at `pytorch/vision/torchvision/io/_video_opt.py`.
|
128 |
+
num_clips (int): overall number of clips to uniformly sample from the
|
129 |
+
given video.
|
130 |
+
target_fps (int): the input video may has different fps, convert it to
|
131 |
+
the target video fps.
|
132 |
+
modalities (tuple): tuple of modalities to decode. Currently only
|
133 |
+
support `visual`, planning to support `acoustic` soon.
|
134 |
+
max_spatial_scale (int): the maximal resolution of the spatial shorter
|
135 |
+
edge size during decoding.
|
136 |
+
Returns:
|
137 |
+
frames (tensor): decoded frames from the video.
|
138 |
+
fps (float): the number of frames per second of the video.
|
139 |
+
decode_all_video (bool): if True, the entire video was decoded.
|
140 |
+
"""
|
141 |
+
# Convert the bytes to a tensor.
|
142 |
+
video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
|
143 |
+
|
144 |
+
decode_all_video = True
|
145 |
+
video_start_pts, video_end_pts = 0, -1
|
146 |
+
# The video_meta is empty, fetch the meta data from the raw video.
|
147 |
+
if len(video_meta) == 0:
|
148 |
+
# Tracking the meta info for selective decoding in the future.
|
149 |
+
meta = io._probe_video_from_memory(video_tensor)
|
150 |
+
# Using the information from video_meta to perform selective decoding.
|
151 |
+
video_meta["video_timebase"] = meta.video_timebase
|
152 |
+
video_meta["video_numerator"] = meta.video_timebase.numerator
|
153 |
+
video_meta["video_denominator"] = meta.video_timebase.denominator
|
154 |
+
video_meta["has_video"] = meta.has_video
|
155 |
+
video_meta["video_duration"] = meta.video_duration
|
156 |
+
video_meta["video_fps"] = meta.video_fps
|
157 |
+
video_meta["audio_timebas"] = meta.audio_timebase
|
158 |
+
video_meta["audio_numerator"] = meta.audio_timebase.numerator
|
159 |
+
video_meta["audio_denominator"] = meta.audio_timebase.denominator
|
160 |
+
video_meta["has_audio"] = meta.has_audio
|
161 |
+
video_meta["audio_duration"] = meta.audio_duration
|
162 |
+
video_meta["audio_sample_rate"] = meta.audio_sample_rate
|
163 |
+
|
164 |
+
fps = video_meta["video_fps"]
|
165 |
+
if (
|
166 |
+
video_meta["has_video"]
|
167 |
+
and video_meta["video_denominator"] > 0
|
168 |
+
and video_meta["video_duration"] > 0
|
169 |
+
):
|
170 |
+
# try selective decoding.
|
171 |
+
decode_all_video = False
|
172 |
+
clip_size = sampling_rate * num_frames / target_fps * fps
|
173 |
+
start_idx, end_idx = get_start_end_idx(
|
174 |
+
fps * video_meta["video_duration"], clip_size, clip_idx, num_clips
|
175 |
+
)
|
176 |
+
# Convert frame index to pts.
|
177 |
+
pts_per_frame = video_meta["video_denominator"] / fps
|
178 |
+
video_start_pts = int(start_idx * pts_per_frame)
|
179 |
+
video_end_pts = int(end_idx * pts_per_frame)
|
180 |
+
|
181 |
+
# Decode the raw video with the tv decoder.
|
182 |
+
v_frames, _ = io._read_video_from_memory(
|
183 |
+
video_tensor,
|
184 |
+
seek_frame_margin=1.0,
|
185 |
+
read_video_stream="visual" in modalities,
|
186 |
+
video_width=0,
|
187 |
+
video_height=0,
|
188 |
+
video_min_dimension=max_spatial_scale,
|
189 |
+
video_pts_range=(video_start_pts, video_end_pts),
|
190 |
+
video_timebase_numerator=video_meta["video_numerator"],
|
191 |
+
video_timebase_denominator=video_meta["video_denominator"],
|
192 |
+
)
|
193 |
+
|
194 |
+
if v_frames.shape == torch.Size([0]):
|
195 |
+
# failed selective decoding
|
196 |
+
decode_all_video = True
|
197 |
+
video_start_pts, video_end_pts = 0, -1
|
198 |
+
v_frames, _ = io._read_video_from_memory(
|
199 |
+
video_tensor,
|
200 |
+
seek_frame_margin=1.0,
|
201 |
+
read_video_stream="visual" in modalities,
|
202 |
+
video_width=0,
|
203 |
+
video_height=0,
|
204 |
+
video_min_dimension=max_spatial_scale,
|
205 |
+
video_pts_range=(video_start_pts, video_end_pts),
|
206 |
+
video_timebase_numerator=video_meta["video_numerator"],
|
207 |
+
video_timebase_denominator=video_meta["video_denominator"],
|
208 |
+
)
|
209 |
+
|
210 |
+
return v_frames, fps, decode_all_video
|
211 |
+
|
212 |
+
|
213 |
+
def pyav_decode(
|
214 |
+
container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, start=None, end=None
|
215 |
+
, duration=None, frames_length=None):
|
216 |
+
"""
|
217 |
+
Convert the video from its original fps to the target_fps. If the video
|
218 |
+
support selective decoding (contain decoding information in the video head),
|
219 |
+
the perform temporal selective decoding and sample a clip from the video
|
220 |
+
with the PyAV decoder. If the video does not support selective decoding,
|
221 |
+
decode the entire video.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
container (container): pyav container.
|
225 |
+
sampling_rate (int): frame sampling rate (interval between two sampled
|
226 |
+
frames.
|
227 |
+
num_frames (int): number of frames to sample.
|
228 |
+
clip_idx (int): if clip_idx is -1, perform random temporal sampling. If
|
229 |
+
clip_idx is larger than -1, uniformly split the video to num_clips
|
230 |
+
clips, and select the clip_idx-th video clip.
|
231 |
+
num_clips (int): overall number of clips to uniformly sample from the
|
232 |
+
given video.
|
233 |
+
target_fps (int): the input video may has different fps, convert it to
|
234 |
+
the target video fps before frame sampling.
|
235 |
+
Returns:
|
236 |
+
frames (tensor): decoded frames from the video. Return None if the no
|
237 |
+
video stream was found.
|
238 |
+
fps (float): the number of frames per second of the video.
|
239 |
+
decode_all_video (bool): If True, the entire video was decoded.
|
240 |
+
"""
|
241 |
+
# Try to fetch the decoding information from the video head. Some of the
|
242 |
+
# videos does not support fetching the decoding information, for that case
|
243 |
+
# it will get None duration.
|
244 |
+
fps = float(container.streams.video[0].average_rate)
|
245 |
+
|
246 |
+
orig_duration = duration
|
247 |
+
tb = float(container.streams.video[0].time_base)
|
248 |
+
frames_length = container.streams.video[0].frames
|
249 |
+
duration = container.streams.video[0].duration
|
250 |
+
if duration is None and orig_duration is not None:
|
251 |
+
duration = orig_duration / tb
|
252 |
+
|
253 |
+
if duration is None:
|
254 |
+
# If failed to fetch the decoding information, decode the entire video.
|
255 |
+
decode_all_video = True
|
256 |
+
video_start_pts, video_end_pts = 0, math.inf
|
257 |
+
else:
|
258 |
+
# Perform selective decoding.
|
259 |
+
decode_all_video = False
|
260 |
+
start_idx, end_idx = get_start_end_idx(
|
261 |
+
frames_length,
|
262 |
+
sampling_rate * num_frames / target_fps * fps,
|
263 |
+
clip_idx,
|
264 |
+
num_clips,
|
265 |
+
)
|
266 |
+
timebase = duration / frames_length
|
267 |
+
video_start_pts = int(start_idx * timebase)
|
268 |
+
video_end_pts = int(end_idx * timebase)
|
269 |
+
|
270 |
+
if start is not None and end is not None:
|
271 |
+
decode_all_video = False
|
272 |
+
|
273 |
+
frames = None
|
274 |
+
# If video stream was found, fetch video frames from the video.
|
275 |
+
if container.streams.video:
|
276 |
+
if start is None and end is None:
|
277 |
+
video_frames, max_pts = pyav_decode_stream(
|
278 |
+
container,
|
279 |
+
video_start_pts,
|
280 |
+
video_end_pts,
|
281 |
+
container.streams.video[0],
|
282 |
+
{"video": 0},
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
timebase = duration / frames_length
|
286 |
+
start_i = start
|
287 |
+
end_i = end
|
288 |
+
video_frames, max_pts = pyav_decode_stream(
|
289 |
+
container,
|
290 |
+
start_i,
|
291 |
+
end_i,
|
292 |
+
container.streams.video[0],
|
293 |
+
{"video": 0},
|
294 |
+
)
|
295 |
+
container.close()
|
296 |
+
|
297 |
+
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
|
298 |
+
frames = torch.as_tensor(np.stack(frames))
|
299 |
+
|
300 |
+
return frames, fps, decode_all_video
|
301 |
+
|
302 |
+
|
303 |
+
def decode(
|
304 |
+
container,
|
305 |
+
sampling_rate,
|
306 |
+
num_frames,
|
307 |
+
clip_idx=-1,
|
308 |
+
num_clips=10,
|
309 |
+
video_meta=None,
|
310 |
+
target_fps=30,
|
311 |
+
backend="pyav",
|
312 |
+
max_spatial_scale=0,
|
313 |
+
start=None,
|
314 |
+
end=None,
|
315 |
+
duration=None,
|
316 |
+
frames_length=None,
|
317 |
+
):
|
318 |
+
"""
|
319 |
+
Decode the video and perform temporal sampling.
|
320 |
+
Args:
|
321 |
+
container (container): pyav container.
|
322 |
+
sampling_rate (int): frame sampling rate (interval between two sampled
|
323 |
+
frames).
|
324 |
+
num_frames (int): number of frames to sample.
|
325 |
+
clip_idx (int): if clip_idx is -1, perform random temporal
|
326 |
+
sampling. If clip_idx is larger than -1, uniformly split the
|
327 |
+
video to num_clips clips, and select the
|
328 |
+
clip_idx-th video clip.
|
329 |
+
num_clips (int): overall number of clips to uniformly
|
330 |
+
sample from the given video.
|
331 |
+
video_meta (dict): a dict contains VideoMetaData. Details can be find
|
332 |
+
at `pytorch/vision/torchvision/io/_video_opt.py`.
|
333 |
+
target_fps (int): the input video may have different fps, convert it to
|
334 |
+
the target video fps before frame sampling.
|
335 |
+
backend (str): decoding backend includes `pyav` and `torchvision`. The
|
336 |
+
default one is `pyav`.
|
337 |
+
max_spatial_scale (int): keep the aspect ratio and resize the frame so
|
338 |
+
that shorter edge size is max_spatial_scale. Only used in
|
339 |
+
`torchvision` backend.
|
340 |
+
Returns:
|
341 |
+
frames (tensor): decoded frames from the video.
|
342 |
+
"""
|
343 |
+
# Currently support two decoders: 1) PyAV, and 2) TorchVision.
|
344 |
+
assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
|
345 |
+
try:
|
346 |
+
if backend == "pyav":
|
347 |
+
frames, fps, decode_all_video = pyav_decode(
|
348 |
+
container,
|
349 |
+
sampling_rate,
|
350 |
+
num_frames,
|
351 |
+
clip_idx,
|
352 |
+
num_clips,
|
353 |
+
target_fps,
|
354 |
+
start,
|
355 |
+
end,
|
356 |
+
duration,
|
357 |
+
frames_length,
|
358 |
+
)
|
359 |
+
elif backend == "torchvision":
|
360 |
+
frames, fps, decode_all_video = torchvision_decode(
|
361 |
+
container,
|
362 |
+
sampling_rate,
|
363 |
+
num_frames,
|
364 |
+
clip_idx,
|
365 |
+
video_meta,
|
366 |
+
num_clips,
|
367 |
+
target_fps,
|
368 |
+
("visual",),
|
369 |
+
max_spatial_scale,
|
370 |
+
)
|
371 |
+
else:
|
372 |
+
raise NotImplementedError(
|
373 |
+
"Unknown decoding backend {}".format(backend)
|
374 |
+
)
|
375 |
+
except Exception as e:
|
376 |
+
print("Failed to decode by {} with exception: {}".format(backend, e))
|
377 |
+
return None
|
378 |
+
|
379 |
+
# Return None if the frames was not decoded successfully.
|
380 |
+
if frames is None or frames.size(0) == 0:
|
381 |
+
return None
|
382 |
+
|
383 |
+
clip_sz = sampling_rate * num_frames / target_fps * fps
|
384 |
+
start_idx, end_idx = get_start_end_idx(
|
385 |
+
frames.shape[0],
|
386 |
+
clip_sz,
|
387 |
+
clip_idx if decode_all_video else 0,
|
388 |
+
num_clips if decode_all_video else 1,
|
389 |
+
)
|
390 |
+
# Perform temporal sampling from the decoded video.
|
391 |
+
frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
|
392 |
+
return frames
|
TimeSformer/timesformer/datasets/kinetics.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.utils.data
|
7 |
+
from fvcore.common.file_io import PathManager
|
8 |
+
|
9 |
+
import timesformer.utils.logging as logging
|
10 |
+
|
11 |
+
from . import decoder as decoder
|
12 |
+
from . import utils as utils
|
13 |
+
from . import video_container as container
|
14 |
+
from .build import DATASET_REGISTRY
|
15 |
+
logger = logging.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
@DATASET_REGISTRY.register()
|
19 |
+
class Kinetics(torch.utils.data.Dataset):
|
20 |
+
"""
|
21 |
+
Kinetics video loader. Construct the Kinetics video loader, then sample
|
22 |
+
clips from the videos. For training and validation, a single clip is
|
23 |
+
randomly sampled from every video with random cropping, scaling, and
|
24 |
+
flipping. For testing, multiple clips are uniformaly sampled from every
|
25 |
+
video with uniform cropping. For uniform cropping, we take the left, center,
|
26 |
+
and right crop if the width is larger than height, or take top, center, and
|
27 |
+
bottom crop if the height is larger than the width.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, cfg, mode, num_retries=10):
|
31 |
+
"""
|
32 |
+
Construct the Kinetics video loader with a given csv file. The format of
|
33 |
+
the csv file is:
|
34 |
+
```
|
35 |
+
path_to_video_1 label_1
|
36 |
+
path_to_video_2 label_2
|
37 |
+
...
|
38 |
+
path_to_video_N label_N
|
39 |
+
```
|
40 |
+
Args:
|
41 |
+
cfg (CfgNode): configs.
|
42 |
+
mode (string): Options includes `train`, `val`, or `test` mode.
|
43 |
+
For the train and val mode, the data loader will take data
|
44 |
+
from the train or val set, and sample one clip per video.
|
45 |
+
For the test mode, the data loader will take data from test set,
|
46 |
+
and sample multiple clips per video.
|
47 |
+
num_retries (int): number of retries.
|
48 |
+
"""
|
49 |
+
# Only support train, val, and test mode.
|
50 |
+
assert mode in [
|
51 |
+
"train",
|
52 |
+
"val",
|
53 |
+
"test",
|
54 |
+
], "Split '{}' not supported for Kinetics".format(mode)
|
55 |
+
self.mode = mode
|
56 |
+
self.cfg = cfg
|
57 |
+
|
58 |
+
self._video_meta = {}
|
59 |
+
self._num_retries = num_retries
|
60 |
+
# For training or validation mode, one single clip is sampled from every
|
61 |
+
# video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
|
62 |
+
# video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
|
63 |
+
# the frames.
|
64 |
+
if self.mode in ["train", "val"]:
|
65 |
+
self._num_clips = 1
|
66 |
+
elif self.mode in ["test"]:
|
67 |
+
self._num_clips = (
|
68 |
+
cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
|
69 |
+
)
|
70 |
+
|
71 |
+
logger.info("Constructing Kinetics {}...".format(mode))
|
72 |
+
self._construct_loader()
|
73 |
+
|
74 |
+
def _construct_loader(self):
|
75 |
+
"""
|
76 |
+
Construct the video loader.
|
77 |
+
"""
|
78 |
+
path_to_file = os.path.join(
|
79 |
+
self.cfg.DATA.PATH_TO_DATA_DIR, "{}.csv".format(self.mode)
|
80 |
+
)
|
81 |
+
assert PathManager.exists(path_to_file), "{} dir not found".format(
|
82 |
+
path_to_file
|
83 |
+
)
|
84 |
+
|
85 |
+
self._path_to_videos = []
|
86 |
+
self._labels = []
|
87 |
+
self._spatial_temporal_idx = []
|
88 |
+
with PathManager.open(path_to_file, "r") as f:
|
89 |
+
for clip_idx, path_label in enumerate(f.read().splitlines()):
|
90 |
+
assert (
|
91 |
+
len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR))
|
92 |
+
== 2
|
93 |
+
)
|
94 |
+
path, label = path_label.split(
|
95 |
+
self.cfg.DATA.PATH_LABEL_SEPARATOR
|
96 |
+
)
|
97 |
+
for idx in range(self._num_clips):
|
98 |
+
self._path_to_videos.append(
|
99 |
+
os.path.join(self.cfg.DATA.PATH_PREFIX, path)
|
100 |
+
)
|
101 |
+
self._labels.append(int(label))
|
102 |
+
self._spatial_temporal_idx.append(idx)
|
103 |
+
self._video_meta[clip_idx * self._num_clips + idx] = {}
|
104 |
+
assert (
|
105 |
+
len(self._path_to_videos) > 0
|
106 |
+
), "Failed to load Kinetics split {} from {}".format(
|
107 |
+
self._split_idx, path_to_file
|
108 |
+
)
|
109 |
+
logger.info(
|
110 |
+
"Constructing kinetics dataloader (size: {}) from {}".format(
|
111 |
+
len(self._path_to_videos), path_to_file
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
def __getitem__(self, index):
|
116 |
+
"""
|
117 |
+
Given the video index, return the list of frames, label, and video
|
118 |
+
index if the video can be fetched and decoded successfully, otherwise
|
119 |
+
repeatly find a random video that can be decoded as a replacement.
|
120 |
+
Args:
|
121 |
+
index (int): the video index provided by the pytorch sampler.
|
122 |
+
Returns:
|
123 |
+
frames (tensor): the frames of sampled from the video. The dimension
|
124 |
+
is `channel` x `num frames` x `height` x `width`.
|
125 |
+
label (int): the label of the current video.
|
126 |
+
index (int): if the video provided by pytorch sampler can be
|
127 |
+
decoded, then return the index of the video. If not, return the
|
128 |
+
index of the video replacement that can be decoded.
|
129 |
+
"""
|
130 |
+
short_cycle_idx = None
|
131 |
+
# When short cycle is used, input index is a tupple.
|
132 |
+
if isinstance(index, tuple):
|
133 |
+
index, short_cycle_idx = index
|
134 |
+
|
135 |
+
if self.mode in ["train", "val"]:
|
136 |
+
# -1 indicates random sampling.
|
137 |
+
temporal_sample_index = -1
|
138 |
+
spatial_sample_index = -1
|
139 |
+
min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
|
140 |
+
max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
|
141 |
+
crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
|
142 |
+
if short_cycle_idx in [0, 1]:
|
143 |
+
crop_size = int(
|
144 |
+
round(
|
145 |
+
self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
|
146 |
+
* self.cfg.MULTIGRID.DEFAULT_S
|
147 |
+
)
|
148 |
+
)
|
149 |
+
if self.cfg.MULTIGRID.DEFAULT_S > 0:
|
150 |
+
# Decreasing the scale is equivalent to using a larger "span"
|
151 |
+
# in a sampling grid.
|
152 |
+
min_scale = int(
|
153 |
+
round(
|
154 |
+
float(min_scale)
|
155 |
+
* crop_size
|
156 |
+
/ self.cfg.MULTIGRID.DEFAULT_S
|
157 |
+
)
|
158 |
+
)
|
159 |
+
elif self.mode in ["test"]:
|
160 |
+
temporal_sample_index = (
|
161 |
+
self._spatial_temporal_idx[index]
|
162 |
+
// self.cfg.TEST.NUM_SPATIAL_CROPS
|
163 |
+
)
|
164 |
+
# spatial_sample_index is in [0, 1, 2]. Corresponding to left,
|
165 |
+
# center, or right if width is larger than height, and top, middle,
|
166 |
+
# or bottom if height is larger than width.
|
167 |
+
spatial_sample_index = (
|
168 |
+
(
|
169 |
+
self._spatial_temporal_idx[index]
|
170 |
+
% self.cfg.TEST.NUM_SPATIAL_CROPS
|
171 |
+
)
|
172 |
+
if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
|
173 |
+
else 1
|
174 |
+
)
|
175 |
+
min_scale, max_scale, crop_size = (
|
176 |
+
[self.cfg.DATA.TEST_CROP_SIZE] * 3
|
177 |
+
if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
|
178 |
+
else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2
|
179 |
+
+ [self.cfg.DATA.TEST_CROP_SIZE]
|
180 |
+
)
|
181 |
+
# The testing is deterministic and no jitter should be performed.
|
182 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
183 |
+
assert len({min_scale, max_scale}) == 1
|
184 |
+
else:
|
185 |
+
raise NotImplementedError(
|
186 |
+
"Does not support {} mode".format(self.mode)
|
187 |
+
)
|
188 |
+
sampling_rate = utils.get_random_sampling_rate(
|
189 |
+
self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
|
190 |
+
self.cfg.DATA.SAMPLING_RATE,
|
191 |
+
)
|
192 |
+
# Try to decode and sample a clip from a video. If the video can not be
|
193 |
+
# decoded, repeatly find a random video replacement that can be decoded.
|
194 |
+
for i_try in range(self._num_retries):
|
195 |
+
video_container = None
|
196 |
+
try:
|
197 |
+
video_container = container.get_video_container(
|
198 |
+
self._path_to_videos[index],
|
199 |
+
self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
|
200 |
+
self.cfg.DATA.DECODING_BACKEND,
|
201 |
+
)
|
202 |
+
except Exception as e:
|
203 |
+
logger.info(
|
204 |
+
"Failed to load video from {} with error {}".format(
|
205 |
+
self._path_to_videos[index], e
|
206 |
+
)
|
207 |
+
)
|
208 |
+
# Select a random video if the current video was not able to access.
|
209 |
+
if video_container is None:
|
210 |
+
logger.warning(
|
211 |
+
"Failed to meta load video idx {} from {}; trial {}".format(
|
212 |
+
index, self._path_to_videos[index], i_try
|
213 |
+
)
|
214 |
+
)
|
215 |
+
if self.mode not in ["test"] and i_try > self._num_retries // 2:
|
216 |
+
# let's try another one
|
217 |
+
index = random.randint(0, len(self._path_to_videos) - 1)
|
218 |
+
continue
|
219 |
+
|
220 |
+
# Decode video. Meta info is used to perform selective decoding.
|
221 |
+
frames = decoder.decode(
|
222 |
+
video_container,
|
223 |
+
sampling_rate,
|
224 |
+
self.cfg.DATA.NUM_FRAMES,
|
225 |
+
temporal_sample_index,
|
226 |
+
self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
|
227 |
+
video_meta=self._video_meta[index],
|
228 |
+
target_fps=self.cfg.DATA.TARGET_FPS,
|
229 |
+
backend=self.cfg.DATA.DECODING_BACKEND,
|
230 |
+
max_spatial_scale=min_scale,
|
231 |
+
)
|
232 |
+
|
233 |
+
# If decoding failed (wrong format, video is too short, and etc),
|
234 |
+
# select another video.
|
235 |
+
if frames is None:
|
236 |
+
logger.warning(
|
237 |
+
"Failed to decode video idx {} from {}; trial {}".format(
|
238 |
+
index, self._path_to_videos[index], i_try
|
239 |
+
)
|
240 |
+
)
|
241 |
+
if self.mode not in ["test"] and i_try > self._num_retries // 2:
|
242 |
+
# let's try another one
|
243 |
+
index = random.randint(0, len(self._path_to_videos) - 1)
|
244 |
+
continue
|
245 |
+
|
246 |
+
|
247 |
+
label = self._labels[index]
|
248 |
+
|
249 |
+
# Perform color normalization.
|
250 |
+
frames = utils.tensor_normalize(
|
251 |
+
frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
|
252 |
+
)
|
253 |
+
|
254 |
+
# T H W C -> C T H W.
|
255 |
+
frames = frames.permute(3, 0, 1, 2)
|
256 |
+
# Perform data augmentation.
|
257 |
+
frames = utils.spatial_sampling(
|
258 |
+
frames,
|
259 |
+
spatial_idx=spatial_sample_index,
|
260 |
+
min_scale=min_scale,
|
261 |
+
max_scale=max_scale,
|
262 |
+
crop_size=crop_size,
|
263 |
+
random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
|
264 |
+
inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
if not self.cfg.MODEL.ARCH in ['vit']:
|
269 |
+
frames = utils.pack_pathway_output(self.cfg, frames)
|
270 |
+
else:
|
271 |
+
# Perform temporal sampling from the fast pathway.
|
272 |
+
frames = torch.index_select(
|
273 |
+
frames,
|
274 |
+
1,
|
275 |
+
torch.linspace(
|
276 |
+
0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
|
277 |
+
|
278 |
+
).long(),
|
279 |
+
)
|
280 |
+
|
281 |
+
return frames, label, index, {}
|
282 |
+
else:
|
283 |
+
raise RuntimeError(
|
284 |
+
"Failed to fetch video after {} retries.".format(
|
285 |
+
self._num_retries
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
def __len__(self):
|
290 |
+
"""
|
291 |
+
Returns:
|
292 |
+
(int): the number of videos in the dataset.
|
293 |
+
"""
|
294 |
+
return len(self._path_to_videos)
|
TimeSformer/timesformer/datasets/loader.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""Data loader."""
|
4 |
+
|
5 |
+
import itertools
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data._utils.collate import default_collate
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
from torch.utils.data.sampler import RandomSampler
|
11 |
+
|
12 |
+
from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler
|
13 |
+
|
14 |
+
from . import utils as utils
|
15 |
+
from .build import build_dataset
|
16 |
+
|
17 |
+
|
18 |
+
def detection_collate(batch):
|
19 |
+
"""
|
20 |
+
Collate function for detection task. Concatanate bboxes, labels and
|
21 |
+
metadata from different samples in the first dimension instead of
|
22 |
+
stacking them to have a batch-size dimension.
|
23 |
+
Args:
|
24 |
+
batch (tuple or list): data batch to collate.
|
25 |
+
Returns:
|
26 |
+
(tuple): collated detection data batch.
|
27 |
+
"""
|
28 |
+
inputs, labels, video_idx, extra_data = zip(*batch)
|
29 |
+
inputs, video_idx = default_collate(inputs), default_collate(video_idx)
|
30 |
+
labels = torch.tensor(np.concatenate(labels, axis=0)).float()
|
31 |
+
|
32 |
+
collated_extra_data = {}
|
33 |
+
for key in extra_data[0].keys():
|
34 |
+
data = [d[key] for d in extra_data]
|
35 |
+
if key == "boxes" or key == "ori_boxes":
|
36 |
+
# Append idx info to the bboxes before concatenating them.
|
37 |
+
bboxes = [
|
38 |
+
np.concatenate(
|
39 |
+
[np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
|
40 |
+
)
|
41 |
+
for i in range(len(data))
|
42 |
+
]
|
43 |
+
bboxes = np.concatenate(bboxes, axis=0)
|
44 |
+
collated_extra_data[key] = torch.tensor(bboxes).float()
|
45 |
+
elif key == "metadata":
|
46 |
+
collated_extra_data[key] = torch.tensor(
|
47 |
+
list(itertools.chain(*data))
|
48 |
+
).view(-1, 2)
|
49 |
+
else:
|
50 |
+
collated_extra_data[key] = default_collate(data)
|
51 |
+
|
52 |
+
return inputs, labels, video_idx, collated_extra_data
|
53 |
+
|
54 |
+
|
55 |
+
def construct_loader(cfg, split, is_precise_bn=False):
|
56 |
+
"""
|
57 |
+
Constructs the data loader for the given dataset.
|
58 |
+
Args:
|
59 |
+
cfg (CfgNode): configs. Details can be found in
|
60 |
+
slowfast/config/defaults.py
|
61 |
+
split (str): the split of the data loader. Options include `train`,
|
62 |
+
`val`, and `test`.
|
63 |
+
"""
|
64 |
+
assert split in ["train", "val", "test"]
|
65 |
+
if split in ["train"]:
|
66 |
+
dataset_name = cfg.TRAIN.DATASET
|
67 |
+
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
|
68 |
+
shuffle = True
|
69 |
+
drop_last = True
|
70 |
+
elif split in ["val"]:
|
71 |
+
dataset_name = cfg.TRAIN.DATASET
|
72 |
+
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
|
73 |
+
shuffle = False
|
74 |
+
drop_last = False
|
75 |
+
elif split in ["test"]:
|
76 |
+
dataset_name = cfg.TEST.DATASET
|
77 |
+
batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
|
78 |
+
shuffle = False
|
79 |
+
drop_last = False
|
80 |
+
|
81 |
+
# Construct the dataset
|
82 |
+
dataset = build_dataset(dataset_name, cfg, split)
|
83 |
+
|
84 |
+
if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
|
85 |
+
# Create a sampler for multi-process training
|
86 |
+
sampler = utils.create_sampler(dataset, shuffle, cfg)
|
87 |
+
batch_sampler = ShortCycleBatchSampler(
|
88 |
+
sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
|
89 |
+
)
|
90 |
+
# Create a loader
|
91 |
+
loader = torch.utils.data.DataLoader(
|
92 |
+
dataset,
|
93 |
+
batch_sampler=batch_sampler,
|
94 |
+
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
|
95 |
+
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
|
96 |
+
worker_init_fn=utils.loader_worker_init_fn(dataset),
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
# Create a sampler for multi-process training
|
100 |
+
sampler = utils.create_sampler(dataset, shuffle, cfg)
|
101 |
+
# Create a loader
|
102 |
+
loader = torch.utils.data.DataLoader(
|
103 |
+
dataset,
|
104 |
+
batch_size=batch_size,
|
105 |
+
shuffle=(False if sampler else shuffle),
|
106 |
+
sampler=sampler,
|
107 |
+
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
|
108 |
+
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
|
109 |
+
drop_last=drop_last,
|
110 |
+
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
|
111 |
+
worker_init_fn=utils.loader_worker_init_fn(dataset),
|
112 |
+
)
|
113 |
+
return loader
|
114 |
+
|
115 |
+
|
116 |
+
def shuffle_dataset(loader, cur_epoch):
|
117 |
+
""" "
|
118 |
+
Shuffles the data.
|
119 |
+
Args:
|
120 |
+
loader (loader): data loader to perform shuffle.
|
121 |
+
cur_epoch (int): number of the current epoch.
|
122 |
+
"""
|
123 |
+
sampler = (
|
124 |
+
loader.batch_sampler.sampler
|
125 |
+
if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
|
126 |
+
else loader.sampler
|
127 |
+
)
|
128 |
+
assert isinstance(
|
129 |
+
sampler, (RandomSampler, DistributedSampler)
|
130 |
+
), "Sampler type '{}' not supported".format(type(sampler))
|
131 |
+
# RandomSampler handles shuffling automatically
|
132 |
+
if isinstance(sampler, DistributedSampler):
|
133 |
+
# DistributedSampler shuffles data based on epoch
|
134 |
+
sampler.set_epoch(cur_epoch)
|
TimeSformer/timesformer/datasets/multigrid_helper.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""Helper functions for multigrid training."""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from torch._six import int_classes as _int_classes
|
7 |
+
from torch.utils.data.sampler import Sampler
|
8 |
+
|
9 |
+
|
10 |
+
class ShortCycleBatchSampler(Sampler):
|
11 |
+
"""
|
12 |
+
Extend Sampler to support "short cycle" sampling.
|
13 |
+
See paper "A Multigrid Method for Efficiently Training Video Models",
|
14 |
+
Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, sampler, batch_size, drop_last, cfg):
|
18 |
+
if not isinstance(sampler, Sampler):
|
19 |
+
raise ValueError(
|
20 |
+
"sampler should be an instance of "
|
21 |
+
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
22 |
+
)
|
23 |
+
if (
|
24 |
+
not isinstance(batch_size, _int_classes)
|
25 |
+
or isinstance(batch_size, bool)
|
26 |
+
or batch_size <= 0
|
27 |
+
):
|
28 |
+
raise ValueError(
|
29 |
+
"batch_size should be a positive integer value, "
|
30 |
+
"but got batch_size={}".format(batch_size)
|
31 |
+
)
|
32 |
+
if not isinstance(drop_last, bool):
|
33 |
+
raise ValueError(
|
34 |
+
"drop_last should be a boolean value, but got "
|
35 |
+
"drop_last={}".format(drop_last)
|
36 |
+
)
|
37 |
+
self.sampler = sampler
|
38 |
+
self.drop_last = drop_last
|
39 |
+
|
40 |
+
bs_factor = [
|
41 |
+
int(
|
42 |
+
round(
|
43 |
+
(
|
44 |
+
float(cfg.DATA.TRAIN_CROP_SIZE)
|
45 |
+
/ (s * cfg.MULTIGRID.DEFAULT_S)
|
46 |
+
)
|
47 |
+
** 2
|
48 |
+
)
|
49 |
+
)
|
50 |
+
for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS
|
51 |
+
]
|
52 |
+
|
53 |
+
self.batch_sizes = [
|
54 |
+
batch_size * bs_factor[0],
|
55 |
+
batch_size * bs_factor[1],
|
56 |
+
batch_size,
|
57 |
+
]
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
counter = 0
|
61 |
+
batch_size = self.batch_sizes[0]
|
62 |
+
batch = []
|
63 |
+
for idx in self.sampler:
|
64 |
+
batch.append((idx, counter % 3))
|
65 |
+
if len(batch) == batch_size:
|
66 |
+
yield batch
|
67 |
+
counter += 1
|
68 |
+
batch_size = self.batch_sizes[counter % 3]
|
69 |
+
batch = []
|
70 |
+
if len(batch) > 0 and not self.drop_last:
|
71 |
+
yield batch
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
avg_batch_size = sum(self.batch_sizes) / 3.0
|
75 |
+
if self.drop_last:
|
76 |
+
return int(np.floor(len(self.sampler) / avg_batch_size))
|
77 |
+
else:
|
78 |
+
return int(np.ceil(len(self.sampler) / avg_batch_size))
|
TimeSformer/timesformer/datasets/ssv2.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from itertools import chain as chain
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
from fvcore.common.file_io import PathManager
|
11 |
+
|
12 |
+
import timesformer.utils.logging as logging
|
13 |
+
|
14 |
+
from . import utils as utils
|
15 |
+
from .build import DATASET_REGISTRY
|
16 |
+
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@DATASET_REGISTRY.register()
|
21 |
+
class Ssv2(torch.utils.data.Dataset):
|
22 |
+
"""
|
23 |
+
Something-Something v2 (SSV2) video loader. Construct the SSV2 video loader,
|
24 |
+
then sample clips from the videos. For training and validation, a single
|
25 |
+
clip is randomly sampled from every video with random cropping, scaling, and
|
26 |
+
flipping. For testing, multiple clips are uniformaly sampled from every
|
27 |
+
video with uniform cropping. For uniform cropping, we take the left, center,
|
28 |
+
and right crop if the width is larger than height, or take top, center, and
|
29 |
+
bottom crop if the height is larger than the width.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, cfg, mode, num_retries=10):
|
33 |
+
"""
|
34 |
+
Load Something-Something V2 data (frame paths, labels, etc. ) to a given
|
35 |
+
Dataset object. The dataset could be downloaded from Something-Something
|
36 |
+
official website (https://20bn.com/datasets/something-something).
|
37 |
+
Please see datasets/DATASET.md for more information about the data format.
|
38 |
+
Args:
|
39 |
+
cfg (CfgNode): configs.
|
40 |
+
mode (string): Options includes `train`, `val`, or `test` mode.
|
41 |
+
For the train and val mode, the data loader will take data
|
42 |
+
from the train or val set, and sample one clip per video.
|
43 |
+
For the test mode, the data loader will take data from test set,
|
44 |
+
and sample multiple clips per video.
|
45 |
+
num_retries (int): number of retries for reading frames from disk.
|
46 |
+
"""
|
47 |
+
# Only support train, val, and test mode.
|
48 |
+
assert mode in [
|
49 |
+
"train",
|
50 |
+
"val",
|
51 |
+
"test",
|
52 |
+
], "Split '{}' not supported for Something-Something V2".format(mode)
|
53 |
+
self.mode = mode
|
54 |
+
self.cfg = cfg
|
55 |
+
|
56 |
+
self._video_meta = {}
|
57 |
+
self._num_retries = num_retries
|
58 |
+
# For training or validation mode, one single clip is sampled from every
|
59 |
+
# video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
|
60 |
+
# video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
|
61 |
+
# the frames.
|
62 |
+
if self.mode in ["train", "val"]:
|
63 |
+
self._num_clips = 1
|
64 |
+
elif self.mode in ["test"]:
|
65 |
+
self._num_clips = (
|
66 |
+
cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
|
67 |
+
)
|
68 |
+
|
69 |
+
logger.info("Constructing Something-Something V2 {}...".format(mode))
|
70 |
+
self._construct_loader()
|
71 |
+
|
72 |
+
def _construct_loader(self):
|
73 |
+
"""
|
74 |
+
Construct the video loader.
|
75 |
+
"""
|
76 |
+
# Loading label names.
|
77 |
+
with PathManager.open(
|
78 |
+
os.path.join(
|
79 |
+
self.cfg.DATA.PATH_TO_DATA_DIR,
|
80 |
+
"something-something-v2-labels.json",
|
81 |
+
),
|
82 |
+
"r",
|
83 |
+
) as f:
|
84 |
+
label_dict = json.load(f)
|
85 |
+
|
86 |
+
# Loading labels.
|
87 |
+
label_file = os.path.join(
|
88 |
+
self.cfg.DATA.PATH_TO_DATA_DIR,
|
89 |
+
"something-something-v2-{}.json".format(
|
90 |
+
"train" if self.mode == "train" else "validation"
|
91 |
+
),
|
92 |
+
)
|
93 |
+
with PathManager.open(label_file, "r") as f:
|
94 |
+
label_json = json.load(f)
|
95 |
+
|
96 |
+
self._video_names = []
|
97 |
+
self._labels = []
|
98 |
+
for video in label_json:
|
99 |
+
video_name = video["id"]
|
100 |
+
template = video["template"]
|
101 |
+
template = template.replace("[", "")
|
102 |
+
template = template.replace("]", "")
|
103 |
+
label = int(label_dict[template])
|
104 |
+
self._video_names.append(video_name)
|
105 |
+
self._labels.append(label)
|
106 |
+
|
107 |
+
path_to_file = os.path.join(
|
108 |
+
self.cfg.DATA.PATH_TO_DATA_DIR,
|
109 |
+
"{}.csv".format("train" if self.mode == "train" else "val"),
|
110 |
+
)
|
111 |
+
assert PathManager.exists(path_to_file), "{} dir not found".format(
|
112 |
+
path_to_file
|
113 |
+
)
|
114 |
+
|
115 |
+
self._path_to_videos, _ = utils.load_image_lists(
|
116 |
+
path_to_file, self.cfg.DATA.PATH_PREFIX
|
117 |
+
)
|
118 |
+
|
119 |
+
assert len(self._path_to_videos) == len(self._video_names), (
|
120 |
+
len(self._path_to_videos),
|
121 |
+
len(self._video_names),
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
# From dict to list.
|
126 |
+
new_paths, new_labels = [], []
|
127 |
+
for index in range(len(self._video_names)):
|
128 |
+
if self._video_names[index] in self._path_to_videos:
|
129 |
+
new_paths.append(self._path_to_videos[self._video_names[index]])
|
130 |
+
new_labels.append(self._labels[index])
|
131 |
+
|
132 |
+
self._labels = new_labels
|
133 |
+
self._path_to_videos = new_paths
|
134 |
+
|
135 |
+
# Extend self when self._num_clips > 1 (during testing).
|
136 |
+
self._path_to_videos = list(
|
137 |
+
chain.from_iterable(
|
138 |
+
[[x] * self._num_clips for x in self._path_to_videos]
|
139 |
+
)
|
140 |
+
)
|
141 |
+
self._labels = list(
|
142 |
+
chain.from_iterable([[x] * self._num_clips for x in self._labels])
|
143 |
+
)
|
144 |
+
self._spatial_temporal_idx = list(
|
145 |
+
chain.from_iterable(
|
146 |
+
[
|
147 |
+
range(self._num_clips)
|
148 |
+
for _ in range(len(self._path_to_videos))
|
149 |
+
]
|
150 |
+
)
|
151 |
+
)
|
152 |
+
logger.info(
|
153 |
+
"Something-Something V2 dataloader constructed "
|
154 |
+
" (size: {}) from {}".format(
|
155 |
+
len(self._path_to_videos), path_to_file
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
def __getitem__(self, index):
|
160 |
+
"""
|
161 |
+
Given the video index, return the list of frames, label, and video
|
162 |
+
index if the video frames can be fetched.
|
163 |
+
Args:
|
164 |
+
index (int): the video index provided by the pytorch sampler.
|
165 |
+
Returns:
|
166 |
+
frames (tensor): the frames of sampled from the video. The dimension
|
167 |
+
is `channel` x `num frames` x `height` x `width`.
|
168 |
+
label (int): the label of the current video.
|
169 |
+
index (int): the index of the video.
|
170 |
+
"""
|
171 |
+
short_cycle_idx = None
|
172 |
+
# When short cycle is used, input index is a tupple.
|
173 |
+
if isinstance(index, tuple):
|
174 |
+
index, short_cycle_idx = index
|
175 |
+
|
176 |
+
if self.mode in ["train", "val"]: #or self.cfg.MODEL.ARCH in ['resformer', 'vit']:
|
177 |
+
# -1 indicates random sampling.
|
178 |
+
spatial_sample_index = -1
|
179 |
+
min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
|
180 |
+
max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
|
181 |
+
crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
|
182 |
+
if short_cycle_idx in [0, 1]:
|
183 |
+
crop_size = int(
|
184 |
+
round(
|
185 |
+
self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
|
186 |
+
* self.cfg.MULTIGRID.DEFAULT_S
|
187 |
+
)
|
188 |
+
)
|
189 |
+
if self.cfg.MULTIGRID.DEFAULT_S > 0:
|
190 |
+
# Decreasing the scale is equivalent to using a larger "span"
|
191 |
+
# in a sampling grid.
|
192 |
+
min_scale = int(
|
193 |
+
round(
|
194 |
+
float(min_scale)
|
195 |
+
* crop_size
|
196 |
+
/ self.cfg.MULTIGRID.DEFAULT_S
|
197 |
+
)
|
198 |
+
)
|
199 |
+
elif self.mode in ["test"]:
|
200 |
+
# spatial_sample_index is in [0, 1, 2]. Corresponding to left,
|
201 |
+
# center, or right if width is larger than height, and top, middle,
|
202 |
+
# or bottom if height is larger than width.
|
203 |
+
spatial_sample_index = (
|
204 |
+
self._spatial_temporal_idx[index]
|
205 |
+
% self.cfg.TEST.NUM_SPATIAL_CROPS
|
206 |
+
)
|
207 |
+
if self.cfg.TEST.NUM_SPATIAL_CROPS == 1:
|
208 |
+
spatial_sample_index = 1
|
209 |
+
|
210 |
+
min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3
|
211 |
+
# The testing is deterministic and no jitter should be performed.
|
212 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
213 |
+
assert len({min_scale, max_scale, crop_size}) == 1
|
214 |
+
else:
|
215 |
+
raise NotImplementedError(
|
216 |
+
"Does not support {} mode".format(self.mode)
|
217 |
+
)
|
218 |
+
|
219 |
+
label = self._labels[index]
|
220 |
+
|
221 |
+
num_frames = self.cfg.DATA.NUM_FRAMES
|
222 |
+
video_length = len(self._path_to_videos[index])
|
223 |
+
|
224 |
+
|
225 |
+
seg_size = float(video_length - 1) / num_frames
|
226 |
+
seq = []
|
227 |
+
for i in range(num_frames):
|
228 |
+
start = int(np.round(seg_size * i))
|
229 |
+
end = int(np.round(seg_size * (i + 1)))
|
230 |
+
if self.mode == "train":
|
231 |
+
seq.append(random.randint(start, end))
|
232 |
+
else:
|
233 |
+
seq.append((start + end) // 2)
|
234 |
+
|
235 |
+
frames = torch.as_tensor(
|
236 |
+
utils.retry_load_images(
|
237 |
+
[self._path_to_videos[index][frame] for frame in seq],
|
238 |
+
self._num_retries,
|
239 |
+
)
|
240 |
+
)
|
241 |
+
|
242 |
+
# Perform color normalization.
|
243 |
+
frames = utils.tensor_normalize(
|
244 |
+
frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
|
245 |
+
)
|
246 |
+
|
247 |
+
# T H W C -> C T H W.
|
248 |
+
frames = frames.permute(3, 0, 1, 2)
|
249 |
+
frames = utils.spatial_sampling(
|
250 |
+
frames,
|
251 |
+
spatial_idx=spatial_sample_index,
|
252 |
+
min_scale=min_scale,
|
253 |
+
max_scale=max_scale,
|
254 |
+
crop_size=crop_size,
|
255 |
+
random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
|
256 |
+
inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
|
257 |
+
)
|
258 |
+
#if not self.cfg.RESFORMER.ACTIVE:
|
259 |
+
if not self.cfg.MODEL.ARCH in ['vit']:
|
260 |
+
frames = utils.pack_pathway_output(self.cfg, frames)
|
261 |
+
else:
|
262 |
+
# Perform temporal sampling from the fast pathway.
|
263 |
+
frames = torch.index_select(
|
264 |
+
frames,
|
265 |
+
1,
|
266 |
+
torch.linspace(
|
267 |
+
0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
|
268 |
+
|
269 |
+
).long(),
|
270 |
+
)
|
271 |
+
return frames, label, index, {}
|
272 |
+
|
273 |
+
def __len__(self):
|
274 |
+
"""
|
275 |
+
Returns:
|
276 |
+
(int): the number of videos in the dataset.
|
277 |
+
"""
|
278 |
+
return len(self._path_to_videos)
|
TimeSformer/timesformer/datasets/transform.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def random_short_side_scale_jitter(
|
9 |
+
images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
|
10 |
+
):
|
11 |
+
"""
|
12 |
+
Perform a spatial short scale jittering on the given images and
|
13 |
+
corresponding boxes.
|
14 |
+
Args:
|
15 |
+
images (tensor): images to perform scale jitter. Dimension is
|
16 |
+
`num frames` x `channel` x `height` x `width`.
|
17 |
+
min_size (int): the minimal size to scale the frames.
|
18 |
+
max_size (int): the maximal size to scale the frames.
|
19 |
+
boxes (ndarray): optional. Corresponding boxes to images.
|
20 |
+
Dimension is `num boxes` x 4.
|
21 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
22 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
23 |
+
scale. If False, take a uniform sample from [min_scale, max_scale].
|
24 |
+
Returns:
|
25 |
+
(tensor): the scaled images with dimension of
|
26 |
+
`num frames` x `channel` x `new height` x `new width`.
|
27 |
+
(ndarray or None): the scaled boxes with dimension of
|
28 |
+
`num boxes` x 4.
|
29 |
+
"""
|
30 |
+
if inverse_uniform_sampling:
|
31 |
+
size = int(
|
32 |
+
round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
size = int(round(np.random.uniform(min_size, max_size)))
|
36 |
+
|
37 |
+
height = images.shape[2]
|
38 |
+
width = images.shape[3]
|
39 |
+
if (width <= height and width == size) or (
|
40 |
+
height <= width and height == size
|
41 |
+
):
|
42 |
+
return images, boxes
|
43 |
+
new_width = size
|
44 |
+
new_height = size
|
45 |
+
if width < height:
|
46 |
+
new_height = int(math.floor((float(height) / width) * size))
|
47 |
+
if boxes is not None:
|
48 |
+
boxes = boxes * float(new_height) / height
|
49 |
+
else:
|
50 |
+
new_width = int(math.floor((float(width) / height) * size))
|
51 |
+
if boxes is not None:
|
52 |
+
boxes = boxes * float(new_width) / width
|
53 |
+
|
54 |
+
return (
|
55 |
+
torch.nn.functional.interpolate(
|
56 |
+
images,
|
57 |
+
size=(new_height, new_width),
|
58 |
+
mode="bilinear",
|
59 |
+
align_corners=False,
|
60 |
+
),
|
61 |
+
boxes,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
66 |
+
"""
|
67 |
+
Peform crop on the bounding boxes given the offsets.
|
68 |
+
Args:
|
69 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
70 |
+
is `num boxes` x 4.
|
71 |
+
x_offset (int): cropping offset in the x axis.
|
72 |
+
y_offset (int): cropping offset in the y axis.
|
73 |
+
Returns:
|
74 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
75 |
+
`num boxes` x 4.
|
76 |
+
"""
|
77 |
+
cropped_boxes = boxes.copy()
|
78 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
79 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
80 |
+
|
81 |
+
return cropped_boxes
|
82 |
+
|
83 |
+
|
84 |
+
def random_crop(images, size, boxes=None):
|
85 |
+
"""
|
86 |
+
Perform random spatial crop on the given images and corresponding boxes.
|
87 |
+
Args:
|
88 |
+
images (tensor): images to perform random crop. The dimension is
|
89 |
+
`num frames` x `channel` x `height` x `width`.
|
90 |
+
size (int): the size of height and width to crop on the image.
|
91 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
92 |
+
Dimension is `num boxes` x 4.
|
93 |
+
Returns:
|
94 |
+
cropped (tensor): cropped images with dimension of
|
95 |
+
`num frames` x `channel` x `size` x `size`.
|
96 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
97 |
+
`num boxes` x 4.
|
98 |
+
"""
|
99 |
+
if images.shape[2] == size and images.shape[3] == size:
|
100 |
+
return images, None
|
101 |
+
height = images.shape[2]
|
102 |
+
width = images.shape[3]
|
103 |
+
y_offset = 0
|
104 |
+
if height > size:
|
105 |
+
y_offset = int(np.random.randint(0, height - size))
|
106 |
+
x_offset = 0
|
107 |
+
if width > size:
|
108 |
+
x_offset = int(np.random.randint(0, width - size))
|
109 |
+
cropped = images[
|
110 |
+
:, :, y_offset : y_offset + size, x_offset : x_offset + size
|
111 |
+
]
|
112 |
+
|
113 |
+
cropped_boxes = (
|
114 |
+
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
115 |
+
)
|
116 |
+
|
117 |
+
return cropped, cropped_boxes
|
118 |
+
|
119 |
+
|
120 |
+
def horizontal_flip(prob, images, boxes=None):
|
121 |
+
"""
|
122 |
+
Perform horizontal flip on the given images and corresponding boxes.
|
123 |
+
Args:
|
124 |
+
prob (float): probility to flip the images.
|
125 |
+
images (tensor): images to perform horizontal flip, the dimension is
|
126 |
+
`num frames` x `channel` x `height` x `width`.
|
127 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
128 |
+
Dimension is `num boxes` x 4.
|
129 |
+
Returns:
|
130 |
+
images (tensor): images with dimension of
|
131 |
+
`num frames` x `channel` x `height` x `width`.
|
132 |
+
flipped_boxes (ndarray or None): the flipped boxes with dimension of
|
133 |
+
`num boxes` x 4.
|
134 |
+
"""
|
135 |
+
if boxes is None:
|
136 |
+
flipped_boxes = None
|
137 |
+
else:
|
138 |
+
flipped_boxes = boxes.copy()
|
139 |
+
|
140 |
+
if np.random.uniform() < prob:
|
141 |
+
images = images.flip((-1))
|
142 |
+
|
143 |
+
width = images.shape[3]
|
144 |
+
if boxes is not None:
|
145 |
+
flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
|
146 |
+
|
147 |
+
return images, flipped_boxes
|
148 |
+
|
149 |
+
|
150 |
+
def uniform_crop(images, size, spatial_idx, boxes=None):
|
151 |
+
"""
|
152 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
153 |
+
Args:
|
154 |
+
images (tensor): images to perform uniform crop. The dimension is
|
155 |
+
`num frames` x `channel` x `height` x `width`.
|
156 |
+
size (int): size of height and weight to crop the images.
|
157 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
158 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
159 |
+
crop if height is larger than width.
|
160 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
161 |
+
Dimension is `num boxes` x 4.
|
162 |
+
Returns:
|
163 |
+
cropped (tensor): images with dimension of
|
164 |
+
`num frames` x `channel` x `size` x `size`.
|
165 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
166 |
+
`num boxes` x 4.
|
167 |
+
"""
|
168 |
+
assert spatial_idx in [0, 1, 2]
|
169 |
+
height = images.shape[2]
|
170 |
+
width = images.shape[3]
|
171 |
+
|
172 |
+
y_offset = int(math.ceil((height - size) / 2))
|
173 |
+
x_offset = int(math.ceil((width - size) / 2))
|
174 |
+
|
175 |
+
if height > width:
|
176 |
+
if spatial_idx == 0:
|
177 |
+
y_offset = 0
|
178 |
+
elif spatial_idx == 2:
|
179 |
+
y_offset = height - size
|
180 |
+
else:
|
181 |
+
if spatial_idx == 0:
|
182 |
+
x_offset = 0
|
183 |
+
elif spatial_idx == 2:
|
184 |
+
x_offset = width - size
|
185 |
+
cropped = images[
|
186 |
+
:, :, y_offset : y_offset + size, x_offset : x_offset + size
|
187 |
+
]
|
188 |
+
|
189 |
+
cropped_boxes = (
|
190 |
+
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
191 |
+
)
|
192 |
+
|
193 |
+
return cropped, cropped_boxes
|
194 |
+
|
195 |
+
|
196 |
+
def uniform_crop_2crops(images, size, spatial_idx, boxes=None):
|
197 |
+
"""
|
198 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
199 |
+
Args:
|
200 |
+
images (tensor): images to perform uniform crop. The dimension is
|
201 |
+
`num frames` x `channel` x `height` x `width`.
|
202 |
+
size (int): size of height and weight to crop the images.
|
203 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
204 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
205 |
+
crop if height is larger than width.
|
206 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
207 |
+
Dimension is `num boxes` x 4.
|
208 |
+
Returns:
|
209 |
+
cropped (tensor): images with dimension of
|
210 |
+
`num frames` x `channel` x `size` x `size`.
|
211 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
212 |
+
`num boxes` x 4.
|
213 |
+
"""
|
214 |
+
assert spatial_idx in [0, 1, 2]
|
215 |
+
height = images.shape[2]
|
216 |
+
width = images.shape[3]
|
217 |
+
|
218 |
+
|
219 |
+
if height > width:
|
220 |
+
x_offset = 0
|
221 |
+
if height > size * 2:
|
222 |
+
if spatial_idx == 0:
|
223 |
+
y_offset = int((height - size * 2) // 2)
|
224 |
+
elif spatial_idx == 1:
|
225 |
+
y_offset = int(height - size - ((height - size * 2) // 2))
|
226 |
+
else:
|
227 |
+
if spatial_idx == 0:
|
228 |
+
y_offset = 0
|
229 |
+
elif spatial_idx == 1:
|
230 |
+
y_offset = height - size
|
231 |
+
else:
|
232 |
+
y_offset = 0
|
233 |
+
if width > size * 2:
|
234 |
+
if spatial_idx == 0:
|
235 |
+
x_offset = int((width - size * 2) // 2)
|
236 |
+
elif spatial_idx == 1:
|
237 |
+
x_offset = int(width - size - ((width - size * 2) // 2))
|
238 |
+
else:
|
239 |
+
if spatial_idx == 0:
|
240 |
+
x_offset = 0
|
241 |
+
elif spatial_idx == 1:
|
242 |
+
x_offset = width - size
|
243 |
+
|
244 |
+
cropped = images[
|
245 |
+
:, :, y_offset : y_offset + size, x_offset : x_offset + size
|
246 |
+
]
|
247 |
+
|
248 |
+
cropped_boxes = (
|
249 |
+
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
250 |
+
)
|
251 |
+
|
252 |
+
return cropped, cropped_boxes
|
253 |
+
|
254 |
+
def clip_boxes_to_image(boxes, height, width):
|
255 |
+
"""
|
256 |
+
Clip an array of boxes to an image with the given height and width.
|
257 |
+
Args:
|
258 |
+
boxes (ndarray): bounding boxes to perform clipping.
|
259 |
+
Dimension is `num boxes` x 4.
|
260 |
+
height (int): given image height.
|
261 |
+
width (int): given image width.
|
262 |
+
Returns:
|
263 |
+
clipped_boxes (ndarray): the clipped boxes with dimension of
|
264 |
+
`num boxes` x 4.
|
265 |
+
"""
|
266 |
+
clipped_boxes = boxes.copy()
|
267 |
+
clipped_boxes[:, [0, 2]] = np.minimum(
|
268 |
+
width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
|
269 |
+
)
|
270 |
+
clipped_boxes[:, [1, 3]] = np.minimum(
|
271 |
+
height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
|
272 |
+
)
|
273 |
+
return clipped_boxes
|
274 |
+
|
275 |
+
|
276 |
+
def blend(images1, images2, alpha):
|
277 |
+
"""
|
278 |
+
Blend two images with a given weight alpha.
|
279 |
+
Args:
|
280 |
+
images1 (tensor): the first images to be blended, the dimension is
|
281 |
+
`num frames` x `channel` x `height` x `width`.
|
282 |
+
images2 (tensor): the second images to be blended, the dimension is
|
283 |
+
`num frames` x `channel` x `height` x `width`.
|
284 |
+
alpha (float): the blending weight.
|
285 |
+
Returns:
|
286 |
+
(tensor): blended images, the dimension is
|
287 |
+
`num frames` x `channel` x `height` x `width`.
|
288 |
+
"""
|
289 |
+
return images1 * alpha + images2 * (1 - alpha)
|
290 |
+
|
291 |
+
|
292 |
+
def grayscale(images):
|
293 |
+
"""
|
294 |
+
Get the grayscale for the input images. The channels of images should be
|
295 |
+
in order BGR.
|
296 |
+
Args:
|
297 |
+
images (tensor): the input images for getting grayscale. Dimension is
|
298 |
+
`num frames` x `channel` x `height` x `width`.
|
299 |
+
Returns:
|
300 |
+
img_gray (tensor): blended images, the dimension is
|
301 |
+
`num frames` x `channel` x `height` x `width`.
|
302 |
+
"""
|
303 |
+
# R -> 0.299, G -> 0.587, B -> 0.114.
|
304 |
+
img_gray = torch.tensor(images)
|
305 |
+
gray_channel = (
|
306 |
+
0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
|
307 |
+
)
|
308 |
+
img_gray[:, 0] = gray_channel
|
309 |
+
img_gray[:, 1] = gray_channel
|
310 |
+
img_gray[:, 2] = gray_channel
|
311 |
+
return img_gray
|
312 |
+
|
313 |
+
|
314 |
+
def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
|
315 |
+
"""
|
316 |
+
Perfrom a color jittering on the input images. The channels of images
|
317 |
+
should be in order BGR.
|
318 |
+
Args:
|
319 |
+
images (tensor): images to perform color jitter. Dimension is
|
320 |
+
`num frames` x `channel` x `height` x `width`.
|
321 |
+
img_brightness (float): jitter ratio for brightness.
|
322 |
+
img_contrast (float): jitter ratio for contrast.
|
323 |
+
img_saturation (float): jitter ratio for saturation.
|
324 |
+
Returns:
|
325 |
+
images (tensor): the jittered images, the dimension is
|
326 |
+
`num frames` x `channel` x `height` x `width`.
|
327 |
+
"""
|
328 |
+
|
329 |
+
jitter = []
|
330 |
+
if img_brightness != 0:
|
331 |
+
jitter.append("brightness")
|
332 |
+
if img_contrast != 0:
|
333 |
+
jitter.append("contrast")
|
334 |
+
if img_saturation != 0:
|
335 |
+
jitter.append("saturation")
|
336 |
+
|
337 |
+
if len(jitter) > 0:
|
338 |
+
order = np.random.permutation(np.arange(len(jitter)))
|
339 |
+
for idx in range(0, len(jitter)):
|
340 |
+
if jitter[order[idx]] == "brightness":
|
341 |
+
images = brightness_jitter(img_brightness, images)
|
342 |
+
elif jitter[order[idx]] == "contrast":
|
343 |
+
images = contrast_jitter(img_contrast, images)
|
344 |
+
elif jitter[order[idx]] == "saturation":
|
345 |
+
images = saturation_jitter(img_saturation, images)
|
346 |
+
return images
|
347 |
+
|
348 |
+
|
349 |
+
def brightness_jitter(var, images):
|
350 |
+
"""
|
351 |
+
Perfrom brightness jittering on the input images. The channels of images
|
352 |
+
should be in order BGR.
|
353 |
+
Args:
|
354 |
+
var (float): jitter ratio for brightness.
|
355 |
+
images (tensor): images to perform color jitter. Dimension is
|
356 |
+
`num frames` x `channel` x `height` x `width`.
|
357 |
+
Returns:
|
358 |
+
images (tensor): the jittered images, the dimension is
|
359 |
+
`num frames` x `channel` x `height` x `width`.
|
360 |
+
"""
|
361 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
362 |
+
|
363 |
+
img_bright = torch.zeros(images.shape)
|
364 |
+
images = blend(images, img_bright, alpha)
|
365 |
+
return images
|
366 |
+
|
367 |
+
|
368 |
+
def contrast_jitter(var, images):
|
369 |
+
"""
|
370 |
+
Perfrom contrast jittering on the input images. The channels of images
|
371 |
+
should be in order BGR.
|
372 |
+
Args:
|
373 |
+
var (float): jitter ratio for contrast.
|
374 |
+
images (tensor): images to perform color jitter. Dimension is
|
375 |
+
`num frames` x `channel` x `height` x `width`.
|
376 |
+
Returns:
|
377 |
+
images (tensor): the jittered images, the dimension is
|
378 |
+
`num frames` x `channel` x `height` x `width`.
|
379 |
+
"""
|
380 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
381 |
+
|
382 |
+
img_gray = grayscale(images)
|
383 |
+
img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
|
384 |
+
images = blend(images, img_gray, alpha)
|
385 |
+
return images
|
386 |
+
|
387 |
+
|
388 |
+
def saturation_jitter(var, images):
|
389 |
+
"""
|
390 |
+
Perfrom saturation jittering on the input images. The channels of images
|
391 |
+
should be in order BGR.
|
392 |
+
Args:
|
393 |
+
var (float): jitter ratio for saturation.
|
394 |
+
images (tensor): images to perform color jitter. Dimension is
|
395 |
+
`num frames` x `channel` x `height` x `width`.
|
396 |
+
Returns:
|
397 |
+
images (tensor): the jittered images, the dimension is
|
398 |
+
`num frames` x `channel` x `height` x `width`.
|
399 |
+
"""
|
400 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
401 |
+
img_gray = grayscale(images)
|
402 |
+
images = blend(images, img_gray, alpha)
|
403 |
+
|
404 |
+
return images
|
405 |
+
|
406 |
+
|
407 |
+
def lighting_jitter(images, alphastd, eigval, eigvec):
|
408 |
+
"""
|
409 |
+
Perform AlexNet-style PCA jitter on the given images.
|
410 |
+
Args:
|
411 |
+
images (tensor): images to perform lighting jitter. Dimension is
|
412 |
+
`num frames` x `channel` x `height` x `width`.
|
413 |
+
alphastd (float): jitter ratio for PCA jitter.
|
414 |
+
eigval (list): eigenvalues for PCA jitter.
|
415 |
+
eigvec (list[list]): eigenvectors for PCA jitter.
|
416 |
+
Returns:
|
417 |
+
out_images (tensor): the jittered images, the dimension is
|
418 |
+
`num frames` x `channel` x `height` x `width`.
|
419 |
+
"""
|
420 |
+
if alphastd == 0:
|
421 |
+
return images
|
422 |
+
# generate alpha1, alpha2, alpha3.
|
423 |
+
alpha = np.random.normal(0, alphastd, size=(1, 3))
|
424 |
+
eig_vec = np.array(eigvec)
|
425 |
+
eig_val = np.reshape(eigval, (1, 3))
|
426 |
+
rgb = np.sum(
|
427 |
+
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
|
428 |
+
axis=1,
|
429 |
+
)
|
430 |
+
out_images = torch.zeros_like(images)
|
431 |
+
for idx in range(images.shape[1]):
|
432 |
+
out_images[:, idx] = images[:, idx] + rgb[2 - idx]
|
433 |
+
|
434 |
+
return out_images
|
435 |
+
|
436 |
+
|
437 |
+
def color_normalization(images, mean, stddev):
|
438 |
+
"""
|
439 |
+
Perform color nomration on the given images.
|
440 |
+
Args:
|
441 |
+
images (tensor): images to perform color normalization. Dimension is
|
442 |
+
`num frames` x `channel` x `height` x `width`.
|
443 |
+
mean (list): mean values for normalization.
|
444 |
+
stddev (list): standard deviations for normalization.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
out_images (tensor): the noramlized images, the dimension is
|
448 |
+
`num frames` x `channel` x `height` x `width`.
|
449 |
+
"""
|
450 |
+
assert len(mean) == images.shape[1], "channel mean not computed properly"
|
451 |
+
assert (
|
452 |
+
len(stddev) == images.shape[1]
|
453 |
+
), "channel stddev not computed properly"
|
454 |
+
|
455 |
+
out_images = torch.zeros_like(images)
|
456 |
+
for idx in range(len(mean)):
|
457 |
+
out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
|
458 |
+
|
459 |
+
return out_images
|
TimeSformer/timesformer/datasets/utils.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from collections import defaultdict
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
from fvcore.common.file_io import PathManager
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
|
14 |
+
from . import transform as transform
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def retry_load_images(image_paths, retry=10, backend="pytorch"):
|
20 |
+
"""
|
21 |
+
This function is to load images with support of retrying for failed load.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
image_paths (list): paths of images needed to be loaded.
|
25 |
+
retry (int, optional): maximum time of loading retrying. Defaults to 10.
|
26 |
+
backend (str): `pytorch` or `cv2`.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
imgs (list): list of loaded images.
|
30 |
+
"""
|
31 |
+
for i in range(retry):
|
32 |
+
imgs = []
|
33 |
+
for image_path in image_paths:
|
34 |
+
with PathManager.open(image_path, "rb") as f:
|
35 |
+
img_str = np.frombuffer(f.read(), np.uint8)
|
36 |
+
img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
|
37 |
+
imgs.append(img)
|
38 |
+
|
39 |
+
if all(img is not None for img in imgs):
|
40 |
+
if backend == "pytorch":
|
41 |
+
imgs = torch.as_tensor(np.stack(imgs))
|
42 |
+
return imgs
|
43 |
+
else:
|
44 |
+
logger.warn("Reading failed. Will retry.")
|
45 |
+
time.sleep(1.0)
|
46 |
+
if i == retry - 1:
|
47 |
+
raise Exception("Failed to load images {}".format(image_paths))
|
48 |
+
|
49 |
+
|
50 |
+
def get_sequence(center_idx, half_len, sample_rate, num_frames):
|
51 |
+
"""
|
52 |
+
Sample frames among the corresponding clip.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
center_idx (int): center frame idx for current clip
|
56 |
+
half_len (int): half of the clip length
|
57 |
+
sample_rate (int): sampling rate for sampling frames inside of the clip
|
58 |
+
num_frames (int): number of expected sampled frames
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
seq (list): list of indexes of sampled frames in this clip.
|
62 |
+
"""
|
63 |
+
seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate))
|
64 |
+
|
65 |
+
for seq_idx in range(len(seq)):
|
66 |
+
if seq[seq_idx] < 0:
|
67 |
+
seq[seq_idx] = 0
|
68 |
+
elif seq[seq_idx] >= num_frames:
|
69 |
+
seq[seq_idx] = num_frames - 1
|
70 |
+
return seq
|
71 |
+
|
72 |
+
|
73 |
+
def pack_pathway_output(cfg, frames):
|
74 |
+
"""
|
75 |
+
Prepare output as a list of tensors. Each tensor corresponding to a
|
76 |
+
unique pathway.
|
77 |
+
Args:
|
78 |
+
frames (tensor): frames of images sampled from the video. The
|
79 |
+
dimension is `channel` x `num frames` x `height` x `width`.
|
80 |
+
Returns:
|
81 |
+
frame_list (list): list of tensors with the dimension of
|
82 |
+
`channel` x `num frames` x `height` x `width`.
|
83 |
+
"""
|
84 |
+
if cfg.DATA.REVERSE_INPUT_CHANNEL:
|
85 |
+
frames = frames[[2, 1, 0], :, :, :]
|
86 |
+
if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
|
87 |
+
frame_list = [frames]
|
88 |
+
elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH:
|
89 |
+
fast_pathway = frames
|
90 |
+
# Perform temporal sampling from the fast pathway.
|
91 |
+
slow_pathway = torch.index_select(
|
92 |
+
frames,
|
93 |
+
1,
|
94 |
+
torch.linspace(
|
95 |
+
0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA
|
96 |
+
).long(),
|
97 |
+
)
|
98 |
+
frame_list = [slow_pathway, fast_pathway]
|
99 |
+
else:
|
100 |
+
raise NotImplementedError(
|
101 |
+
"Model arch {} is not in {}".format(
|
102 |
+
cfg.MODEL.ARCH,
|
103 |
+
cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH,
|
104 |
+
)
|
105 |
+
)
|
106 |
+
return frame_list
|
107 |
+
|
108 |
+
|
109 |
+
def spatial_sampling(
|
110 |
+
frames,
|
111 |
+
spatial_idx=-1,
|
112 |
+
min_scale=256,
|
113 |
+
max_scale=320,
|
114 |
+
crop_size=224,
|
115 |
+
random_horizontal_flip=True,
|
116 |
+
inverse_uniform_sampling=False,
|
117 |
+
):
|
118 |
+
"""
|
119 |
+
Perform spatial sampling on the given video frames. If spatial_idx is
|
120 |
+
-1, perform random scale, random crop, and random flip on the given
|
121 |
+
frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
|
122 |
+
with the given spatial_idx.
|
123 |
+
Args:
|
124 |
+
frames (tensor): frames of images sampled from the video. The
|
125 |
+
dimension is `num frames` x `height` x `width` x `channel`.
|
126 |
+
spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
|
127 |
+
or 2, perform left, center, right crop if width is larger than
|
128 |
+
height, and perform top, center, buttom crop if height is larger
|
129 |
+
than width.
|
130 |
+
min_scale (int): the minimal size of scaling.
|
131 |
+
max_scale (int): the maximal size of scaling.
|
132 |
+
crop_size (int): the size of height and width used to crop the
|
133 |
+
frames.
|
134 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
135 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
136 |
+
scale. If False, take a uniform sample from [min_scale,
|
137 |
+
max_scale].
|
138 |
+
Returns:
|
139 |
+
frames (tensor): spatially sampled frames.
|
140 |
+
"""
|
141 |
+
assert spatial_idx in [-1, 0, 1, 2]
|
142 |
+
if spatial_idx == -1:
|
143 |
+
frames, _ = transform.random_short_side_scale_jitter(
|
144 |
+
images=frames,
|
145 |
+
min_size=min_scale,
|
146 |
+
max_size=max_scale,
|
147 |
+
inverse_uniform_sampling=inverse_uniform_sampling,
|
148 |
+
)
|
149 |
+
frames, _ = transform.random_crop(frames, crop_size)
|
150 |
+
if random_horizontal_flip:
|
151 |
+
frames, _ = transform.horizontal_flip(0.5, frames)
|
152 |
+
else:
|
153 |
+
# The testing is deterministic and no jitter should be performed.
|
154 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
155 |
+
#assert len({min_scale, max_scale, crop_size}) == 1
|
156 |
+
frames, _ = transform.random_short_side_scale_jitter(
|
157 |
+
frames, min_scale, max_scale
|
158 |
+
)
|
159 |
+
frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx)
|
160 |
+
return frames
|
161 |
+
|
162 |
+
def spatial_sampling_2crops(
|
163 |
+
frames,
|
164 |
+
spatial_idx=-1,
|
165 |
+
min_scale=256,
|
166 |
+
max_scale=320,
|
167 |
+
crop_size=224,
|
168 |
+
random_horizontal_flip=True,
|
169 |
+
inverse_uniform_sampling=False,
|
170 |
+
):
|
171 |
+
"""
|
172 |
+
Perform spatial sampling on the given video frames. If spatial_idx is
|
173 |
+
-1, perform random scale, random crop, and random flip on the given
|
174 |
+
frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
|
175 |
+
with the given spatial_idx.
|
176 |
+
Args:
|
177 |
+
frames (tensor): frames of images sampled from the video. The
|
178 |
+
dimension is `num frames` x `height` x `width` x `channel`.
|
179 |
+
spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
|
180 |
+
or 2, perform left, center, right crop if width is larger than
|
181 |
+
height, and perform top, center, buttom crop if height is larger
|
182 |
+
than width.
|
183 |
+
min_scale (int): the minimal size of scaling.
|
184 |
+
max_scale (int): the maximal size of scaling.
|
185 |
+
crop_size (int): the size of height and width used to crop the
|
186 |
+
frames.
|
187 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
188 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
189 |
+
scale. If False, take a uniform sample from [min_scale,
|
190 |
+
max_scale].
|
191 |
+
Returns:
|
192 |
+
frames (tensor): spatially sampled frames.
|
193 |
+
"""
|
194 |
+
assert spatial_idx in [-1, 0, 1, 2]
|
195 |
+
if spatial_idx == -1:
|
196 |
+
frames, _ = transform.random_short_side_scale_jitter(
|
197 |
+
images=frames,
|
198 |
+
min_size=min_scale,
|
199 |
+
max_size=max_scale,
|
200 |
+
inverse_uniform_sampling=inverse_uniform_sampling,
|
201 |
+
)
|
202 |
+
frames, _ = transform.random_crop(frames, crop_size)
|
203 |
+
if random_horizontal_flip:
|
204 |
+
frames, _ = transform.horizontal_flip(0.5, frames)
|
205 |
+
else:
|
206 |
+
# The testing is deterministic and no jitter should be performed.
|
207 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
208 |
+
#assert len({min_scale, max_scale, crop_size}) == 1
|
209 |
+
frames, _ = transform.random_short_side_scale_jitter(
|
210 |
+
frames, min_scale, max_scale
|
211 |
+
)
|
212 |
+
frames, _ = transform.uniform_crop_2crops(frames, crop_size, spatial_idx)
|
213 |
+
return frames
|
214 |
+
|
215 |
+
|
216 |
+
def as_binary_vector(labels, num_classes):
|
217 |
+
"""
|
218 |
+
Construct binary label vector given a list of label indices.
|
219 |
+
Args:
|
220 |
+
labels (list): The input label list.
|
221 |
+
num_classes (int): Number of classes of the label vector.
|
222 |
+
Returns:
|
223 |
+
labels (numpy array): the resulting binary vector.
|
224 |
+
"""
|
225 |
+
label_arr = np.zeros((num_classes,))
|
226 |
+
|
227 |
+
for lbl in set(labels):
|
228 |
+
label_arr[lbl] = 1.0
|
229 |
+
return label_arr
|
230 |
+
|
231 |
+
|
232 |
+
def aggregate_labels(label_list):
|
233 |
+
"""
|
234 |
+
Join a list of label list.
|
235 |
+
Args:
|
236 |
+
labels (list): The input label list.
|
237 |
+
Returns:
|
238 |
+
labels (list): The joint list of all lists in input.
|
239 |
+
"""
|
240 |
+
all_labels = []
|
241 |
+
for labels in label_list:
|
242 |
+
for l in labels:
|
243 |
+
all_labels.append(l)
|
244 |
+
return list(set(all_labels))
|
245 |
+
|
246 |
+
|
247 |
+
def convert_to_video_level_labels(labels):
|
248 |
+
"""
|
249 |
+
Aggregate annotations from all frames of a video to form video-level labels.
|
250 |
+
Args:
|
251 |
+
labels (list): The input label list.
|
252 |
+
Returns:
|
253 |
+
labels (list): Same as input, but with each label replaced by
|
254 |
+
a video-level one.
|
255 |
+
"""
|
256 |
+
for video_id in range(len(labels)):
|
257 |
+
video_level_labels = aggregate_labels(labels[video_id])
|
258 |
+
for i in range(len(labels[video_id])):
|
259 |
+
labels[video_id][i] = video_level_labels
|
260 |
+
return labels
|
261 |
+
|
262 |
+
|
263 |
+
def load_image_lists(frame_list_file, prefix="", return_list=False):
|
264 |
+
"""
|
265 |
+
Load image paths and labels from a "frame list".
|
266 |
+
Each line of the frame list contains:
|
267 |
+
`original_vido_id video_id frame_id path labels`
|
268 |
+
Args:
|
269 |
+
frame_list_file (string): path to the frame list.
|
270 |
+
prefix (str): the prefix for the path.
|
271 |
+
return_list (bool): if True, return a list. If False, return a dict.
|
272 |
+
Returns:
|
273 |
+
image_paths (list or dict): list of list containing path to each frame.
|
274 |
+
If return_list is False, then return in a dict form.
|
275 |
+
labels (list or dict): list of list containing label of each frame.
|
276 |
+
If return_list is False, then return in a dict form.
|
277 |
+
"""
|
278 |
+
image_paths = defaultdict(list)
|
279 |
+
labels = defaultdict(list)
|
280 |
+
with PathManager.open(frame_list_file, "r") as f:
|
281 |
+
assert f.readline().startswith("original_vido_id")
|
282 |
+
for line in f:
|
283 |
+
row = line.split()
|
284 |
+
# original_vido_id video_id frame_id path labels
|
285 |
+
assert len(row) == 5
|
286 |
+
video_name = row[0]
|
287 |
+
if prefix == "":
|
288 |
+
path = row[3]
|
289 |
+
else:
|
290 |
+
path = os.path.join(prefix, row[3])
|
291 |
+
image_paths[video_name].append(path)
|
292 |
+
frame_labels = row[-1].replace('"', "")
|
293 |
+
if frame_labels != "":
|
294 |
+
labels[video_name].append(
|
295 |
+
[int(x) for x in frame_labels.split(",")]
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
labels[video_name].append([])
|
299 |
+
|
300 |
+
if return_list:
|
301 |
+
keys = image_paths.keys()
|
302 |
+
image_paths = [image_paths[key] for key in keys]
|
303 |
+
labels = [labels[key] for key in keys]
|
304 |
+
return image_paths, labels
|
305 |
+
return dict(image_paths), dict(labels)
|
306 |
+
|
307 |
+
|
308 |
+
def tensor_normalize(tensor, mean, std):
|
309 |
+
"""
|
310 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
311 |
+
Args:
|
312 |
+
tensor (tensor): tensor to normalize.
|
313 |
+
mean (tensor or list): mean value to subtract.
|
314 |
+
std (tensor or list): std to divide.
|
315 |
+
"""
|
316 |
+
if tensor.dtype == torch.uint8:
|
317 |
+
tensor = tensor.float()
|
318 |
+
tensor = tensor / 255.0
|
319 |
+
if type(mean) == list:
|
320 |
+
mean = torch.tensor(mean)
|
321 |
+
if type(std) == list:
|
322 |
+
std = torch.tensor(std)
|
323 |
+
tensor = tensor - mean
|
324 |
+
tensor = tensor / std
|
325 |
+
return tensor
|
326 |
+
|
327 |
+
|
328 |
+
def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate):
|
329 |
+
"""
|
330 |
+
When multigrid training uses a fewer number of frames, we randomly
|
331 |
+
increase the sampling rate so that some clips cover the original span.
|
332 |
+
"""
|
333 |
+
if long_cycle_sampling_rate > 0:
|
334 |
+
assert long_cycle_sampling_rate >= sampling_rate
|
335 |
+
return random.randint(sampling_rate, long_cycle_sampling_rate)
|
336 |
+
else:
|
337 |
+
return sampling_rate
|
338 |
+
|
339 |
+
|
340 |
+
def revert_tensor_normalize(tensor, mean, std):
|
341 |
+
"""
|
342 |
+
Revert normalization for a given tensor by multiplying by the std and adding the mean.
|
343 |
+
Args:
|
344 |
+
tensor (tensor): tensor to revert normalization.
|
345 |
+
mean (tensor or list): mean value to add.
|
346 |
+
std (tensor or list): std to multiply.
|
347 |
+
"""
|
348 |
+
if type(mean) == list:
|
349 |
+
mean = torch.tensor(mean)
|
350 |
+
if type(std) == list:
|
351 |
+
std = torch.tensor(std)
|
352 |
+
tensor = tensor * std
|
353 |
+
tensor = tensor + mean
|
354 |
+
return tensor
|
355 |
+
|
356 |
+
|
357 |
+
def create_sampler(dataset, shuffle, cfg):
|
358 |
+
"""
|
359 |
+
Create sampler for the given dataset.
|
360 |
+
Args:
|
361 |
+
dataset (torch.utils.data.Dataset): the given dataset.
|
362 |
+
shuffle (bool): set to ``True`` to have the data reshuffled
|
363 |
+
at every epoch.
|
364 |
+
cfg (CfgNode): configs. Details can be found in
|
365 |
+
slowfast/config/defaults.py
|
366 |
+
Returns:
|
367 |
+
sampler (Sampler): the created sampler.
|
368 |
+
"""
|
369 |
+
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
|
370 |
+
|
371 |
+
return sampler
|
372 |
+
|
373 |
+
|
374 |
+
def loader_worker_init_fn(dataset):
|
375 |
+
"""
|
376 |
+
Create init function passed to pytorch data loader.
|
377 |
+
Args:
|
378 |
+
dataset (torch.utils.data.Dataset): the given dataset.
|
379 |
+
"""
|
380 |
+
return None
|
TimeSformer/timesformer/datasets/video_container.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import av
|
4 |
+
|
5 |
+
|
6 |
+
def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
|
7 |
+
"""
|
8 |
+
Given the path to the video, return the pyav video container.
|
9 |
+
Args:
|
10 |
+
path_to_vid (str): path to the video.
|
11 |
+
multi_thread_decode (bool): if True, perform multi-thread decoding.
|
12 |
+
backend (str): decoder backend, options include `pyav` and
|
13 |
+
`torchvision`, default is `pyav`.
|
14 |
+
Returns:
|
15 |
+
container (container): video container.
|
16 |
+
"""
|
17 |
+
if backend == "torchvision":
|
18 |
+
with open(path_to_vid, "rb") as fp:
|
19 |
+
container = fp.read()
|
20 |
+
return container
|
21 |
+
elif backend == "pyav":
|
22 |
+
#try:
|
23 |
+
container = av.open(path_to_vid)
|
24 |
+
if multi_thread_decode:
|
25 |
+
# Enable multiple threads for decoding.
|
26 |
+
container.streams.video[0].thread_type = "AUTO"
|
27 |
+
#except:
|
28 |
+
# container = None
|
29 |
+
return container
|
30 |
+
else:
|
31 |
+
raise NotImplementedError("Unknown backend {}".format(backend))
|
TimeSformer/timesformer/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from .build import MODEL_REGISTRY, build_model # noqa
|
4 |
+
from .custom_video_model_builder import * # noqa
|
5 |
+
from .video_model_builder import ResNet, SlowFast # noqa
|
TimeSformer/timesformer/models/batchnorm_helper.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""BatchNorm (BN) utility functions and custom batch-size BN implementations"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.autograd.function import Function
|
10 |
+
|
11 |
+
import timesformer.utils.distributed as du
|
12 |
+
|
13 |
+
|
14 |
+
def get_norm(cfg):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
cfg (CfgNode): model building configs, details are in the comments of
|
18 |
+
the config file.
|
19 |
+
Returns:
|
20 |
+
nn.Module: the normalization layer.
|
21 |
+
"""
|
22 |
+
if cfg.BN.NORM_TYPE == "batchnorm":
|
23 |
+
return nn.BatchNorm3d
|
24 |
+
elif cfg.BN.NORM_TYPE == "sub_batchnorm":
|
25 |
+
return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS)
|
26 |
+
elif cfg.BN.NORM_TYPE == "sync_batchnorm":
|
27 |
+
return partial(
|
28 |
+
NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
raise NotImplementedError(
|
32 |
+
"Norm type {} is not supported".format(cfg.BN.NORM_TYPE)
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class SubBatchNorm3d(nn.Module):
|
37 |
+
"""
|
38 |
+
The standard BN layer computes stats across all examples in a GPU. In some
|
39 |
+
cases it is desirable to compute stats across only a subset of examples
|
40 |
+
(e.g., in multigrid training https://arxiv.org/abs/1912.00998).
|
41 |
+
SubBatchNorm3d splits the batch dimension into N splits, and run BN on
|
42 |
+
each of them separately (so that the stats are computed on each subset of
|
43 |
+
examples (1/N of batch) independently. During evaluation, it aggregates
|
44 |
+
the stats from all splits into one BN.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, num_splits, **args):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
num_splits (int): number of splits.
|
51 |
+
args (list): other arguments.
|
52 |
+
"""
|
53 |
+
super(SubBatchNorm3d, self).__init__()
|
54 |
+
self.num_splits = num_splits
|
55 |
+
num_features = args["num_features"]
|
56 |
+
# Keep only one set of weight and bias.
|
57 |
+
if args.get("affine", True):
|
58 |
+
self.affine = True
|
59 |
+
args["affine"] = False
|
60 |
+
self.weight = torch.nn.Parameter(torch.ones(num_features))
|
61 |
+
self.bias = torch.nn.Parameter(torch.zeros(num_features))
|
62 |
+
else:
|
63 |
+
self.affine = False
|
64 |
+
self.bn = nn.BatchNorm3d(**args)
|
65 |
+
args["num_features"] = num_features * num_splits
|
66 |
+
self.split_bn = nn.BatchNorm3d(**args)
|
67 |
+
|
68 |
+
def _get_aggregated_mean_std(self, means, stds, n):
|
69 |
+
"""
|
70 |
+
Calculate the aggregated mean and stds.
|
71 |
+
Args:
|
72 |
+
means (tensor): mean values.
|
73 |
+
stds (tensor): standard deviations.
|
74 |
+
n (int): number of sets of means and stds.
|
75 |
+
"""
|
76 |
+
mean = means.view(n, -1).sum(0) / n
|
77 |
+
std = (
|
78 |
+
stds.view(n, -1).sum(0) / n
|
79 |
+
+ ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n
|
80 |
+
)
|
81 |
+
return mean.detach(), std.detach()
|
82 |
+
|
83 |
+
def aggregate_stats(self):
|
84 |
+
"""
|
85 |
+
Synchronize running_mean, and running_var. Call this before eval.
|
86 |
+
"""
|
87 |
+
if self.split_bn.track_running_stats:
|
88 |
+
(
|
89 |
+
self.bn.running_mean.data,
|
90 |
+
self.bn.running_var.data,
|
91 |
+
) = self._get_aggregated_mean_std(
|
92 |
+
self.split_bn.running_mean,
|
93 |
+
self.split_bn.running_var,
|
94 |
+
self.num_splits,
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
if self.training:
|
99 |
+
n, c, t, h, w = x.shape
|
100 |
+
x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
|
101 |
+
x = self.split_bn(x)
|
102 |
+
x = x.view(n, c, t, h, w)
|
103 |
+
else:
|
104 |
+
x = self.bn(x)
|
105 |
+
if self.affine:
|
106 |
+
x = x * self.weight.view((-1, 1, 1, 1))
|
107 |
+
x = x + self.bias.view((-1, 1, 1, 1))
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class GroupGather(Function):
|
112 |
+
"""
|
113 |
+
GroupGather performs all gather on each of the local process/ GPU groups.
|
114 |
+
"""
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def forward(ctx, input, num_sync_devices, num_groups):
|
118 |
+
"""
|
119 |
+
Perform forwarding, gathering the stats across different process/ GPU
|
120 |
+
group.
|
121 |
+
"""
|
122 |
+
ctx.num_sync_devices = num_sync_devices
|
123 |
+
ctx.num_groups = num_groups
|
124 |
+
|
125 |
+
input_list = [
|
126 |
+
torch.zeros_like(input) for k in range(du.get_local_size())
|
127 |
+
]
|
128 |
+
dist.all_gather(
|
129 |
+
input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP
|
130 |
+
)
|
131 |
+
|
132 |
+
inputs = torch.stack(input_list, dim=0)
|
133 |
+
if num_groups > 1:
|
134 |
+
rank = du.get_local_rank()
|
135 |
+
group_idx = rank // num_sync_devices
|
136 |
+
inputs = inputs[
|
137 |
+
group_idx
|
138 |
+
* num_sync_devices : (group_idx + 1)
|
139 |
+
* num_sync_devices
|
140 |
+
]
|
141 |
+
inputs = torch.sum(inputs, dim=0)
|
142 |
+
return inputs
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def backward(ctx, grad_output):
|
146 |
+
"""
|
147 |
+
Perform backwarding, gathering the gradients across different process/ GPU
|
148 |
+
group.
|
149 |
+
"""
|
150 |
+
grad_output_list = [
|
151 |
+
torch.zeros_like(grad_output) for k in range(du.get_local_size())
|
152 |
+
]
|
153 |
+
dist.all_gather(
|
154 |
+
grad_output_list,
|
155 |
+
grad_output,
|
156 |
+
async_op=False,
|
157 |
+
group=du._LOCAL_PROCESS_GROUP,
|
158 |
+
)
|
159 |
+
|
160 |
+
grads = torch.stack(grad_output_list, dim=0)
|
161 |
+
if ctx.num_groups > 1:
|
162 |
+
rank = du.get_local_rank()
|
163 |
+
group_idx = rank // ctx.num_sync_devices
|
164 |
+
grads = grads[
|
165 |
+
group_idx
|
166 |
+
* ctx.num_sync_devices : (group_idx + 1)
|
167 |
+
* ctx.num_sync_devices
|
168 |
+
]
|
169 |
+
grads = torch.sum(grads, dim=0)
|
170 |
+
return grads, None, None
|
171 |
+
|
172 |
+
|
173 |
+
class NaiveSyncBatchNorm3d(nn.BatchNorm3d):
|
174 |
+
def __init__(self, num_sync_devices, **args):
|
175 |
+
"""
|
176 |
+
Naive version of Synchronized 3D BatchNorm.
|
177 |
+
Args:
|
178 |
+
num_sync_devices (int): number of device to sync.
|
179 |
+
args (list): other arguments.
|
180 |
+
"""
|
181 |
+
self.num_sync_devices = num_sync_devices
|
182 |
+
if self.num_sync_devices > 0:
|
183 |
+
assert du.get_local_size() % self.num_sync_devices == 0, (
|
184 |
+
du.get_local_size(),
|
185 |
+
self.num_sync_devices,
|
186 |
+
)
|
187 |
+
self.num_groups = du.get_local_size() // self.num_sync_devices
|
188 |
+
else:
|
189 |
+
self.num_sync_devices = du.get_local_size()
|
190 |
+
self.num_groups = 1
|
191 |
+
super(NaiveSyncBatchNorm3d, self).__init__(**args)
|
192 |
+
|
193 |
+
def forward(self, input):
|
194 |
+
if du.get_local_size() == 1 or not self.training:
|
195 |
+
return super().forward(input)
|
196 |
+
|
197 |
+
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
|
198 |
+
C = input.shape[1]
|
199 |
+
mean = torch.mean(input, dim=[0, 2, 3, 4])
|
200 |
+
meansqr = torch.mean(input * input, dim=[0, 2, 3, 4])
|
201 |
+
|
202 |
+
vec = torch.cat([mean, meansqr], dim=0)
|
203 |
+
vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
|
204 |
+
1.0 / self.num_sync_devices
|
205 |
+
)
|
206 |
+
|
207 |
+
mean, meansqr = torch.split(vec, C)
|
208 |
+
var = meansqr - mean * mean
|
209 |
+
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
|
210 |
+
self.running_var += self.momentum * (var.detach() - self.running_var)
|
211 |
+
|
212 |
+
invstd = torch.rsqrt(var + self.eps)
|
213 |
+
scale = self.weight * invstd
|
214 |
+
bias = self.bias - mean * scale
|
215 |
+
scale = scale.reshape(1, -1, 1, 1, 1)
|
216 |
+
bias = bias.reshape(1, -1, 1, 1, 1)
|
217 |
+
return input * scale + bias
|
TimeSformer/timesformer/models/build.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""Model construction functions."""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from fvcore.common.registry import Registry
|
7 |
+
|
8 |
+
MODEL_REGISTRY = Registry("MODEL")
|
9 |
+
MODEL_REGISTRY.__doc__ = """
|
10 |
+
Registry for video model.
|
11 |
+
|
12 |
+
The registered object will be called with `obj(cfg)`.
|
13 |
+
The call should return a `torch.nn.Module` object.
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
def build_model(cfg, gpu_id=None):
|
18 |
+
"""
|
19 |
+
Builds the video model.
|
20 |
+
Args:
|
21 |
+
cfg (configs): configs that contains the hyper-parameters to build the
|
22 |
+
backbone. Details can be seen in slowfast/config/defaults.py.
|
23 |
+
gpu_id (Optional[int]): specify the gpu index to build model.
|
24 |
+
"""
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
assert (
|
27 |
+
cfg.NUM_GPUS <= torch.cuda.device_count()
|
28 |
+
), "Cannot use more GPU devices than available"
|
29 |
+
else:
|
30 |
+
assert (
|
31 |
+
cfg.NUM_GPUS == 0
|
32 |
+
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs."
|
33 |
+
|
34 |
+
# Construct the model
|
35 |
+
name = cfg.MODEL.MODEL_NAME
|
36 |
+
model = MODEL_REGISTRY.get(name)(cfg)
|
37 |
+
|
38 |
+
if cfg.NUM_GPUS:
|
39 |
+
if gpu_id is None:
|
40 |
+
# Determine the GPU used by the current process
|
41 |
+
cur_device = torch.cuda.current_device()
|
42 |
+
else:
|
43 |
+
cur_device = gpu_id
|
44 |
+
# Transfer the model to the current GPU device
|
45 |
+
model = model.cuda(device=cur_device)
|
46 |
+
|
47 |
+
|
48 |
+
# Use multi-process data parallel model in the multi-gpu setting
|
49 |
+
if cfg.NUM_GPUS > 1:
|
50 |
+
# Make model replica operate on the current device
|
51 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
52 |
+
module=model, device_ids=[cur_device], output_device=cur_device
|
53 |
+
)
|
54 |
+
return model
|
TimeSformer/timesformer/models/conv2d_same.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Ross Wightman
|
2 |
+
# Conv2d w/ Same Padding
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from typing import Tuple, Optional
|
8 |
+
|
9 |
+
import math
|
10 |
+
from typing import List, Tuple
|
11 |
+
#from .padding import pad_same, get_padding_value
|
12 |
+
|
13 |
+
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
14 |
+
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
|
15 |
+
ih, iw = x.size()[-2:]
|
16 |
+
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
17 |
+
if pad_h > 0 or pad_w > 0:
|
18 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
|
19 |
+
return x
|
20 |
+
|
21 |
+
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
22 |
+
def get_same_padding(x: int, k: int, s: int, d: int):
|
23 |
+
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
24 |
+
|
25 |
+
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
26 |
+
dynamic = False
|
27 |
+
if isinstance(padding, str):
|
28 |
+
# for any string padding, the padding will be calculated for you, one of three ways
|
29 |
+
padding = padding.lower()
|
30 |
+
if padding == 'same':
|
31 |
+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
32 |
+
if is_static_pad(kernel_size, **kwargs):
|
33 |
+
# static case, no extra overhead
|
34 |
+
padding = get_padding(kernel_size, **kwargs)
|
35 |
+
else:
|
36 |
+
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
37 |
+
padding = 0
|
38 |
+
dynamic = True
|
39 |
+
elif padding == 'valid':
|
40 |
+
# 'VALID' padding, same as padding=0
|
41 |
+
padding = 0
|
42 |
+
else:
|
43 |
+
# Default to PyTorch style 'same'-ish symmetric padding
|
44 |
+
padding = get_padding(kernel_size, **kwargs)
|
45 |
+
return padding, dynamic
|
46 |
+
|
47 |
+
def conv2d_same(
|
48 |
+
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
49 |
+
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
50 |
+
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
51 |
+
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
52 |
+
|
53 |
+
|
54 |
+
class Conv2dSame(nn.Conv2d):
|
55 |
+
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
59 |
+
padding=0, dilation=1, groups=1, bias=True):
|
60 |
+
super(Conv2dSame, self).__init__(
|
61 |
+
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
65 |
+
|
66 |
+
|
67 |
+
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
68 |
+
padding = kwargs.pop('padding', '')
|
69 |
+
kwargs.setdefault('bias', False)
|
70 |
+
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
71 |
+
if is_dynamic:
|
72 |
+
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
73 |
+
else:
|
74 |
+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
TimeSformer/timesformer/models/custom_video_model_builder.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
|
4 |
+
"""A More Flexible Video models."""
|
TimeSformer/timesformer/models/features.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Ross Wightman
|
2 |
+
|
3 |
+
from collections import OrderedDict, defaultdict
|
4 |
+
from copy import deepcopy
|
5 |
+
from functools import partial
|
6 |
+
from typing import Dict, List, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
class FeatureInfo:
|
13 |
+
|
14 |
+
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
15 |
+
prev_reduction = 1
|
16 |
+
for fi in feature_info:
|
17 |
+
# sanity check the mandatory fields, there may be additional fields depending on the model
|
18 |
+
assert 'num_chs' in fi and fi['num_chs'] > 0
|
19 |
+
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
20 |
+
prev_reduction = fi['reduction']
|
21 |
+
assert 'module' in fi
|
22 |
+
self.out_indices = out_indices
|
23 |
+
self.info = feature_info
|
24 |
+
|
25 |
+
def from_other(self, out_indices: Tuple[int]):
|
26 |
+
return FeatureInfo(deepcopy(self.info), out_indices)
|
27 |
+
|
28 |
+
def get(self, key, idx=None):
|
29 |
+
""" Get value by key at specified index (indices)
|
30 |
+
if idx == None, returns value for key at each output index
|
31 |
+
if idx is an integer, return value for that feature module index (ignoring output indices)
|
32 |
+
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
33 |
+
"""
|
34 |
+
if idx is None:
|
35 |
+
return [self.info[i][key] for i in self.out_indices]
|
36 |
+
if isinstance(idx, (tuple, list)):
|
37 |
+
return [self.info[i][key] for i in idx]
|
38 |
+
else:
|
39 |
+
return self.info[idx][key]
|
40 |
+
|
41 |
+
def get_dicts(self, keys=None, idx=None):
|
42 |
+
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
43 |
+
"""
|
44 |
+
if idx is None:
|
45 |
+
if keys is None:
|
46 |
+
return [self.info[i] for i in self.out_indices]
|
47 |
+
else:
|
48 |
+
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
49 |
+
if isinstance(idx, (tuple, list)):
|
50 |
+
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
51 |
+
else:
|
52 |
+
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
53 |
+
|
54 |
+
def channels(self, idx=None):
|
55 |
+
""" feature channels accessor
|
56 |
+
"""
|
57 |
+
return self.get('num_chs', idx)
|
58 |
+
|
59 |
+
def reduction(self, idx=None):
|
60 |
+
""" feature reduction (output stride) accessor
|
61 |
+
"""
|
62 |
+
return self.get('reduction', idx)
|
63 |
+
|
64 |
+
def module_name(self, idx=None):
|
65 |
+
""" feature module name accessor
|
66 |
+
"""
|
67 |
+
return self.get('module', idx)
|
68 |
+
|
69 |
+
def __getitem__(self, item):
|
70 |
+
return self.info[item]
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.info)
|
74 |
+
|
75 |
+
|
76 |
+
class FeatureHooks:
|
77 |
+
""" Feature Hook Helper
|
78 |
+
This module helps with the setup and extraction of hooks for extracting features from
|
79 |
+
internal nodes in a model by node name. This works quite well in eager Python but needs
|
80 |
+
redesign for torcscript.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
84 |
+
# setup feature hooks
|
85 |
+
modules = {k: v for k, v in named_modules}
|
86 |
+
for i, h in enumerate(hooks):
|
87 |
+
hook_name = h['module']
|
88 |
+
m = modules[hook_name]
|
89 |
+
hook_id = out_map[i] if out_map else hook_name
|
90 |
+
hook_fn = partial(self._collect_output_hook, hook_id)
|
91 |
+
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
|
92 |
+
if hook_type == 'forward_pre':
|
93 |
+
m.register_forward_pre_hook(hook_fn)
|
94 |
+
elif hook_type == 'forward':
|
95 |
+
m.register_forward_hook(hook_fn)
|
96 |
+
else:
|
97 |
+
assert False, "Unsupported hook type"
|
98 |
+
self._feature_outputs = defaultdict(OrderedDict)
|
99 |
+
|
100 |
+
def _collect_output_hook(self, hook_id, *args):
|
101 |
+
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
102 |
+
if isinstance(x, tuple):
|
103 |
+
x = x[0] # unwrap input tuple
|
104 |
+
self._feature_outputs[x.device][hook_id] = x
|
105 |
+
|
106 |
+
def get_output(self, device) -> Dict[str, torch.tensor]:
|
107 |
+
output = self._feature_outputs[device]
|
108 |
+
self._feature_outputs[device] = OrderedDict() # clear after reading
|
109 |
+
return output
|
110 |
+
|
111 |
+
|
112 |
+
def _module_list(module, flatten_sequential=False):
|
113 |
+
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
114 |
+
ml = []
|
115 |
+
for name, module in module.named_children():
|
116 |
+
if flatten_sequential and isinstance(module, nn.Sequential):
|
117 |
+
# first level of Sequential containers is flattened into containing model
|
118 |
+
for child_name, child_module in module.named_children():
|
119 |
+
combined = [name, child_name]
|
120 |
+
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
121 |
+
else:
|
122 |
+
ml.append((name, name, module))
|
123 |
+
return ml
|
124 |
+
|
125 |
+
|
126 |
+
def _get_feature_info(net, out_indices):
|
127 |
+
feature_info = getattr(net, 'feature_info')
|
128 |
+
if isinstance(feature_info, FeatureInfo):
|
129 |
+
return feature_info.from_other(out_indices)
|
130 |
+
elif isinstance(feature_info, (list, tuple)):
|
131 |
+
return FeatureInfo(net.feature_info, out_indices)
|
132 |
+
else:
|
133 |
+
assert False, "Provided feature_info is not valid"
|
134 |
+
|
135 |
+
|
136 |
+
def _get_return_layers(feature_info, out_map):
|
137 |
+
module_names = feature_info.module_name()
|
138 |
+
return_layers = {}
|
139 |
+
for i, name in enumerate(module_names):
|
140 |
+
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
141 |
+
return return_layers
|
142 |
+
|
143 |
+
|
144 |
+
class FeatureDictNet(nn.ModuleDict):
|
145 |
+
""" Feature extractor with OrderedDict return
|
146 |
+
Wrap a model and extract features as specified by the out indices, the network is
|
147 |
+
partially re-built from contained modules.
|
148 |
+
There is a strong assumption that the modules have been registered into the model in the same
|
149 |
+
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
150 |
+
trivial modules like `self.relu = nn.ReLU`.
|
151 |
+
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
152 |
+
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
153 |
+
All Sequential containers that are directly assigned to the original model will have their
|
154 |
+
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
155 |
+
Arguments:
|
156 |
+
model (nn.Module): model from which we will extract the features
|
157 |
+
out_indices (tuple[int]): model output indices to extract features for
|
158 |
+
out_map (sequence): list or tuple specifying desired return id for each out index,
|
159 |
+
otherwise str(index) is used
|
160 |
+
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
161 |
+
vs select element [0]
|
162 |
+
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
163 |
+
"""
|
164 |
+
def __init__(
|
165 |
+
self, model,
|
166 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
167 |
+
super(FeatureDictNet, self).__init__()
|
168 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
169 |
+
self.concat = feature_concat
|
170 |
+
self.return_layers = {}
|
171 |
+
return_layers = _get_return_layers(self.feature_info, out_map)
|
172 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
173 |
+
remaining = set(return_layers.keys())
|
174 |
+
layers = OrderedDict()
|
175 |
+
for new_name, old_name, module in modules:
|
176 |
+
layers[new_name] = module
|
177 |
+
if old_name in remaining:
|
178 |
+
# return id has to be consistently str type for torchscript
|
179 |
+
self.return_layers[new_name] = str(return_layers[old_name])
|
180 |
+
remaining.remove(old_name)
|
181 |
+
if not remaining:
|
182 |
+
break
|
183 |
+
assert not remaining and len(self.return_layers) == len(return_layers), \
|
184 |
+
f'Return layers ({remaining}) are not present in model'
|
185 |
+
self.update(layers)
|
186 |
+
|
187 |
+
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
188 |
+
out = OrderedDict()
|
189 |
+
for name, module in self.items():
|
190 |
+
x = module(x)
|
191 |
+
if name in self.return_layers:
|
192 |
+
out_id = self.return_layers[name]
|
193 |
+
if isinstance(x, (tuple, list)):
|
194 |
+
# If model tap is a tuple or list, concat or select first element
|
195 |
+
# FIXME this may need to be more generic / flexible for some nets
|
196 |
+
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
197 |
+
else:
|
198 |
+
out[out_id] = x
|
199 |
+
return out
|
200 |
+
|
201 |
+
def forward(self, x) -> Dict[str, torch.Tensor]:
|
202 |
+
return self._collect(x)
|
203 |
+
|
204 |
+
|
205 |
+
class FeatureListNet(FeatureDictNet):
|
206 |
+
""" Feature extractor with list return
|
207 |
+
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
208 |
+
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
209 |
+
"""
|
210 |
+
def __init__(
|
211 |
+
self, model,
|
212 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
213 |
+
super(FeatureListNet, self).__init__(
|
214 |
+
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
215 |
+
flatten_sequential=flatten_sequential)
|
216 |
+
|
217 |
+
def forward(self, x) -> (List[torch.Tensor]):
|
218 |
+
return list(self._collect(x).values())
|
219 |
+
|
220 |
+
|
221 |
+
class FeatureHookNet(nn.ModuleDict):
|
222 |
+
""" FeatureHookNet
|
223 |
+
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
224 |
+
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
225 |
+
network in any way.
|
226 |
+
If `no_rewrite` is False, the model will be re-written as in the
|
227 |
+
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
228 |
+
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
229 |
+
"""
|
230 |
+
def __init__(
|
231 |
+
self, model,
|
232 |
+
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
233 |
+
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
234 |
+
super(FeatureHookNet, self).__init__()
|
235 |
+
assert not torch.jit.is_scripting()
|
236 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
237 |
+
self.out_as_dict = out_as_dict
|
238 |
+
layers = OrderedDict()
|
239 |
+
hooks = []
|
240 |
+
if no_rewrite:
|
241 |
+
assert not flatten_sequential
|
242 |
+
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
243 |
+
model.reset_classifier(0)
|
244 |
+
layers['body'] = model
|
245 |
+
hooks.extend(self.feature_info.get_dicts())
|
246 |
+
else:
|
247 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
248 |
+
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
249 |
+
for f in self.feature_info.get_dicts()}
|
250 |
+
for new_name, old_name, module in modules:
|
251 |
+
layers[new_name] = module
|
252 |
+
for fn, fm in module.named_modules(prefix=old_name):
|
253 |
+
if fn in remaining:
|
254 |
+
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
255 |
+
del remaining[fn]
|
256 |
+
if not remaining:
|
257 |
+
break
|
258 |
+
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
259 |
+
self.update(layers)
|
260 |
+
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
for name, module in self.items():
|
264 |
+
x = module(x)
|
265 |
+
out = self.hooks.get_output(x.device)
|
266 |
+
return out if self.out_as_dict else list(out.values())
|
TimeSformer/timesformer/models/head_helper.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
"""ResNe(X)t Head helper."""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
class ResNetBasicHead(nn.Module):
|
9 |
+
"""
|
10 |
+
ResNe(X)t 3D head.
|
11 |
+
This layer performs a fully-connected projection during training, when the
|
12 |
+
input size is 1x1x1. It performs a convolutional projection during testing
|
13 |
+
when the input size is larger than 1x1x1. If the inputs are from multiple
|
14 |
+
different pathways, the inputs will be concatenated after pooling.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim_in,
|
20 |
+
num_classes,
|
21 |
+
pool_size,
|
22 |
+
dropout_rate=0.0,
|
23 |
+
act_func="softmax",
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
The `__init__` method of any subclass should also contain these
|
27 |
+
arguments.
|
28 |
+
ResNetBasicHead takes p pathways as input where p in [1, infty].
|
29 |
+
|
30 |
+
Args:
|
31 |
+
dim_in (list): the list of channel dimensions of the p inputs to the
|
32 |
+
ResNetHead.
|
33 |
+
num_classes (int): the channel dimensions of the p outputs to the
|
34 |
+
ResNetHead.
|
35 |
+
pool_size (list): the list of kernel sizes of p spatial temporal
|
36 |
+
poolings, temporal pool kernel size, spatial pool kernel size,
|
37 |
+
spatial pool kernel size in order.
|
38 |
+
dropout_rate (float): dropout rate. If equal to 0.0, perform no
|
39 |
+
dropout.
|
40 |
+
act_func (string): activation function to use. 'softmax': applies
|
41 |
+
softmax on the output. 'sigmoid': applies sigmoid on the output.
|
42 |
+
"""
|
43 |
+
super(ResNetBasicHead, self).__init__()
|
44 |
+
assert (
|
45 |
+
len({len(pool_size), len(dim_in)}) == 1
|
46 |
+
), "pathway dimensions are not consistent."
|
47 |
+
self.num_pathways = len(pool_size)
|
48 |
+
|
49 |
+
for pathway in range(self.num_pathways):
|
50 |
+
if pool_size[pathway] is None:
|
51 |
+
avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
52 |
+
else:
|
53 |
+
avg_pool = nn.AvgPool3d(pool_size[pathway], stride=1)
|
54 |
+
self.add_module("pathway{}_avgpool".format(pathway), avg_pool)
|
55 |
+
|
56 |
+
if dropout_rate > 0.0:
|
57 |
+
self.dropout = nn.Dropout(dropout_rate)
|
58 |
+
# Perform FC in a fully convolutional manner. The FC layer will be
|
59 |
+
# initialized with a different std comparing to convolutional layers.
|
60 |
+
self.projection = nn.Linear(sum(dim_in), num_classes, bias=True)
|
61 |
+
|
62 |
+
# Softmax for evaluation and testing.
|
63 |
+
if act_func == "softmax":
|
64 |
+
self.act = nn.Softmax(dim=4)
|
65 |
+
elif act_func == "sigmoid":
|
66 |
+
self.act = nn.Sigmoid()
|
67 |
+
else:
|
68 |
+
raise NotImplementedError(
|
69 |
+
"{} is not supported as an activation"
|
70 |
+
"function.".format(act_func)
|
71 |
+
)
|
72 |
+
|
73 |
+
def forward(self, inputs):
|
74 |
+
assert (
|
75 |
+
len(inputs) == self.num_pathways
|
76 |
+
), "Input tensor does not contain {} pathway".format(self.num_pathways)
|
77 |
+
pool_out = []
|
78 |
+
for pathway in range(self.num_pathways):
|
79 |
+
m = getattr(self, "pathway{}_avgpool".format(pathway))
|
80 |
+
pool_out.append(m(inputs[pathway]))
|
81 |
+
x = torch.cat(pool_out, 1)
|
82 |
+
# (N, C, T, H, W) -> (N, T, H, W, C).
|
83 |
+
x = x.permute((0, 2, 3, 4, 1))
|
84 |
+
# Perform dropout.
|
85 |
+
if hasattr(self, "dropout"):
|
86 |
+
x = self.dropout(x)
|
87 |
+
x = self.projection(x)
|
88 |
+
|
89 |
+
# Performs fully convlutional inference.
|
90 |
+
if not self.training:
|
91 |
+
x = self.act(x)
|
92 |
+
x = x.mean([1, 2, 3])
|
93 |
+
|
94 |
+
x = x.view(x.shape[0], -1)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class X3DHead(nn.Module):
|
99 |
+
"""
|
100 |
+
X3D head.
|
101 |
+
This layer performs a fully-connected projection during training, when the
|
102 |
+
input size is 1x1x1. It performs a convolutional projection during testing
|
103 |
+
when the input size is larger than 1x1x1. If the inputs are from multiple
|
104 |
+
different pathways, the inputs will be concatenated after pooling.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
dim_in,
|
110 |
+
dim_inner,
|
111 |
+
dim_out,
|
112 |
+
num_classes,
|
113 |
+
pool_size,
|
114 |
+
dropout_rate=0.0,
|
115 |
+
act_func="softmax",
|
116 |
+
inplace_relu=True,
|
117 |
+
eps=1e-5,
|
118 |
+
bn_mmt=0.1,
|
119 |
+
norm_module=nn.BatchNorm3d,
|
120 |
+
bn_lin5_on=False,
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
The `__init__` method of any subclass should also contain these
|
124 |
+
arguments.
|
125 |
+
X3DHead takes a 5-dim feature tensor (BxCxTxHxW) as input.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
dim_in (float): the channel dimension C of the input.
|
129 |
+
num_classes (int): the channel dimensions of the output.
|
130 |
+
pool_size (float): a single entry list of kernel size for
|
131 |
+
spatiotemporal pooling for the TxHxW dimensions.
|
132 |
+
dropout_rate (float): dropout rate. If equal to 0.0, perform no
|
133 |
+
dropout.
|
134 |
+
act_func (string): activation function to use. 'softmax': applies
|
135 |
+
softmax on the output. 'sigmoid': applies sigmoid on the output.
|
136 |
+
inplace_relu (bool): if True, calculate the relu on the original
|
137 |
+
input without allocating new memory.
|
138 |
+
eps (float): epsilon for batch norm.
|
139 |
+
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
140 |
+
PyTorch = 1 - BN momentum in Caffe2.
|
141 |
+
norm_module (nn.Module): nn.Module for the normalization layer. The
|
142 |
+
default is nn.BatchNorm3d.
|
143 |
+
bn_lin5_on (bool): if True, perform normalization on the features
|
144 |
+
before the classifier.
|
145 |
+
"""
|
146 |
+
super(X3DHead, self).__init__()
|
147 |
+
self.pool_size = pool_size
|
148 |
+
self.dropout_rate = dropout_rate
|
149 |
+
self.num_classes = num_classes
|
150 |
+
self.act_func = act_func
|
151 |
+
self.eps = eps
|
152 |
+
self.bn_mmt = bn_mmt
|
153 |
+
self.inplace_relu = inplace_relu
|
154 |
+
self.bn_lin5_on = bn_lin5_on
|
155 |
+
self._construct_head(dim_in, dim_inner, dim_out, norm_module)
|
156 |
+
|
157 |
+
def _construct_head(self, dim_in, dim_inner, dim_out, norm_module):
|
158 |
+
|
159 |
+
self.conv_5 = nn.Conv3d(
|
160 |
+
dim_in,
|
161 |
+
dim_inner,
|
162 |
+
kernel_size=(1, 1, 1),
|
163 |
+
stride=(1, 1, 1),
|
164 |
+
padding=(0, 0, 0),
|
165 |
+
bias=False,
|
166 |
+
)
|
167 |
+
self.conv_5_bn = norm_module(
|
168 |
+
num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt
|
169 |
+
)
|
170 |
+
self.conv_5_relu = nn.ReLU(self.inplace_relu)
|
171 |
+
|
172 |
+
if self.pool_size is None:
|
173 |
+
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
174 |
+
else:
|
175 |
+
self.avg_pool = nn.AvgPool3d(self.pool_size, stride=1)
|
176 |
+
|
177 |
+
self.lin_5 = nn.Conv3d(
|
178 |
+
dim_inner,
|
179 |
+
dim_out,
|
180 |
+
kernel_size=(1, 1, 1),
|
181 |
+
stride=(1, 1, 1),
|
182 |
+
padding=(0, 0, 0),
|
183 |
+
bias=False,
|
184 |
+
)
|
185 |
+
if self.bn_lin5_on:
|
186 |
+
self.lin_5_bn = norm_module(
|
187 |
+
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
|
188 |
+
)
|
189 |
+
self.lin_5_relu = nn.ReLU(self.inplace_relu)
|
190 |
+
|
191 |
+
if self.dropout_rate > 0.0:
|
192 |
+
self.dropout = nn.Dropout(self.dropout_rate)
|
193 |
+
# Perform FC in a fully convolutional manner. The FC layer will be
|
194 |
+
# initialized with a different std comparing to convolutional layers.
|
195 |
+
self.projection = nn.Linear(dim_out, self.num_classes, bias=True)
|
196 |
+
|
197 |
+
# Softmax for evaluation and testing.
|
198 |
+
if self.act_func == "softmax":
|
199 |
+
self.act = nn.Softmax(dim=4)
|
200 |
+
elif self.act_func == "sigmoid":
|
201 |
+
self.act = nn.Sigmoid()
|
202 |
+
else:
|
203 |
+
raise NotImplementedError(
|
204 |
+
"{} is not supported as an activation"
|
205 |
+
"function.".format(self.act_func)
|
206 |
+
)
|
207 |
+
|
208 |
+
def forward(self, inputs):
|
209 |
+
# In its current design the X3D head is only useable for a single
|
210 |
+
# pathway input.
|
211 |
+
assert len(inputs) == 1, "Input tensor does not contain 1 pathway"
|
212 |
+
x = self.conv_5(inputs[0])
|
213 |
+
x = self.conv_5_bn(x)
|
214 |
+
x = self.conv_5_relu(x)
|
215 |
+
x = self.avg_pool(x)
|
216 |
+
|
217 |
+
x = self.lin_5(x)
|
218 |
+
if self.bn_lin5_on:
|
219 |
+
x = self.lin_5_bn(x)
|
220 |
+
x = self.lin_5_relu(x)
|
221 |
+
|
222 |
+
# (N, C, T, H, W) -> (N, T, H, W, C).
|
223 |
+
x = x.permute((0, 2, 3, 4, 1))
|
224 |
+
# Perform dropout.
|
225 |
+
if hasattr(self, "dropout"):
|
226 |
+
x = self.dropout(x)
|
227 |
+
x = self.projection(x)
|
228 |
+
|
229 |
+
# Performs fully convlutional inference.
|
230 |
+
if not self.training:
|
231 |
+
x = self.act(x)
|
232 |
+
x = x.mean([1, 2, 3])
|
233 |
+
|
234 |
+
x = x.view(x.shape[0], -1)
|
235 |
+
return x
|
TimeSformer/timesformer/models/helpers.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
# Copyright 2020 Ross Wightman
|
3 |
+
# Modified model creation / weight loading / state_dict helpers
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
from collections import OrderedDict
|
9 |
+
from copy import deepcopy
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.utils.model_zoo as model_zoo
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from timesformer.models.features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
18 |
+
from timesformer.models.conv2d_same import Conv2dSame
|
19 |
+
from timesformer.models.linear import Linear
|
20 |
+
|
21 |
+
|
22 |
+
_logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
def load_state_dict(checkpoint_path, use_ema=False):
|
25 |
+
if checkpoint_path and os.path.isfile(checkpoint_path):
|
26 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
27 |
+
state_dict_key = 'state_dict'
|
28 |
+
if isinstance(checkpoint, dict):
|
29 |
+
if use_ema and 'state_dict_ema' in checkpoint:
|
30 |
+
state_dict_key = 'state_dict_ema'
|
31 |
+
if state_dict_key and state_dict_key in checkpoint:
|
32 |
+
new_state_dict = OrderedDict()
|
33 |
+
for k, v in checkpoint[state_dict_key].items():
|
34 |
+
# strip `module.` prefix
|
35 |
+
name = k[7:] if k.startswith('module') else k
|
36 |
+
new_state_dict[name] = v
|
37 |
+
state_dict = new_state_dict
|
38 |
+
elif 'model_state' in checkpoint:
|
39 |
+
state_dict_key = 'model_state'
|
40 |
+
new_state_dict = OrderedDict()
|
41 |
+
for k, v in checkpoint[state_dict_key].items():
|
42 |
+
# strip `model.` prefix
|
43 |
+
name = k[6:] if k.startswith('model') else k
|
44 |
+
new_state_dict[name] = v
|
45 |
+
state_dict = new_state_dict
|
46 |
+
else:
|
47 |
+
state_dict = checkpoint
|
48 |
+
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
49 |
+
return state_dict
|
50 |
+
else:
|
51 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
52 |
+
raise FileNotFoundError()
|
53 |
+
|
54 |
+
|
55 |
+
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
56 |
+
state_dict = load_state_dict(checkpoint_path, use_ema)
|
57 |
+
model.load_state_dict(state_dict, strict=strict)
|
58 |
+
|
59 |
+
|
60 |
+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
61 |
+
resume_epoch = None
|
62 |
+
if os.path.isfile(checkpoint_path):
|
63 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
64 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
65 |
+
if log_info:
|
66 |
+
_logger.info('Restoring model state from checkpoint...')
|
67 |
+
new_state_dict = OrderedDict()
|
68 |
+
for k, v in checkpoint['state_dict'].items():
|
69 |
+
name = k[7:] if k.startswith('module') else k
|
70 |
+
new_state_dict[name] = v
|
71 |
+
model.load_state_dict(new_state_dict)
|
72 |
+
|
73 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
74 |
+
if log_info:
|
75 |
+
_logger.info('Restoring optimizer state from checkpoint...')
|
76 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
77 |
+
|
78 |
+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
79 |
+
if log_info:
|
80 |
+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
81 |
+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
82 |
+
|
83 |
+
if 'epoch' in checkpoint:
|
84 |
+
resume_epoch = checkpoint['epoch']
|
85 |
+
if 'version' in checkpoint and checkpoint['version'] > 1:
|
86 |
+
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
87 |
+
|
88 |
+
if log_info:
|
89 |
+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
90 |
+
else:
|
91 |
+
model.load_state_dict(checkpoint)
|
92 |
+
if log_info:
|
93 |
+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
94 |
+
return resume_epoch
|
95 |
+
else:
|
96 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
97 |
+
raise FileNotFoundError()
|
98 |
+
|
99 |
+
|
100 |
+
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True):
|
101 |
+
if cfg is None:
|
102 |
+
cfg = getattr(model, 'default_cfg')
|
103 |
+
if cfg is None or 'url' not in cfg or not cfg['url']:
|
104 |
+
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
105 |
+
return
|
106 |
+
|
107 |
+
if len(pretrained_model) == 0:
|
108 |
+
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
109 |
+
else:
|
110 |
+
try:
|
111 |
+
state_dict = load_state_dict(pretrained_model)['model']
|
112 |
+
except:
|
113 |
+
state_dict = load_state_dict(pretrained_model)
|
114 |
+
|
115 |
+
|
116 |
+
if filter_fn is not None:
|
117 |
+
state_dict = filter_fn(state_dict)
|
118 |
+
|
119 |
+
if in_chans == 1:
|
120 |
+
conv1_name = cfg['first_conv']
|
121 |
+
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
122 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
123 |
+
conv1_type = conv1_weight.dtype
|
124 |
+
conv1_weight = conv1_weight.float()
|
125 |
+
O, I, J, K = conv1_weight.shape
|
126 |
+
if I > 3:
|
127 |
+
assert conv1_weight.shape[1] % 3 == 0
|
128 |
+
# For models with space2depth stems
|
129 |
+
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
130 |
+
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
131 |
+
else:
|
132 |
+
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
133 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
134 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
135 |
+
elif in_chans != 3:
|
136 |
+
conv1_name = cfg['first_conv']
|
137 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
138 |
+
conv1_type = conv1_weight.dtype
|
139 |
+
conv1_weight = conv1_weight.float()
|
140 |
+
O, I, J, K = conv1_weight.shape
|
141 |
+
if I != 3:
|
142 |
+
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
143 |
+
del state_dict[conv1_name + '.weight']
|
144 |
+
strict = False
|
145 |
+
else:
|
146 |
+
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
147 |
+
repeat = int(math.ceil(in_chans / 3))
|
148 |
+
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
149 |
+
conv1_weight *= (3 / float(in_chans))
|
150 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
151 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
152 |
+
|
153 |
+
|
154 |
+
classifier_name = cfg['classifier']
|
155 |
+
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
156 |
+
# special case for imagenet trained models with extra background class in pretrained weights
|
157 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
158 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
159 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
160 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
161 |
+
elif num_classes != state_dict[classifier_name + '.weight'].size(0):
|
162 |
+
#print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
|
163 |
+
# completely discard fully connected for all other differences between pretrained and created model
|
164 |
+
del state_dict[classifier_name + '.weight']
|
165 |
+
del state_dict[classifier_name + '.bias']
|
166 |
+
strict = False
|
167 |
+
|
168 |
+
|
169 |
+
## Resizing the positional embeddings in case they don't match
|
170 |
+
if num_patches + 1 != state_dict['pos_embed'].size(1):
|
171 |
+
pos_embed = state_dict['pos_embed']
|
172 |
+
cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
|
173 |
+
other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
|
174 |
+
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
|
175 |
+
new_pos_embed = new_pos_embed.transpose(1, 2)
|
176 |
+
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
|
177 |
+
state_dict['pos_embed'] = new_pos_embed
|
178 |
+
|
179 |
+
## Resizing time embeddings in case they don't match
|
180 |
+
if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
|
181 |
+
time_embed = state_dict['time_embed'].transpose(1, 2)
|
182 |
+
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
|
183 |
+
state_dict['time_embed'] = new_time_embed.transpose(1, 2)
|
184 |
+
|
185 |
+
## Initializing temporal attention
|
186 |
+
if attention_type == 'divided_space_time':
|
187 |
+
new_state_dict = state_dict.copy()
|
188 |
+
for key in state_dict:
|
189 |
+
if 'blocks' in key and 'attn' in key:
|
190 |
+
new_key = key.replace('attn','temporal_attn')
|
191 |
+
if not new_key in state_dict:
|
192 |
+
new_state_dict[new_key] = state_dict[key]
|
193 |
+
else:
|
194 |
+
new_state_dict[new_key] = state_dict[new_key]
|
195 |
+
if 'blocks' in key and 'norm1' in key:
|
196 |
+
new_key = key.replace('norm1','temporal_norm1')
|
197 |
+
if not new_key in state_dict:
|
198 |
+
new_state_dict[new_key] = state_dict[key]
|
199 |
+
else:
|
200 |
+
new_state_dict[new_key] = state_dict[new_key]
|
201 |
+
state_dict = new_state_dict
|
202 |
+
|
203 |
+
## Loading the weights
|
204 |
+
model.load_state_dict(state_dict, strict=False)
|
205 |
+
|
206 |
+
|
207 |
+
def extract_layer(model, layer):
|
208 |
+
layer = layer.split('.')
|
209 |
+
module = model
|
210 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
211 |
+
module = model.module
|
212 |
+
if not hasattr(model, 'module') and layer[0] == 'module':
|
213 |
+
layer = layer[1:]
|
214 |
+
for l in layer:
|
215 |
+
if hasattr(module, l):
|
216 |
+
if not l.isdigit():
|
217 |
+
module = getattr(module, l)
|
218 |
+
else:
|
219 |
+
module = module[int(l)]
|
220 |
+
else:
|
221 |
+
return module
|
222 |
+
return module
|
223 |
+
|
224 |
+
|
225 |
+
def set_layer(model, layer, val):
|
226 |
+
layer = layer.split('.')
|
227 |
+
module = model
|
228 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
229 |
+
module = model.module
|
230 |
+
lst_index = 0
|
231 |
+
module2 = module
|
232 |
+
for l in layer:
|
233 |
+
if hasattr(module2, l):
|
234 |
+
if not l.isdigit():
|
235 |
+
module2 = getattr(module2, l)
|
236 |
+
else:
|
237 |
+
module2 = module2[int(l)]
|
238 |
+
lst_index += 1
|
239 |
+
lst_index -= 1
|
240 |
+
for l in layer[:lst_index]:
|
241 |
+
if not l.isdigit():
|
242 |
+
module = getattr(module, l)
|
243 |
+
else:
|
244 |
+
module = module[int(l)]
|
245 |
+
l = layer[lst_index]
|
246 |
+
setattr(module, l, val)
|
247 |
+
|
248 |
+
|
249 |
+
def adapt_model_from_string(parent_module, model_string):
|
250 |
+
separator = '***'
|
251 |
+
state_dict = {}
|
252 |
+
lst_shape = model_string.split(separator)
|
253 |
+
for k in lst_shape:
|
254 |
+
k = k.split(':')
|
255 |
+
key = k[0]
|
256 |
+
shape = k[1][1:-1].split(',')
|
257 |
+
if shape[0] != '':
|
258 |
+
state_dict[key] = [int(i) for i in shape]
|
259 |
+
|
260 |
+
new_module = deepcopy(parent_module)
|
261 |
+
for n, m in parent_module.named_modules():
|
262 |
+
old_module = extract_layer(parent_module, n)
|
263 |
+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
264 |
+
if isinstance(old_module, Conv2dSame):
|
265 |
+
conv = Conv2dSame
|
266 |
+
else:
|
267 |
+
conv = nn.Conv2d
|
268 |
+
s = state_dict[n + '.weight']
|
269 |
+
in_channels = s[1]
|
270 |
+
out_channels = s[0]
|
271 |
+
g = 1
|
272 |
+
if old_module.groups > 1:
|
273 |
+
in_channels = out_channels
|
274 |
+
g = in_channels
|
275 |
+
new_conv = conv(
|
276 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
277 |
+
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
278 |
+
groups=g, stride=old_module.stride)
|
279 |
+
set_layer(new_module, n, new_conv)
|
280 |
+
if isinstance(old_module, nn.BatchNorm2d):
|
281 |
+
new_bn = nn.BatchNorm2d(
|
282 |
+
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
283 |
+
affine=old_module.affine, track_running_stats=True)
|
284 |
+
set_layer(new_module, n, new_bn)
|
285 |
+
if isinstance(old_module, nn.Linear):
|
286 |
+
num_features = state_dict[n + '.weight'][1]
|
287 |
+
new_fc = Linear(
|
288 |
+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
289 |
+
set_layer(new_module, n, new_fc)
|
290 |
+
if hasattr(new_module, 'num_features'):
|
291 |
+
new_module.num_features = num_features
|
292 |
+
new_module.eval()
|
293 |
+
parent_module.eval()
|
294 |
+
|
295 |
+
return new_module
|
296 |
+
|
297 |
+
|
298 |
+
def adapt_model_from_file(parent_module, model_variant):
|
299 |
+
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
300 |
+
with open(adapt_file, 'r') as f:
|
301 |
+
return adapt_model_from_string(parent_module, f.read().strip())
|
302 |
+
|
303 |
+
|
304 |
+
def default_cfg_for_features(default_cfg):
|
305 |
+
default_cfg = deepcopy(default_cfg)
|
306 |
+
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
307 |
+
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
|
308 |
+
for tr in to_remove:
|
309 |
+
default_cfg.pop(tr, None)
|
310 |
+
return default_cfg
|
311 |
+
|
312 |
+
|
313 |
+
def build_model_with_cfg(
|
314 |
+
model_cls: Callable,
|
315 |
+
variant: str,
|
316 |
+
pretrained: bool,
|
317 |
+
default_cfg: dict,
|
318 |
+
model_cfg: dict = None,
|
319 |
+
feature_cfg: dict = None,
|
320 |
+
pretrained_strict: bool = True,
|
321 |
+
pretrained_filter_fn: Callable = None,
|
322 |
+
**kwargs):
|
323 |
+
pruned = kwargs.pop('pruned', False)
|
324 |
+
features = False
|
325 |
+
feature_cfg = feature_cfg or {}
|
326 |
+
|
327 |
+
if kwargs.pop('features_only', False):
|
328 |
+
features = True
|
329 |
+
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
330 |
+
if 'out_indices' in kwargs:
|
331 |
+
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
332 |
+
|
333 |
+
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
334 |
+
model.default_cfg = deepcopy(default_cfg)
|
335 |
+
|
336 |
+
if pruned:
|
337 |
+
model = adapt_model_from_file(model, variant)
|
338 |
+
|
339 |
+
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
340 |
+
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
341 |
+
if pretrained:
|
342 |
+
load_pretrained(
|
343 |
+
model,
|
344 |
+
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
345 |
+
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
346 |
+
|
347 |
+
if features:
|
348 |
+
feature_cls = FeatureListNet
|
349 |
+
if 'feature_cls' in feature_cfg:
|
350 |
+
feature_cls = feature_cfg.pop('feature_cls')
|
351 |
+
if isinstance(feature_cls, str):
|
352 |
+
feature_cls = feature_cls.lower()
|
353 |
+
if 'hook' in feature_cls:
|
354 |
+
feature_cls = FeatureHookNet
|
355 |
+
else:
|
356 |
+
assert False, f'Unknown feature class {feature_cls}'
|
357 |
+
model = feature_cls(model, **feature_cfg)
|
358 |
+
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
359 |
+
|
360 |
+
return model
|