mshukor commited on
Commit
87d7283
·
1 Parent(s): d41c3ca
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +8 -9
  3. TimeSformer/.gitignore +143 -0
  4. TimeSformer/CODE_OF_CONDUCT.md +5 -0
  5. TimeSformer/CONTRIBUTING.md +25 -0
  6. TimeSformer/LICENSE +399 -0
  7. TimeSformer/README.md +248 -0
  8. TimeSformer/configs/Kinetics/SLOWFAST_4x16_R50.yaml +63 -0
  9. TimeSformer/configs/Kinetics/SLOWFAST_8x8_R101.yaml +63 -0
  10. TimeSformer/configs/Kinetics/SLOWFAST_8x8_R50.yaml +63 -0
  11. TimeSformer/configs/Kinetics/TimeSformer_divST_16x16_448.yaml +45 -0
  12. TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml +45 -0
  13. TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml +45 -0
  14. TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml +46 -0
  15. TimeSformer/configs/Kinetics/TimeSformer_divST_96x4_224.yaml +45 -0
  16. TimeSformer/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml +45 -0
  17. TimeSformer/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml +45 -0
  18. TimeSformer/configs/SSv2/SLOWFAST_16x8_R50.yaml +83 -0
  19. TimeSformer/configs/SSv2/TimeSformer_divST_16_448.yaml +48 -0
  20. TimeSformer/configs/SSv2/TimeSformer_divST_64_224.yaml +48 -0
  21. TimeSformer/configs/SSv2/TimeSformer_divST_8_224.yaml +48 -0
  22. TimeSformer/environment.yml +26 -0
  23. TimeSformer/example.ipynb +84 -0
  24. TimeSformer/setup.cfg +23 -0
  25. TimeSformer/setup.py +23 -0
  26. TimeSformer/slurm_scripts/run_multi_node_job.sh +25 -0
  27. TimeSformer/slurm_scripts/run_single_node_job.sh +35 -0
  28. TimeSformer/timesformer/__init__.py +5 -0
  29. TimeSformer/timesformer/config/__init__.py +1 -0
  30. TimeSformer/timesformer/config/defaults.py +820 -0
  31. TimeSformer/timesformer/datasets/DATASET.md +26 -0
  32. TimeSformer/timesformer/datasets/__init__.py +5 -0
  33. TimeSformer/timesformer/datasets/build.py +30 -0
  34. TimeSformer/timesformer/datasets/cv2_transform.py +796 -0
  35. TimeSformer/timesformer/datasets/decoder.py +392 -0
  36. TimeSformer/timesformer/datasets/kinetics.py +294 -0
  37. TimeSformer/timesformer/datasets/loader.py +134 -0
  38. TimeSformer/timesformer/datasets/multigrid_helper.py +78 -0
  39. TimeSformer/timesformer/datasets/ssv2.py +278 -0
  40. TimeSformer/timesformer/datasets/transform.py +459 -0
  41. TimeSformer/timesformer/datasets/utils.py +380 -0
  42. TimeSformer/timesformer/datasets/video_container.py +31 -0
  43. TimeSformer/timesformer/models/__init__.py +5 -0
  44. TimeSformer/timesformer/models/batchnorm_helper.py +217 -0
  45. TimeSformer/timesformer/models/build.py +54 -0
  46. TimeSformer/timesformer/models/conv2d_same.py +74 -0
  47. TimeSformer/timesformer/models/custom_video_model_builder.py +4 -0
  48. TimeSformer/timesformer/models/features.py +266 -0
  49. TimeSformer/timesformer/models/head_helper.py +235 -0
  50. TimeSformer/timesformer/models/helpers.py +360 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 mshukor
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: EP ALM Video Text
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: eP-ALM
3
+ emoji: 🌍
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.12.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ ---
 
TimeSformer/.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ # Docker file from Python is inspired from here :
6
+ # https://github.com/github/gitignore/blob/master/Python.gitignore
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ tests/report/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102
+ __pypackages__/
103
+
104
+ # Celery stuff
105
+ celerybeat-schedule
106
+ celerybeat.pid
107
+
108
+ # SageMath parsed files
109
+ *.sage.py
110
+
111
+ # Environments
112
+ .env
113
+ .venv
114
+ env/
115
+ venv/
116
+ ENV/
117
+ env.bak/
118
+ venv.bak/
119
+
120
+ # Spyder project settings
121
+ .spyderproject
122
+ .spyproject
123
+
124
+ # Rope project settings
125
+ .ropeproject
126
+
127
+ # mkdocs documentation
128
+ /site
129
+
130
+ # mypy
131
+ .mypy_cache/
132
+ .dmypy.json
133
+ dmypy.json
134
+
135
+ # Pyre type checker
136
+ .pyre/
137
+
138
+ # pytype static type analyzer
139
+ .pytype/
140
+
141
+
142
+ # Cython debug symbols
143
+ cython_debug/
TimeSformer/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
TimeSformer/CONTRIBUTING.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to TimeSformer
2
+
3
+ ## Pull Requests
4
+ We actively welcome your pull requests.
5
+
6
+ 1. Fork the repo and create your branch from `master`.
7
+ 2. If you've added code that should be tested, add tests.
8
+ 3. If you've changed APIs, update the documentation.
9
+ 4. Ensure the test suite passes.
10
+ 5. Make sure your code lints.
11
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
12
+
13
+ ## Contributor License Agreement ("CLA")
14
+ In order to accept your pull request, we need you to submit a CLA. You only need
15
+ to do this once to work on any of Facebook's open source projects.
16
+
17
+ Complete your CLA here: <https://code.facebook.com/cla>
18
+
19
+ ## Issues
20
+ We use GitHub issues to track public bugs. Please ensure your description is
21
+ clear and has sufficient instructions to be able to reproduce the issue.
22
+
23
+ ## License
24
+ By contributing to TimeSformer, you agree that your contributions will be licensed
25
+ under the [LICENSE.md](LICENSE.md) file in the root directory of this source tree.
TimeSformer/LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
TimeSformer/README.md ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TimeSformer
2
+
3
+ This is an official pytorch implementation of our ICML 2021 paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/pdf/2102.05095.pdf). In this repository, we provide PyTorch code for training and testing our proposed TimeSformer model. TimeSformer provides an efficient video classification framework that achieves state-of-the-art results on several video action recognition benchmarks such as Kinetics-400.
4
+
5
+ If you find TimeSformer useful in your research, please use the following BibTeX entry for citation.
6
+
7
+ ```BibTeX
8
+ @inproceedings{gberta_2021_ICML,
9
+ author = {Gedas Bertasius and Heng Wang and Lorenzo Torresani},
10
+ title = {Is Space-Time Attention All You Need for Video Understanding?},
11
+ booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
12
+ month = {July},
13
+ year = {2021}
14
+ }
15
+ ```
16
+
17
+ # Model Zoo
18
+
19
+ We provide TimeSformer models pretrained on Kinetics-400 (K400), Kinetics-600 (K600), Something-Something-V2 (SSv2), and HowTo100M datasets.
20
+
21
+ | name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
22
+ | --- | --- | --- | --- | --- | --- | --- |
23
+ | TimeSformer | K400 | 8 | 224 | 77.9 | 93.2 | [model](https://www.dropbox.com/s/g5t24we9gl5yk88/TimeSformer_divST_8x32_224_K400.pyth?dl=0) |
24
+ | TimeSformer-HR | K400 | 16 | 448 | 79.6 | 94.0 | [model](https://www.dropbox.com/s/6f0x172lpqy3oxt/TimeSformer_divST_16x16_448_K400.pyth?dl=0) |
25
+ | TimeSformer-L | K400 | 96 | 224 | 80.6 | 94.7 | [model](https://www.dropbox.com/s/r1iuxahif3sgimo/TimeSformer_divST_96x4_224_K400.pyth?dl=0) |
26
+
27
+ | name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
28
+ | --- | --- | --- | --- | --- | --- | --- |
29
+ | TimeSformer | K600 | 8 | 224 | 79.1 | 94.4 | [model](https://www.dropbox.com/s/4h2qt41m2z3aqrb/TimeSformer_divST_8x32_224_K600.pyth?dl=0) |
30
+ | TimeSformer-HR | K600 | 16 | 448 | 81.8 | 95.8 | [model](https://www.dropbox.com/s/ft1e92g2vhvxecv/TimeSformer_divST_16x16_448_K600.pyth?dl=0) |
31
+ | TimeSformer-L | K600 | 96 | 224 | 82.2 | 95.6 | [model](https://www.dropbox.com/s/857rx6xeclxfhdg/TimeSformer_divST_96x4_224_K600.pyth?dl=0) |
32
+
33
+ | name | dataset | # of frames | spatial crop | acc@1 | acc@5 | url |
34
+ | --- | --- | --- | --- | --- | --- | --- |
35
+ | TimeSformer | SSv2 | 8 | 224 | 59.1 | 85.6 | [model](https://www.dropbox.com/s/tybhuml57y24wpm/TimeSformer_divST_8_224_SSv2.pyth?dl=0) |
36
+ | TimeSformer-HR | SSv2 | 16 | 448 | 61.8 | 86.9 | [model](https://www.dropbox.com/s/9t68uzk8w2fpfnv/TimeSformer_divST_16_448_SSv2.pyth?dl=0) |
37
+ | TimeSformer-L | SSv2 | 64 | 224 | 62.0 | 87.5 | [model](https://www.dropbox.com/s/3f1rm2al8mhprwa/TimeSformer_divST_64_224_SSv2.pyth?dl=0) |
38
+
39
+ | name | dataset | # of frames | spatial crop | single clip coverage | acc@1 | url |
40
+ | --- | --- | --- | --- | --- | --- | --- |
41
+ | TimeSformer | HowTo100M | 8 | 224 | 8.5s | 56.8 | [model](https://www.dropbox.com/s/9v8hcm88b9tc6ff/TimeSformer_divST_8x32_224_HowTo100M.pyth?dl=0) |
42
+ | TimeSformer | HowTo100M | 32 | 224 | 34.1s | 61.2 | [model](https://www.dropbox.com/s/4roflx4q1gscu85/TimeSformer_divST_32x32_224_HowTo100M.pyth?dl=0) |
43
+ | TimeSformer | HowTo100M | 64 | 448 | 68.3s | 62.2 | [model](https://www.dropbox.com/s/15bvqltl1j5vyp3/TimeSformer_divST_64x32_224_HowTo100M.pyth?dl=0) |
44
+ | TimeSformer | HowTo100M | 96 | 224 | 102.4s | 62.6 | [model](https://www.dropbox.com/s/t2mzgahnfhgakma/TimeSformer_divST_96x32_224_HowTo100M.pyth?dl=0) |
45
+
46
+ We note that these models were re-trained using a slightly different implementation than the one used in the paper. Therefore, there might be a small difference in performance compared to the results reported in the paper.
47
+
48
+ You can load the pretrained models as follows:
49
+
50
+ ```python
51
+ import torch
52
+ from timesformer.models.vit import TimeSformer
53
+
54
+ model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time', pretrained_model='/path/to/pretrained/model.pyth')
55
+
56
+ dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)
57
+
58
+ pred = model(dummy_video,) # (2, 400)
59
+ ```
60
+
61
+ # Installation
62
+
63
+ First, create a conda virtual environment and activate it:
64
+ ```
65
+ conda create -n timesformer python=3.7 -y
66
+ source activate timesformer
67
+ ```
68
+
69
+ Then, install the following packages:
70
+
71
+ - torchvision: `pip install torchvision` or `conda install torchvision -c pytorch`
72
+ - [fvcore](https://github.com/facebookresearch/fvcore/): `pip install 'git+https://github.com/facebookresearch/fvcore'`
73
+ - simplejson: `pip install simplejson`
74
+ - einops: `pip install einops`
75
+ - timm: `pip install timm`
76
+ - PyAV: `conda install av -c conda-forge`
77
+ - psutil: `pip install psutil`
78
+ - scikit-learn: `pip install scikit-learn`
79
+ - OpenCV: `pip install opencv-python`
80
+ - tensorboard: `pip install tensorboard`
81
+
82
+ Lastly, build the TimeSformer codebase by running:
83
+ ```
84
+ git clone https://github.com/facebookresearch/TimeSformer
85
+ cd TimeSformer
86
+ python setup.py build develop
87
+ ```
88
+
89
+ # Usage
90
+
91
+ ## Dataset Preparation
92
+
93
+ Please use the dataset preparation instructions provided in [DATASET.md](timesformer/datasets/DATASET.md).
94
+
95
+ ## Training the Default TimeSformer
96
+
97
+ Training the default TimeSformer that uses divided space-time attention, and operates on 8-frame clips cropped at 224x224 spatial resolution, can be done using the following command:
98
+
99
+ ```
100
+ python tools/run_net.py \
101
+ --cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml \
102
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
103
+ NUM_GPUS 8 \
104
+ TRAIN.BATCH_SIZE 8 \
105
+ ```
106
+ You may need to pass location of your dataset in the command line by adding `DATA.PATH_TO_DATA_DIR path_to_your_dataset`, or you can simply add
107
+
108
+ ```
109
+ DATA:
110
+ PATH_TO_DATA_DIR: path_to_your_dataset
111
+ ```
112
+
113
+ To the yaml configs file, then you do not need to pass it to the command line every time.
114
+
115
+ ## Using a Different Number of GPUs
116
+
117
+ If you want to use a smaller number of GPUs, you need to modify .yaml configuration files in [`configs/`](configs/). Specifically, you need to modify the NUM_GPUS, TRAIN.BATCH_SIZE, TEST.BATCH_SIZE, DATA_LOADER.NUM_WORKERS entries in each configuration file. The BATCH_SIZE entry should be the same or higher as the NUM_GPUS entry. In [`configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml`](configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml), we provide a sample configuration file for a 4 GPU setup.
118
+
119
+
120
+ ## Using Different Self-Attention Schemes
121
+
122
+ If you want to experiment with different space-time self-attention schemes, e.g., space-only or joint space-time attention, use the following commands:
123
+
124
+
125
+ ```
126
+ python tools/run_net.py \
127
+ --cfg configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml \
128
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
129
+ NUM_GPUS 8 \
130
+ TRAIN.BATCH_SIZE 8 \
131
+ ```
132
+
133
+ and
134
+
135
+ ```
136
+ python tools/run_net.py \
137
+ --cfg configs/Kinetics/TimeSformer_jointST_8x32_224.yaml \
138
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
139
+ NUM_GPUS 8 \
140
+ TRAIN.BATCH_SIZE 8 \
141
+ ```
142
+
143
+ ## Training Different TimeSformer Variants
144
+
145
+ If you want to train more powerful TimeSformer variants, e.g., TimeSformer-HR (operating on 16-frame clips sampled at 448x448 spatial resolution), and TimeSformer-L (operating on 96-frame clips sampled at 224x224 spatial resolution), use the following commands:
146
+
147
+ ```
148
+ python tools/run_net.py \
149
+ --cfg configs/Kinetics/TimeSformer_divST_16x16_448.yaml \
150
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
151
+ NUM_GPUS 8 \
152
+ TRAIN.BATCH_SIZE 8 \
153
+ ```
154
+
155
+ and
156
+
157
+ ```
158
+ python tools/run_net.py \
159
+ --cfg configs/Kinetics/TimeSformer_divST_96x4_224.yaml \
160
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
161
+ NUM_GPUS 8 \
162
+ TRAIN.BATCH_SIZE 8 \
163
+ ```
164
+
165
+ Note that for these models you will need a set of GPUs with ~32GB of memory.
166
+
167
+ ## Inference
168
+
169
+ Use `TRAIN.ENABLE` and `TEST.ENABLE` to control whether training or testing is required for a given run. When testing, you also have to provide the path to the checkpoint model via TEST.CHECKPOINT_FILE_PATH.
170
+ ```
171
+ python tools/run_net.py \
172
+ --cfg configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml \
173
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
174
+ TEST.CHECKPOINT_FILE_PATH path_to_your_checkpoint \
175
+ TRAIN.ENABLE False \
176
+ ```
177
+
178
+ ## Single-Node Training via Slurm
179
+
180
+ To train TimeSformer via Slurm, please check out our single node Slurm training script [`slurm_scripts/run_single_node_job.sh`](slurm_scripts/run_single_node_job.sh).
181
+
182
+
183
+ ## Multi-Node Training via Submitit
184
+
185
+ Distributed training is available via Slurm and submitit
186
+
187
+ ```
188
+ pip install submitit
189
+ ```
190
+
191
+ To train TimeSformer model on Kinetics using 4 nodes with 8 gpus each use the following command:
192
+ ```
193
+ python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --name ${JOB_NAME} --use_volta32
194
+ ```
195
+
196
+ We provide a script for launching slurm jobs in [`slurm_scripts/run_multi_node_job.sh`](slurm_scripts/run_multi_node_job.sh).
197
+
198
+ ## Finetuning
199
+
200
+ To finetune from an existing PyTorch checkpoint add the following line in the command line, or you can also add it in the YAML config:
201
+
202
+ ```
203
+ TRAIN.CHECKPOINT_FILE_PATH path_to_your_PyTorch_checkpoint
204
+ TRAIN.FINETUNE True
205
+ ```
206
+
207
+ ## HowTo100M Dataset Split
208
+
209
+ If you want to experiment with the long-term video modeling task on HowTo100M, please download the train/test split files from [here](https://www.dropbox.com/sh/ttvsxwqypijjuda/AACmJx1CnddW6cVBoc21eSuva?dl=0).
210
+
211
+
212
+ # Environment
213
+
214
+ The code was developed using python 3.7 on Ubuntu 20.04. For training, we used four GPU compute nodes each node containing 8 Tesla V100 GPUs (32 GPUs in total). Other platforms or GPU cards have not been fully tested.
215
+
216
+ # License
217
+
218
+ The majority of this work is licensed under [CC-NC 4.0 International license](LICENSE). However portions of the project are available under separate license terms: [SlowFast](https://github.com/facebookresearch/SlowFast) and [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) are licensed under the Apache 2.0 license.
219
+
220
+ # Contributing
221
+
222
+ We actively welcome your pull requests. Please see [CONTRIBUTING.md](CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) for more info.
223
+
224
+ # Acknowledgements
225
+
226
+ TimeSformer is built on top of [PySlowFast](https://github.com/facebookresearch/SlowFast) and [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) by [Ross Wightman](https://github.com/rwightman). We thank the authors for releasing their code. If you use our model, please consider citing these works as well:
227
+
228
+ ```BibTeX
229
+ @misc{fan2020pyslowfast,
230
+ author = {Haoqi Fan and Yanghao Li and Bo Xiong and Wan-Yen Lo and
231
+ Christoph Feichtenhofer},
232
+ title = {PySlowFast},
233
+ howpublished = {\url{https://github.com/facebookresearch/slowfast}},
234
+ year = {2020}
235
+ }
236
+ ```
237
+
238
+ ```BibTeX
239
+ @misc{rw2019timm,
240
+ author = {Ross Wightman},
241
+ title = {PyTorch Image Models},
242
+ year = {2019},
243
+ publisher = {GitHub},
244
+ journal = {GitHub repository},
245
+ doi = {10.5281/zenodo.4414861},
246
+ howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
247
+ }
248
+ ```
TimeSformer/configs/Kinetics/SLOWFAST_4x16_R50.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 64
5
+ EVAL_PERIOD: 10
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 32
11
+ SAMPLING_RATE: 2
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 256
15
+ INPUT_CHANNEL_NUM: [3, 3]
16
+ SLOWFAST:
17
+ ALPHA: 8
18
+ BETA_INV: 8
19
+ FUSION_CONV_CHANNEL_RATIO: 2
20
+ FUSION_KERNEL_SZ: 5
21
+ RESNET:
22
+ ZERO_INIT_FINAL_BN: True
23
+ WIDTH_PER_GROUP: 64
24
+ NUM_GROUPS: 1
25
+ DEPTH: 50
26
+ TRANS_FUNC: bottleneck_transform
27
+ STRIDE_1X1: False
28
+ NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29
+ SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30
+ SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31
+ NONLOCAL:
32
+ LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33
+ GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34
+ INSTANTIATION: dot_product
35
+ BN:
36
+ USE_PRECISE_STATS: True
37
+ NUM_BATCHES_PRECISE: 200
38
+ SOLVER:
39
+ BASE_LR: 0.8
40
+ LR_POLICY: cosine
41
+ MAX_EPOCH: 196
42
+ MOMENTUM: 0.9
43
+ WEIGHT_DECAY: 1e-4
44
+ WARMUP_EPOCHS: 34.0
45
+ WARMUP_START_LR: 0.01
46
+ OPTIMIZING_METHOD: sgd
47
+ MODEL:
48
+ NUM_CLASSES: 400
49
+ ARCH: slowfast
50
+ MODEL_NAME: SlowFast
51
+ LOSS_FUNC: cross_entropy
52
+ DROPOUT_RATE: 0.5
53
+ TEST:
54
+ ENABLE: True
55
+ DATASET: kinetics
56
+ BATCH_SIZE: 64
57
+ DATA_LOADER:
58
+ NUM_WORKERS: 8
59
+ PIN_MEMORY: True
60
+ NUM_GPUS: 8
61
+ NUM_SHARDS: 1
62
+ RNG_SEED: 0
63
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/SLOWFAST_8x8_R101.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 64
5
+ EVAL_PERIOD: 10
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 32
11
+ SAMPLING_RATE: 2
12
+ TRAIN_JITTER_SCALES: [256, 340]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 256
15
+ INPUT_CHANNEL_NUM: [3, 3]
16
+ SLOWFAST:
17
+ ALPHA: 4
18
+ BETA_INV: 8
19
+ FUSION_CONV_CHANNEL_RATIO: 2
20
+ FUSION_KERNEL_SZ: 5
21
+ RESNET:
22
+ ZERO_INIT_FINAL_BN: True
23
+ WIDTH_PER_GROUP: 64
24
+ NUM_GROUPS: 1
25
+ DEPTH: 101
26
+ TRANS_FUNC: bottleneck_transform
27
+ STRIDE_1X1: False
28
+ NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29
+ SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30
+ SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31
+ NONLOCAL:
32
+ LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33
+ GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34
+ INSTANTIATION: dot_product
35
+ BN:
36
+ USE_PRECISE_STATS: True
37
+ NUM_BATCHES_PRECISE: 200
38
+ SOLVER:
39
+ BASE_LR: 0.8 ## 8 nodes
40
+ LR_POLICY: cosine
41
+ MAX_EPOCH: 196
42
+ MOMENTUM: 0.9
43
+ WEIGHT_DECAY: 1e-4
44
+ WARMUP_EPOCHS: 34.0
45
+ WARMUP_START_LR: 0.01
46
+ OPTIMIZING_METHOD: sgd
47
+ MODEL:
48
+ NUM_CLASSES: 400
49
+ ARCH: slowfast
50
+ MODEL_NAME: SlowFast
51
+ LOSS_FUNC: cross_entropy
52
+ DROPOUT_RATE: 0.5
53
+ TEST:
54
+ ENABLE: True
55
+ DATASET: kinetics
56
+ BATCH_SIZE: 64
57
+ DATA_LOADER:
58
+ NUM_WORKERS: 8
59
+ PIN_MEMORY: True
60
+ NUM_GPUS: 8
61
+ NUM_SHARDS: 1
62
+ RNG_SEED: 0
63
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/SLOWFAST_8x8_R50.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 64
5
+ EVAL_PERIOD: 10
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 32
11
+ SAMPLING_RATE: 2
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 256
15
+ INPUT_CHANNEL_NUM: [3, 3]
16
+ SLOWFAST:
17
+ ALPHA: 4
18
+ BETA_INV: 8
19
+ FUSION_CONV_CHANNEL_RATIO: 2
20
+ FUSION_KERNEL_SZ: 7
21
+ RESNET:
22
+ ZERO_INIT_FINAL_BN: True
23
+ WIDTH_PER_GROUP: 64
24
+ NUM_GROUPS: 1
25
+ DEPTH: 50
26
+ TRANS_FUNC: bottleneck_transform
27
+ STRIDE_1X1: False
28
+ NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29
+ SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30
+ SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31
+ NONLOCAL:
32
+ LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33
+ GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34
+ INSTANTIATION: dot_product
35
+ BN:
36
+ USE_PRECISE_STATS: True
37
+ NUM_BATCHES_PRECISE: 200
38
+ SOLVER:
39
+ BASE_LR: 0.8
40
+ LR_POLICY: cosine
41
+ MAX_EPOCH: 196
42
+ MOMENTUM: 0.9
43
+ WEIGHT_DECAY: 1e-4
44
+ WARMUP_EPOCHS: 34.0
45
+ WARMUP_START_LR: 0.01
46
+ OPTIMIZING_METHOD: sgd
47
+ MODEL:
48
+ NUM_CLASSES: 400
49
+ ARCH: slowfast
50
+ MODEL_NAME: SlowFast
51
+ LOSS_FUNC: cross_entropy
52
+ DROPOUT_RATE: 0.5
53
+ TEST:
54
+ ENABLE: True
55
+ DATASET: kinetics
56
+ BATCH_SIZE: 64
57
+ DATA_LOADER:
58
+ NUM_WORKERS: 8
59
+ PIN_MEMORY: True
60
+ NUM_GPUS: 8
61
+ NUM_SHARDS: 1
62
+ RNG_SEED: 0
63
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_divST_16x16_448.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 16
11
+ SAMPLING_RATE: 16
12
+ TRAIN_JITTER_SCALES: [448, 512]
13
+ TRAIN_CROP_SIZE: 448
14
+ TEST_CROP_SIZE: 448
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'divided_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 8
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 8
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 8
11
+ SAMPLING_RATE: 32
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'divided_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 8
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 8
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 4
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 8
11
+ SAMPLING_RATE: 32
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'divided_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 4
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 4
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 4
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: False
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 8
11
+ SAMPLING_RATE: 32
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'divided_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ CHECKPOINT_FILE_PATH: '/checkpoint/gedas/jobs/timesformer/kinetics_400/TimeSformer_divST_8x32_224/checkpoints/checkpoint_epoch_00025.pyth'
40
+ DATA_LOADER:
41
+ NUM_WORKERS: 8
42
+ PIN_MEMORY: True
43
+ NUM_GPUS: 8
44
+ NUM_SHARDS: 1
45
+ RNG_SEED: 0
46
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_divST_96x4_224.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 96
11
+ SAMPLING_RATE: 4
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'divided_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 8
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 8
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 8
11
+ SAMPLING_RATE: 32
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'joint_space_time'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 8
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 8
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: kinetics
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: /path/to/kinetics/
10
+ NUM_FRAMES: 8
11
+ SAMPLING_RATE: 32
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ TIMESFORMER:
17
+ ATTENTION_TYPE: 'space_only'
18
+ SOLVER:
19
+ BASE_LR: 0.005
20
+ LR_POLICY: steps_with_relative_lrs
21
+ STEPS: [0, 11, 14]
22
+ LRS: [1, 0.1, 0.01]
23
+ MAX_EPOCH: 15
24
+ MOMENTUM: 0.9
25
+ WEIGHT_DECAY: 1e-4
26
+ OPTIMIZING_METHOD: sgd
27
+ MODEL:
28
+ MODEL_NAME: vit_base_patch16_224
29
+ NUM_CLASSES: 400
30
+ ARCH: vit
31
+ LOSS_FUNC: cross_entropy
32
+ DROPOUT_RATE: 0.5
33
+ TEST:
34
+ ENABLE: True
35
+ DATASET: kinetics
36
+ BATCH_SIZE: 8
37
+ NUM_ENSEMBLE_VIEWS: 1
38
+ NUM_SPATIAL_CROPS: 3
39
+ DATA_LOADER:
40
+ NUM_WORKERS: 8
41
+ PIN_MEMORY: True
42
+ NUM_GPUS: 8
43
+ NUM_SHARDS: 1
44
+ RNG_SEED: 0
45
+ OUTPUT_DIR: .
TimeSformer/configs/SSv2/SLOWFAST_16x8_R50.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: ssv2
4
+ BATCH_SIZE: 16
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
10
+ PATH_PREFIX: "/path/to/ssv2/frames/"
11
+ NUM_FRAMES: 64
12
+ SAMPLING_RATE: 2
13
+ TRAIN_JITTER_SCALES: [256, 320]
14
+ TRAIN_CROP_SIZE: 224
15
+ TEST_CROP_SIZE: 256
16
+ INPUT_CHANNEL_NUM: [3, 3]
17
+ INV_UNIFORM_SAMPLE: True
18
+ RANDOM_FLIP: False
19
+ REVERSE_INPUT_CHANNEL: True
20
+ SLOWFAST:
21
+ ALPHA: 4
22
+ BETA_INV: 8
23
+ FUSION_CONV_CHANNEL_RATIO: 2
24
+ FUSION_KERNEL_SZ: 7
25
+ RESNET:
26
+ SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
27
+ SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
28
+ ZERO_INIT_FINAL_BN: True
29
+ WIDTH_PER_GROUP: 64
30
+ NUM_GROUPS: 1
31
+ DEPTH: 50
32
+ TRANS_FUNC: bottleneck_transform
33
+ STRIDE_1X1: False
34
+ NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
35
+ NONLOCAL:
36
+ LOCATION: [[[], []], [[], []], [[], []], [[], []]]
37
+ GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
38
+ INSTANTIATION: dot_product
39
+ BN:
40
+ USE_PRECISE_STATS: True
41
+ NUM_BATCHES_PRECISE: 200
42
+ NORM_TYPE: sync_batchnorm
43
+ NUM_SYNC_DEVICES: 4
44
+ SOLVER:
45
+ BASE_LR: 0.2 #8 nodes
46
+ LR_POLICY: cosine
47
+ MAX_EPOCH: 200
48
+ MOMENTUM: 0.9
49
+ WEIGHT_DECAY: 1e-4
50
+ WARMUP_EPOCHS: 34.0
51
+ WARMUP_START_LR: 0.01
52
+ OPTIMIZING_METHOD: sgd
53
+ #SOLVER:
54
+ # BASE_LR: 0.03
55
+ # LR_POLICY: steps_with_relative_lrs
56
+ # LRS: [1, 0.1, 0.01, 0.001, 0.0001, 0.00001]
57
+ # STEPS: [0, 14, 18]
58
+ # MAX_EPOCH: 22
59
+ # MOMENTUM: 0.9
60
+ # WEIGHT_DECAY: 1e-6
61
+ # WARMUP_EPOCHS: 0.19
62
+ # WARMUP_START_LR: 0.0001
63
+ # OPTIMIZING_METHOD: sgd
64
+ MODEL:
65
+ NUM_CLASSES: 174
66
+ ARCH: slowfast
67
+ LOSS_FUNC: cross_entropy
68
+ DROPOUT_RATE: 0.5
69
+ TEST:
70
+ ENABLE: True
71
+ DATASET: ssv2
72
+ BATCH_SIZE: 16
73
+ NUM_ENSEMBLE_VIEWS: 1
74
+ NUM_SPATIAL_CROPS: 1
75
+ DATA_LOADER:
76
+ NUM_WORKERS: 4
77
+ PIN_MEMORY: True
78
+ NUM_GPUS: 8
79
+ NUM_SHARDS: 1
80
+ RNG_SEED: 0
81
+ OUTPUT_DIR: .
82
+ #LOG_MODEL_INFO: False
83
+ LOG_MODEL_INFO: True
TimeSformer/configs/SSv2/TimeSformer_divST_16_448.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: ssv2
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
10
+ PATH_PREFIX: "/path/to/ssv2/frames/"
11
+ NUM_FRAMES: 16
12
+ TRAIN_JITTER_SCALES: [448, 512]
13
+ TRAIN_CROP_SIZE: 448
14
+ TEST_CROP_SIZE: 448
15
+ INPUT_CHANNEL_NUM: [3]
16
+ INV_UNIFORM_SAMPLE: True
17
+ RANDOM_FLIP: False
18
+ REVERSE_INPUT_CHANNEL: True
19
+ TIMESFORMER:
20
+ ATTENTION_TYPE: 'divided_space_time'
21
+ SOLVER:
22
+ BASE_LR: 0.005
23
+ LR_POLICY: steps_with_relative_lrs
24
+ STEPS: [0, 11, 14]
25
+ LRS: [1, 0.1, 0.01]
26
+ MAX_EPOCH: 15
27
+ MOMENTUM: 0.9
28
+ WEIGHT_DECAY: 1e-4
29
+ OPTIMIZING_METHOD: sgd
30
+ MODEL:
31
+ MODEL_NAME: vit_base_patch16_224
32
+ NUM_CLASSES: 174
33
+ ARCH: vit
34
+ LOSS_FUNC: cross_entropy
35
+ DROPOUT_RATE: 0.5
36
+ TEST:
37
+ ENABLE: True
38
+ DATASET: ssv2
39
+ BATCH_SIZE: 8
40
+ NUM_ENSEMBLE_VIEWS: 1
41
+ NUM_SPATIAL_CROPS: 3
42
+ DATA_LOADER:
43
+ NUM_WORKERS: 4
44
+ PIN_MEMORY: True
45
+ NUM_GPUS: 8
46
+ NUM_SHARDS: 1
47
+ RNG_SEED: 0
48
+ OUTPUT_DIR: .
TimeSformer/configs/SSv2/TimeSformer_divST_64_224.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: ssv2
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
10
+ PATH_PREFIX: "/path/to/ssv2/frames/"
11
+ NUM_FRAMES: 64
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ INV_UNIFORM_SAMPLE: True
17
+ RANDOM_FLIP: False
18
+ REVERSE_INPUT_CHANNEL: True
19
+ TIMESFORMER:
20
+ ATTENTION_TYPE: 'divided_space_time'
21
+ SOLVER:
22
+ BASE_LR: 0.005
23
+ LR_POLICY: steps_with_relative_lrs
24
+ STEPS: [0, 11, 14]
25
+ LRS: [1, 0.1, 0.01]
26
+ MAX_EPOCH: 15
27
+ MOMENTUM: 0.9
28
+ WEIGHT_DECAY: 1e-4
29
+ OPTIMIZING_METHOD: sgd
30
+ MODEL:
31
+ MODEL_NAME: vit_base_patch16_224
32
+ NUM_CLASSES: 174
33
+ ARCH: vit
34
+ LOSS_FUNC: cross_entropy
35
+ DROPOUT_RATE: 0.5
36
+ TEST:
37
+ ENABLE: True
38
+ DATASET: ssv2
39
+ BATCH_SIZE: 8
40
+ NUM_ENSEMBLE_VIEWS: 1
41
+ NUM_SPATIAL_CROPS: 3
42
+ DATA_LOADER:
43
+ NUM_WORKERS: 4
44
+ PIN_MEMORY: True
45
+ NUM_GPUS: 8
46
+ NUM_SHARDS: 1
47
+ RNG_SEED: 0
48
+ OUTPUT_DIR: .
TimeSformer/configs/SSv2/TimeSformer_divST_8_224.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: ssv2
4
+ BATCH_SIZE: 8
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ DATA:
9
+ PATH_TO_DATA_DIR: " /path/to/ssv2/annotations/"
10
+ PATH_PREFIX: "/path/to/ssv2/frames/"
11
+ NUM_FRAMES: 8
12
+ TRAIN_JITTER_SCALES: [256, 320]
13
+ TRAIN_CROP_SIZE: 224
14
+ TEST_CROP_SIZE: 224
15
+ INPUT_CHANNEL_NUM: [3]
16
+ INV_UNIFORM_SAMPLE: True
17
+ RANDOM_FLIP: False
18
+ REVERSE_INPUT_CHANNEL: True
19
+ TIMESFORMER:
20
+ ATTENTION_TYPE: 'divided_space_time'
21
+ SOLVER:
22
+ BASE_LR: 0.005
23
+ LR_POLICY: steps_with_relative_lrs
24
+ STEPS: [0, 11, 14]
25
+ LRS: [1, 0.1, 0.01]
26
+ MAX_EPOCH: 15
27
+ MOMENTUM: 0.9
28
+ WEIGHT_DECAY: 1e-4
29
+ OPTIMIZING_METHOD: sgd
30
+ MODEL:
31
+ MODEL_NAME: vit_base_patch16_224
32
+ NUM_CLASSES: 174
33
+ ARCH: vit
34
+ LOSS_FUNC: cross_entropy
35
+ DROPOUT_RATE: 0.5
36
+ TEST:
37
+ ENABLE: True
38
+ DATASET: ssv2
39
+ BATCH_SIZE: 8
40
+ NUM_ENSEMBLE_VIEWS: 1
41
+ NUM_SPATIAL_CROPS: 3
42
+ DATA_LOADER:
43
+ NUM_WORKERS: 4
44
+ PIN_MEMORY: True
45
+ NUM_GPUS: 8
46
+ NUM_SHARDS: 1
47
+ RNG_SEED: 0
48
+ OUTPUT_DIR: .
TimeSformer/environment.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: timesformer
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - python>3.7
8
+ - jupyterlab
9
+ - pandas>=1.2
10
+ - numpy>1.19
11
+ - pytorch>=1.6
12
+ - torchvision>=0.7
13
+ - scikit-learn>=0.22
14
+ - opencv>=4.2
15
+ - pyyaml>=5.1
16
+ - yacs>=0.1.6
17
+ - einops>=0.3
18
+ - tensorboard
19
+ - psutil
20
+ - tqdm
21
+ - matplotlib
22
+ - simplejson
23
+ - pip
24
+ - pip:
25
+ - fvcore
26
+ - av
TimeSformer/example.ipynb ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "08fe0c59",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from pathlib import Path\n",
11
+ "\n",
12
+ "import torch\n",
13
+ "from timesformer.models.vit import TimeSformer"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "id": "10239d32",
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "data": {
24
+ "text/plain": [
25
+ "True"
26
+ ]
27
+ },
28
+ "execution_count": 2,
29
+ "metadata": {},
30
+ "output_type": "execute_result"
31
+ }
32
+ ],
33
+ "source": [
34
+ "model_file = Path.home()/'TimeSformer/models/TimeSformer_divST_8x32_224_K600.pyth'\n",
35
+ "model_file.exists()"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "id": "652fb03e",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "model = TimeSformer(img_size=224, num_classes=600, num_frames=8, attention_type='divided_space_time', pretrained_model=str(model_file))\n",
46
+ "\n",
47
+ "dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)\n",
48
+ "\n",
49
+ "pred = model(dummy_video,) # (2, 600)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 6,
55
+ "id": "83de13c5-791c-4db7-aba4-6d29ce88584e",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "assert pred.shape == (2,600)"
60
+ ]
61
+ }
62
+ ],
63
+ "metadata": {
64
+ "kernelspec": {
65
+ "display_name": "Python 3",
66
+ "language": "python",
67
+ "name": "python3"
68
+ },
69
+ "language_info": {
70
+ "codemirror_mode": {
71
+ "name": "ipython",
72
+ "version": 3
73
+ },
74
+ "file_extension": ".py",
75
+ "mimetype": "text/x-python",
76
+ "name": "python",
77
+ "nbconvert_exporter": "python",
78
+ "pygments_lexer": "ipython3",
79
+ "version": "3.9.4"
80
+ }
81
+ },
82
+ "nbformat": 4,
83
+ "nbformat_minor": 5
84
+ }
TimeSformer/setup.cfg ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ line_length=100
3
+ multi_line_output=4
4
+ known_standard_library=numpy,setuptools
5
+ known_myself=timesformer
6
+ known_third_party=fvcore,av,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib,torchvision,yaml,tqdm,psutil,opencv-python,pandas,tensorboard,moviepy,sklearn,cv2
7
+ no_lines_before=STDLIB,THIRDPARTY
8
+ sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
9
+ default_section=FIRSTPARTY
10
+
11
+ [mypy]
12
+ python_version=3.6
13
+ ignore_missing_imports = True
14
+ warn_unused_configs = True
15
+ disallow_untyped_defs = True
16
+ check_untyped_defs = True
17
+ warn_unused_ignores = True
18
+ warn_redundant_casts = True
19
+ show_column_numbers = True
20
+ follow_imports = silent
21
+ allow_redefinition = True
22
+ ; Require all functions to be annotated
23
+ disallow_incomplete_defs = True
TimeSformer/setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ setup(
6
+ name="timesformer",
7
+ version="1.0",
8
+ author="FBAI",
9
+ url="unknown",
10
+ description="TimeSformer",
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'attention mechanism',
14
+ 'transformers',
15
+ 'video classification',
16
+ ],
17
+ install_requires=[
18
+ 'einops>=0.3',
19
+ 'torch>=1.6'
20
+ ],
21
+ extras_require={"tensorboard_video_visualization": ["moviepy"]},
22
+ packages=find_packages(exclude=("configs", "tests")),
23
+ )
TimeSformer/slurm_scripts/run_multi_node_job.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # A script with a list of commands for submitting SLURM jobs
3
+
4
+ #### Kinetics training
5
+ JOB_NAME=TimeSformer_divST_8x32_224
6
+ python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition dev --comment "" --name ${JOB_NAME} --use_volta32
7
+
8
+ #JOB_NAME=TimeSformer_jointST_8x32_224
9
+ #python tools/submit.py --cfg configs/Kinetics/TimeSformer_jointST_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
10
+
11
+ #JOB_NAME=TimeSformer_spaceOnly_8x32_224
12
+ #python tools/submit.py --cfg configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
13
+
14
+ #### Kinetics inference
15
+ #JOB_NAME=TimeSformer_divST_8x32_224_TEST_3clips
16
+ #python tools/submit.py --cfg configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition dev --comment "" --name ${JOB_NAME} --use_volta32
17
+
18
+
19
+ ##### SSv2 training
20
+ #JOB_NAME=TimeSformer_divST_8_224
21
+ #python tools/submit.py --cfg configs/SSv2/TimeSformer_divST_8_224.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
22
+
23
+ ##### Sth-Sth_v2 inference
24
+ #JOB_NAME=TimeSformer_divST_8_224_TEST_3clips
25
+ #python tools/submit.py --cfg configs/SSv2/TimeSformer_divST_8_224_TEST.yaml --job_dir /your/job/dir/${JOB_NAME}/ --num_shards 4 --partition learnfair --comment "" --name ${JOB_NAME} --use_volta32
TimeSformer/slurm_scripts/run_single_node_job.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # A script with a list of commands for submitting SLURM jobs
3
+
4
+ #SBATCH --job-name=timesformer
5
+ #SBATCH --mail-type=END,FAIL,REQUEUE
6
+ #SBATCH [email protected]
7
+
8
+ ## %j is the job id, %u is the user id
9
+ #SBATCH --output=/path/to/output/logs/slog-%A-%a.out
10
+
11
+ ## filename for job standard error output (stderr)
12
+ #SBATCH --error=/path/to/error/logs/slog-%A-%a.err
13
+
14
+ #SBATCH --array=1
15
+ #SBATCH --partition=partition_of_your_choice
16
+ #SBATCH --nodes=1 -C volta32gb
17
+ #SBATCH --ntasks-per-node=1
18
+ #SBATCH --gpus-per-node=8
19
+ #SBATCH --cpus-per-task=80
20
+ #SBATCH --mem=480GB
21
+ #SBATCH --signal=USR1@600
22
+ #SBATCH --time=72:00:00
23
+ #SBATCH --open-mode=append
24
+
25
+ module purge
26
+ module load cuda/10.0
27
+ module load NCCL/2.4.7-1-cuda.10.0
28
+ module load cudnn/v7.4-cuda.10.0
29
+ source activate timesformer
30
+
31
+ WORKINGDIR=/path/to/TimeSformer
32
+ CURPYTHON=/path/to/python
33
+
34
+ srun --label ${CURPYTHON} ${WORKINGDIR}/tools/run_net.py --cfg ${WORKINGDIR}/configs/Kinetics/TimeSformer_divST_8x32_224.yaml NUM_GPUS 8 TRAIN.BATCH_SIZE 8
35
+
TimeSformer/timesformer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from timesformer.utils.env import setup_environment
4
+
5
+ setup_environment()
TimeSformer/timesformer/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
TimeSformer/timesformer/config/defaults.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """Configs."""
4
+ from fvcore.common.config import CfgNode
5
+ # -----------------------------------------------------------------------------
6
+ # Config definition
7
+ # -----------------------------------------------------------------------------
8
+ _C = CfgNode()
9
+
10
+ # ---------------------------------------------------------------------------- #
11
+ # Batch norm options
12
+ # ---------------------------------------------------------------------------- #
13
+ _C.BN = CfgNode()
14
+
15
+ # Precise BN stats.
16
+ _C.BN.USE_PRECISE_STATS = False
17
+
18
+ # Number of samples use to compute precise bn.
19
+ _C.BN.NUM_BATCHES_PRECISE = 200
20
+
21
+ # Weight decay value that applies on BN.
22
+ _C.BN.WEIGHT_DECAY = 0.0
23
+
24
+ # Norm type, options include `batchnorm`, `sub_batchnorm`, `sync_batchnorm`
25
+ _C.BN.NORM_TYPE = "batchnorm"
26
+
27
+ # Parameter for SubBatchNorm, where it splits the batch dimension into
28
+ # NUM_SPLITS splits, and run BN on each of them separately independently.
29
+ _C.BN.NUM_SPLITS = 1
30
+
31
+ # Parameter for NaiveSyncBatchNorm3d, where the stats across `NUM_SYNC_DEVICES`
32
+ # devices will be synchronized.
33
+ _C.BN.NUM_SYNC_DEVICES = 1
34
+
35
+
36
+ # ---------------------------------------------------------------------------- #
37
+ # Training options.
38
+ # ---------------------------------------------------------------------------- #
39
+ _C.TRAIN = CfgNode()
40
+
41
+ # If True Train the model, else skip training.
42
+ _C.TRAIN.ENABLE = True
43
+
44
+ # Dataset.
45
+ _C.TRAIN.DATASET = "kinetics"
46
+
47
+ ##
48
+ _C.TRAIN.FINETUNE = False
49
+
50
+ # Total mini-batch size.
51
+ _C.TRAIN.BATCH_SIZE = 64
52
+
53
+ # Evaluate model on test data every eval period epochs.
54
+ _C.TRAIN.EVAL_PERIOD = 10
55
+
56
+ # Save model checkpoint every checkpoint period epochs.
57
+ _C.TRAIN.CHECKPOINT_PERIOD = 10
58
+
59
+ # Resume training from the latest checkpoint in the output directory.
60
+ _C.TRAIN.AUTO_RESUME = True
61
+
62
+ # Path to the checkpoint to load the initial weight.
63
+ _C.TRAIN.CHECKPOINT_FILE_PATH = ""
64
+
65
+ # Checkpoint types include `caffe2` or `pytorch`.
66
+ _C.TRAIN.CHECKPOINT_TYPE = "pytorch"
67
+
68
+ # If True, perform inflation when loading checkpoint.
69
+ _C.TRAIN.CHECKPOINT_INFLATE = False
70
+
71
+ # If True, reset epochs when loading checkpoint.
72
+ _C.TRAIN.CHECKPOINT_EPOCH_RESET = False
73
+
74
+ # If set, clear all layer names according to the pattern provided.
75
+ _C.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN = () # ("backbone.",)
76
+
77
+ # ---------------------------------------------------------------------------- #
78
+ # Testing options
79
+ # ---------------------------------------------------------------------------- #
80
+ _C.TEST = CfgNode()
81
+
82
+ # If True test the model, else skip the testing.
83
+ _C.TEST.ENABLE = True
84
+
85
+ # Dataset for testing.
86
+ _C.TEST.DATASET = "kinetics"
87
+
88
+ # Total mini-batch size
89
+ _C.TEST.BATCH_SIZE = 8
90
+
91
+ # Path to the checkpoint to load the initial weight.
92
+ _C.TEST.CHECKPOINT_FILE_PATH = ""
93
+
94
+ # Number of clips to sample from a video uniformly for aggregating the
95
+ # prediction results.
96
+ _C.TEST.NUM_ENSEMBLE_VIEWS = 10
97
+
98
+ # Number of crops to sample from a frame spatially for aggregating the
99
+ # prediction results.
100
+ _C.TEST.NUM_SPATIAL_CROPS = 3
101
+
102
+ # Checkpoint types include `caffe2` or `pytorch`.
103
+ _C.TEST.CHECKPOINT_TYPE = "pytorch"
104
+ # Path to saving prediction results file.
105
+ _C.TEST.SAVE_RESULTS_PATH = ""
106
+ # -----------------------------------------------------------------------------
107
+ # ResNet options
108
+ # -----------------------------------------------------------------------------
109
+ _C.RESNET = CfgNode()
110
+
111
+ # Transformation function.
112
+ _C.RESNET.TRANS_FUNC = "bottleneck_transform"
113
+
114
+ # Number of groups. 1 for ResNet, and larger than 1 for ResNeXt).
115
+ _C.RESNET.NUM_GROUPS = 1
116
+
117
+ # Width of each group (64 -> ResNet; 4 -> ResNeXt).
118
+ _C.RESNET.WIDTH_PER_GROUP = 64
119
+
120
+ # Apply relu in a inplace manner.
121
+ _C.RESNET.INPLACE_RELU = True
122
+
123
+ # Apply stride to 1x1 conv.
124
+ _C.RESNET.STRIDE_1X1 = False
125
+
126
+ # If true, initialize the gamma of the final BN of each block to zero.
127
+ _C.RESNET.ZERO_INIT_FINAL_BN = False
128
+
129
+ # Number of weight layers.
130
+ _C.RESNET.DEPTH = 50
131
+
132
+ # If the current block has more than NUM_BLOCK_TEMP_KERNEL blocks, use temporal
133
+ # kernel of 1 for the rest of the blocks.
134
+ _C.RESNET.NUM_BLOCK_TEMP_KERNEL = [[3], [4], [6], [3]]
135
+
136
+ # Size of stride on different res stages.
137
+ _C.RESNET.SPATIAL_STRIDES = [[1], [2], [2], [2]]
138
+
139
+ # Size of dilation on different res stages.
140
+ _C.RESNET.SPATIAL_DILATIONS = [[1], [1], [1], [1]]
141
+
142
+ # ---------------------------------------------------------------------------- #
143
+ # X3D options
144
+ # See https://arxiv.org/abs/2004.04730 for details about X3D Networks.
145
+ # ---------------------------------------------------------------------------- #
146
+ _C.X3D = CfgNode()
147
+
148
+ # Width expansion factor.
149
+ _C.X3D.WIDTH_FACTOR = 1.0
150
+
151
+ # Depth expansion factor.
152
+ _C.X3D.DEPTH_FACTOR = 1.0
153
+
154
+ # Bottleneck expansion factor for the 3x3x3 conv.
155
+ _C.X3D.BOTTLENECK_FACTOR = 1.0 #
156
+
157
+ # Dimensions of the last linear layer before classificaiton.
158
+ _C.X3D.DIM_C5 = 2048
159
+
160
+ # Dimensions of the first 3x3 conv layer.
161
+ _C.X3D.DIM_C1 = 12
162
+
163
+ # Whether to scale the width of Res2, default is false.
164
+ _C.X3D.SCALE_RES2 = False
165
+
166
+ # Whether to use a BatchNorm (BN) layer before the classifier, default is false.
167
+ _C.X3D.BN_LIN5 = False
168
+
169
+ # Whether to use channelwise (=depthwise) convolution in the center (3x3x3)
170
+ # convolution operation of the residual blocks.
171
+ _C.X3D.CHANNELWISE_3x3x3 = True
172
+
173
+ # -----------------------------------------------------------------------------
174
+ # Nonlocal options
175
+ # -----------------------------------------------------------------------------
176
+ _C.NONLOCAL = CfgNode()
177
+
178
+ # Index of each stage and block to add nonlocal layers.
179
+ _C.NONLOCAL.LOCATION = [[[]], [[]], [[]], [[]]]
180
+
181
+ # Number of group for nonlocal for each stage.
182
+ _C.NONLOCAL.GROUP = [[1], [1], [1], [1]]
183
+
184
+ # Instatiation to use for non-local layer.
185
+ _C.NONLOCAL.INSTANTIATION = "dot_product"
186
+
187
+
188
+ # Size of pooling layers used in Non-Local.
189
+ _C.NONLOCAL.POOL = [
190
+ # Res2
191
+ [[1, 2, 2], [1, 2, 2]],
192
+ # Res3
193
+ [[1, 2, 2], [1, 2, 2]],
194
+ # Res4
195
+ [[1, 2, 2], [1, 2, 2]],
196
+ # Res5
197
+ [[1, 2, 2], [1, 2, 2]],
198
+ ]
199
+
200
+ # -----------------------------------------------------------------------------
201
+ # Model options
202
+ # -----------------------------------------------------------------------------
203
+ _C.MODEL = CfgNode()
204
+
205
+ # Model architecture.
206
+ _C.MODEL.ARCH = "slowfast"
207
+
208
+ # Model name
209
+ _C.MODEL.MODEL_NAME = "SlowFast"
210
+
211
+ # The number of classes to predict for the model.
212
+ _C.MODEL.NUM_CLASSES = 400
213
+
214
+ # Loss function.
215
+ _C.MODEL.LOSS_FUNC = "cross_entropy"
216
+
217
+ # Model architectures that has one single pathway.
218
+ _C.MODEL.SINGLE_PATHWAY_ARCH = ["c2d", "i3d", "slow", "x3d"]
219
+
220
+ # Model architectures that has multiple pathways.
221
+ _C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"]
222
+
223
+ # Dropout rate before final projection in the backbone.
224
+ _C.MODEL.DROPOUT_RATE = 0.5
225
+
226
+ # Randomly drop rate for Res-blocks, linearly increase from res2 to res5
227
+ _C.MODEL.DROPCONNECT_RATE = 0.0
228
+
229
+ # The std to initialize the fc layer(s).
230
+ _C.MODEL.FC_INIT_STD = 0.01
231
+
232
+ # Activation layer for the output head.
233
+ _C.MODEL.HEAD_ACT = "softmax"
234
+
235
+
236
+ # -----------------------------------------------------------------------------
237
+ # SlowFast options
238
+ # -----------------------------------------------------------------------------
239
+ _C.SLOWFAST = CfgNode()
240
+
241
+ # Corresponds to the inverse of the channel reduction ratio, $\beta$ between
242
+ # the Slow and Fast pathways.
243
+ _C.SLOWFAST.BETA_INV = 8
244
+
245
+ # Corresponds to the frame rate reduction ratio, $\alpha$ between the Slow and
246
+ # Fast pathways.
247
+ _C.SLOWFAST.ALPHA = 8
248
+
249
+ # Ratio of channel dimensions between the Slow and Fast pathways.
250
+ _C.SLOWFAST.FUSION_CONV_CHANNEL_RATIO = 2
251
+
252
+ # Kernel dimension used for fusing information from Fast pathway to Slow
253
+ # pathway.
254
+ _C.SLOWFAST.FUSION_KERNEL_SZ = 5
255
+
256
+ ####### TimeSformer Options
257
+ _C.TIMESFORMER = CfgNode()
258
+ _C.TIMESFORMER.ATTENTION_TYPE = 'divided_space_time'
259
+ _C.TIMESFORMER.PRETRAINED_MODEL = ''
260
+
261
+ ## MixUp parameters
262
+ _C.MIXUP = CfgNode()
263
+ _C.MIXUP.ENABLED = False
264
+ _C.MIXUP.ALPHA = 0.8
265
+ _C.MIXUP.CUTMIX_ALPHA = 1.0
266
+ _C.MIXUP.CUTMIX_MINMAX = None
267
+ _C.MIXUP.PROB = 1.0
268
+ _C.MIXUP.SWITCH_PROB = 0.5
269
+ _C.MIXUP.MODE = 'batch'
270
+
271
+ _C.EMA = CfgNode()
272
+ _C.EMA.ENABLED = False
273
+
274
+ # -----------------------------------------------------------------------------
275
+ # Data options
276
+ # -----------------------------------------------------------------------------
277
+ _C.DATA = CfgNode()
278
+
279
+ # The path to the data directory.
280
+ _C.DATA.PATH_TO_DATA_DIR = ""
281
+
282
+ # The separator used between path and label.
283
+ _C.DATA.PATH_LABEL_SEPARATOR = " "
284
+
285
+ # Video path prefix if any.
286
+ _C.DATA.PATH_PREFIX = ""
287
+
288
+ # The spatial crop size of the input clip.
289
+ _C.DATA.CROP_SIZE = 224
290
+
291
+ # The number of frames of the input clip.
292
+ _C.DATA.NUM_FRAMES = 8
293
+
294
+ # The video sampling rate of the input clip.
295
+ _C.DATA.SAMPLING_RATE = 8
296
+
297
+ # The mean value of the video raw pixels across the R G B channels.
298
+ _C.DATA.MEAN = [0.45, 0.45, 0.45]
299
+ # List of input frame channel dimensions.
300
+
301
+ _C.DATA.INPUT_CHANNEL_NUM = [3, 3]
302
+
303
+ # The std value of the video raw pixels across the R G B channels.
304
+ _C.DATA.STD = [0.225, 0.225, 0.225]
305
+
306
+ # The spatial augmentation jitter scales for training.
307
+ _C.DATA.TRAIN_JITTER_SCALES = [256, 320]
308
+
309
+ # The spatial crop size for training.
310
+ _C.DATA.TRAIN_CROP_SIZE = 224
311
+
312
+ # The spatial crop size for testing.
313
+ _C.DATA.TEST_CROP_SIZE = 256
314
+
315
+ # Input videos may has different fps, convert it to the target video fps before
316
+ # frame sampling.
317
+ _C.DATA.TARGET_FPS = 30
318
+
319
+ # Decoding backend, options include `pyav` or `torchvision`
320
+ _C.DATA.DECODING_BACKEND = "pyav"
321
+
322
+ # if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a
323
+ # reciprocal to get the scale. If False, take a uniform sample from
324
+ # [min_scale, max_scale].
325
+ _C.DATA.INV_UNIFORM_SAMPLE = False
326
+
327
+ # If True, perform random horizontal flip on the video frames during training.
328
+ _C.DATA.RANDOM_FLIP = True
329
+
330
+ # If True, calculdate the map as metric.
331
+ _C.DATA.MULTI_LABEL = False
332
+
333
+ # Method to perform the ensemble, options include "sum" and "max".
334
+ _C.DATA.ENSEMBLE_METHOD = "sum"
335
+
336
+ # If True, revert the default input channel (RBG <-> BGR).
337
+ _C.DATA.REVERSE_INPUT_CHANNEL = False
338
+
339
+ ############
340
+ _C.DATA.TEMPORAL_EXTENT = 8
341
+ _C.DATA.DEIT_TRANSFORMS = False
342
+ _C.DATA.COLOR_JITTER = 0.
343
+ _C.DATA.AUTO_AUGMENT = ''
344
+ _C.DATA.RE_PROB = 0.0
345
+
346
+ # ---------------------------------------------------------------------------- #
347
+ # Optimizer options
348
+ # ---------------------------------------------------------------------------- #
349
+ _C.SOLVER = CfgNode()
350
+
351
+ # Base learning rate.
352
+ _C.SOLVER.BASE_LR = 0.1
353
+
354
+ # Learning rate policy (see utils/lr_policy.py for options and examples).
355
+ _C.SOLVER.LR_POLICY = "cosine"
356
+
357
+ # Final learning rates for 'cosine' policy.
358
+ _C.SOLVER.COSINE_END_LR = 0.0
359
+
360
+ # Exponential decay factor.
361
+ _C.SOLVER.GAMMA = 0.1
362
+
363
+ # Step size for 'exp' and 'cos' policies (in epochs).
364
+ _C.SOLVER.STEP_SIZE = 1
365
+
366
+ # Steps for 'steps_' policies (in epochs).
367
+ _C.SOLVER.STEPS = []
368
+
369
+ # Learning rates for 'steps_' policies.
370
+ _C.SOLVER.LRS = []
371
+
372
+ # Maximal number of epochs.
373
+ _C.SOLVER.MAX_EPOCH = 300
374
+
375
+ # Momentum.
376
+ _C.SOLVER.MOMENTUM = 0.9
377
+
378
+ # Momentum dampening.
379
+ _C.SOLVER.DAMPENING = 0.0
380
+
381
+ # Nesterov momentum.
382
+ _C.SOLVER.NESTEROV = True
383
+
384
+ # L2 regularization.
385
+ _C.SOLVER.WEIGHT_DECAY = 1e-4
386
+
387
+ # Start the warm up from SOLVER.BASE_LR * SOLVER.WARMUP_FACTOR.
388
+ _C.SOLVER.WARMUP_FACTOR = 0.1
389
+
390
+ # Gradually warm up the SOLVER.BASE_LR over this number of epochs.
391
+ _C.SOLVER.WARMUP_EPOCHS = 0.0
392
+
393
+ # The start learning rate of the warm up.
394
+ _C.SOLVER.WARMUP_START_LR = 0.01
395
+
396
+ # Optimization method.
397
+ _C.SOLVER.OPTIMIZING_METHOD = "sgd"
398
+
399
+ # Base learning rate is linearly scaled with NUM_SHARDS.
400
+ _C.SOLVER.BASE_LR_SCALE_NUM_SHARDS = False
401
+
402
+ # ---------------------------------------------------------------------------- #
403
+ # Misc options
404
+ # ---------------------------------------------------------------------------- #
405
+
406
+ # Number of GPUs to use (applies to both training and testing).
407
+ _C.NUM_GPUS = 1
408
+
409
+ # Number of machine to use for the job.
410
+ _C.NUM_SHARDS = 1
411
+
412
+ # The index of the current machine.
413
+ _C.SHARD_ID = 0
414
+
415
+ # Output basedir.
416
+ _C.OUTPUT_DIR = "./tmp"
417
+
418
+ # Note that non-determinism may still be present due to non-deterministic
419
+ # operator implementations in GPU operator libraries.
420
+ _C.RNG_SEED = 1
421
+
422
+ # Log period in iters.
423
+ _C.LOG_PERIOD = 10
424
+
425
+ # If True, log the model info.
426
+ _C.LOG_MODEL_INFO = False
427
+
428
+ # Distributed backend.
429
+ _C.DIST_BACKEND = "nccl"
430
+
431
+ # Global batch size
432
+ _C.GLOBAL_BATCH_SIZE = 64
433
+
434
+ # ---------------------------------------------------------------------------- #
435
+ # Benchmark options
436
+ # ---------------------------------------------------------------------------- #
437
+ _C.BENCHMARK = CfgNode()
438
+
439
+ # Number of epochs for data loading benchmark.
440
+ _C.BENCHMARK.NUM_EPOCHS = 5
441
+
442
+ # Log period in iters for data loading benchmark.
443
+ _C.BENCHMARK.LOG_PERIOD = 100
444
+
445
+ # If True, shuffle dataloader for epoch during benchmark.
446
+ _C.BENCHMARK.SHUFFLE = True
447
+
448
+
449
+ # ---------------------------------------------------------------------------- #
450
+ # Common train/test data loader options
451
+ # ---------------------------------------------------------------------------- #
452
+ _C.DATA_LOADER = CfgNode()
453
+
454
+ # Number of data loader workers per training process.
455
+ _C.DATA_LOADER.NUM_WORKERS = 8
456
+
457
+ # Load data to pinned host memory.
458
+ _C.DATA_LOADER.PIN_MEMORY = True
459
+
460
+ # Enable multi thread decoding.
461
+ _C.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE = False
462
+
463
+
464
+ # ---------------------------------------------------------------------------- #
465
+ # Detection options.
466
+ # ---------------------------------------------------------------------------- #
467
+ _C.DETECTION = CfgNode()
468
+
469
+ # Whether enable video detection.
470
+ _C.DETECTION.ENABLE = False
471
+
472
+ # Aligned version of RoI. More details can be found at slowfast/models/head_helper.py
473
+ _C.DETECTION.ALIGNED = True
474
+
475
+ # Spatial scale factor.
476
+ _C.DETECTION.SPATIAL_SCALE_FACTOR = 16
477
+
478
+ # RoI tranformation resolution.
479
+ _C.DETECTION.ROI_XFORM_RESOLUTION = 7
480
+
481
+
482
+ # -----------------------------------------------------------------------------
483
+ # AVA Dataset options
484
+ # -----------------------------------------------------------------------------
485
+ _C.AVA = CfgNode()
486
+
487
+ # Directory path of frames.
488
+ _C.AVA.FRAME_DIR = "/mnt/fair-flash3-east/ava_trainval_frames.img/"
489
+
490
+ # Directory path for files of frame lists.
491
+ _C.AVA.FRAME_LIST_DIR = (
492
+ "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
493
+ )
494
+
495
+ # Directory path for annotation files.
496
+ _C.AVA.ANNOTATION_DIR = (
497
+ "/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
498
+ )
499
+
500
+ # Filenames of training samples list files.
501
+ _C.AVA.TRAIN_LISTS = ["train.csv"]
502
+
503
+ # Filenames of test samples list files.
504
+ _C.AVA.TEST_LISTS = ["val.csv"]
505
+
506
+ # Filenames of box list files for training. Note that we assume files which
507
+ # contains predicted boxes will have a suffix "predicted_boxes" in the
508
+ # filename.
509
+ _C.AVA.TRAIN_GT_BOX_LISTS = ["ava_train_v2.2.csv"]
510
+ _C.AVA.TRAIN_PREDICT_BOX_LISTS = []
511
+
512
+ # Filenames of box list files for test.
513
+ _C.AVA.TEST_PREDICT_BOX_LISTS = ["ava_val_predicted_boxes.csv"]
514
+
515
+ # This option controls the score threshold for the predicted boxes to use.
516
+ _C.AVA.DETECTION_SCORE_THRESH = 0.9
517
+
518
+ # If use BGR as the format of input frames.
519
+ _C.AVA.BGR = False
520
+
521
+ # Training augmentation parameters
522
+ # Whether to use color augmentation method.
523
+ _C.AVA.TRAIN_USE_COLOR_AUGMENTATION = False
524
+
525
+ # Whether to only use PCA jitter augmentation when using color augmentation
526
+ # method (otherwise combine with color jitter method).
527
+ _C.AVA.TRAIN_PCA_JITTER_ONLY = True
528
+
529
+ # Eigenvalues for PCA jittering. Note PCA is RGB based.
530
+ _C.AVA.TRAIN_PCA_EIGVAL = [0.225, 0.224, 0.229]
531
+
532
+ # Eigenvectors for PCA jittering.
533
+ _C.AVA.TRAIN_PCA_EIGVEC = [
534
+ [-0.5675, 0.7192, 0.4009],
535
+ [-0.5808, -0.0045, -0.8140],
536
+ [-0.5836, -0.6948, 0.4203],
537
+ ]
538
+
539
+ # Whether to do horizontal flipping during test.
540
+ _C.AVA.TEST_FORCE_FLIP = False
541
+
542
+ # Whether to use full test set for validation split.
543
+ _C.AVA.FULL_TEST_ON_VAL = False
544
+
545
+ # The name of the file to the ava label map.
546
+ _C.AVA.LABEL_MAP_FILE = "ava_action_list_v2.2_for_activitynet_2019.pbtxt"
547
+
548
+ # The name of the file to the ava exclusion.
549
+ _C.AVA.EXCLUSION_FILE = "ava_val_excluded_timestamps_v2.2.csv"
550
+
551
+ # The name of the file to the ava groundtruth.
552
+ _C.AVA.GROUNDTRUTH_FILE = "ava_val_v2.2.csv"
553
+
554
+ # Backend to process image, includes `pytorch` and `cv2`.
555
+ _C.AVA.IMG_PROC_BACKEND = "cv2"
556
+
557
+ # ---------------------------------------------------------------------------- #
558
+ # Multigrid training options
559
+ # See https://arxiv.org/abs/1912.00998 for details about multigrid training.
560
+ # ---------------------------------------------------------------------------- #
561
+ _C.MULTIGRID = CfgNode()
562
+
563
+ # Multigrid training allows us to train for more epochs with fewer iterations.
564
+ # This hyperparameter specifies how many times more epochs to train.
565
+ # The default setting in paper trains for 1.5x more epochs than baseline.
566
+ _C.MULTIGRID.EPOCH_FACTOR = 1.5
567
+
568
+ # Enable short cycles.
569
+ _C.MULTIGRID.SHORT_CYCLE = False
570
+ # Short cycle additional spatial dimensions relative to the default crop size.
571
+ _C.MULTIGRID.SHORT_CYCLE_FACTORS = [0.5, 0.5 ** 0.5]
572
+
573
+ _C.MULTIGRID.LONG_CYCLE = False
574
+ # (Temporal, Spatial) dimensions relative to the default shape.
575
+ _C.MULTIGRID.LONG_CYCLE_FACTORS = [
576
+ (0.25, 0.5 ** 0.5),
577
+ (0.5, 0.5 ** 0.5),
578
+ (0.5, 1),
579
+ (1, 1),
580
+ ]
581
+
582
+ # While a standard BN computes stats across all examples in a GPU,
583
+ # for multigrid training we fix the number of clips to compute BN stats on.
584
+ # See https://arxiv.org/abs/1912.00998 for details.
585
+ _C.MULTIGRID.BN_BASE_SIZE = 8
586
+
587
+ # Multigrid training epochs are not proportional to actual training time or
588
+ # computations, so _C.TRAIN.EVAL_PERIOD leads to too frequent or rare
589
+ # evaluation. We use a multigrid-specific rule to determine when to evaluate:
590
+ # This hyperparameter defines how many times to evaluate a model per long
591
+ # cycle shape.
592
+ _C.MULTIGRID.EVAL_FREQ = 3
593
+
594
+ # No need to specify; Set automatically and used as global variables.
595
+ _C.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = 0
596
+ _C.MULTIGRID.DEFAULT_B = 0
597
+ _C.MULTIGRID.DEFAULT_T = 0
598
+ _C.MULTIGRID.DEFAULT_S = 0
599
+
600
+ # -----------------------------------------------------------------------------
601
+ # Tensorboard Visualization Options
602
+ # -----------------------------------------------------------------------------
603
+ _C.TENSORBOARD = CfgNode()
604
+
605
+ # Log to summary writer, this will automatically.
606
+ # log loss, lr and metrics during train/eval.
607
+ _C.TENSORBOARD.ENABLE = False
608
+ # Provide path to prediction results for visualization.
609
+ # This is a pickle file of [prediction_tensor, label_tensor]
610
+ _C.TENSORBOARD.PREDICTIONS_PATH = ""
611
+ # Path to directory for tensorboard logs.
612
+ # Default to to cfg.OUTPUT_DIR/runs-{cfg.TRAIN.DATASET}.
613
+ _C.TENSORBOARD.LOG_DIR = ""
614
+ # Path to a json file providing class_name - id mapping
615
+ # in the format {"class_name1": id1, "class_name2": id2, ...}.
616
+ # This file must be provided to enable plotting confusion matrix
617
+ # by a subset or parent categories.
618
+ _C.TENSORBOARD.CLASS_NAMES_PATH = ""
619
+
620
+ # Path to a json file for categories -> classes mapping
621
+ # in the format {"parent_class": ["child_class1", "child_class2",...], ...}.
622
+ _C.TENSORBOARD.CATEGORIES_PATH = ""
623
+
624
+ # Config for confusion matrices visualization.
625
+ _C.TENSORBOARD.CONFUSION_MATRIX = CfgNode()
626
+ # Visualize confusion matrix.
627
+ _C.TENSORBOARD.CONFUSION_MATRIX.ENABLE = False
628
+ # Figure size of the confusion matrices plotted.
629
+ _C.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE = [8, 8]
630
+ # Path to a subset of categories to visualize.
631
+ # File contains class names separated by newline characters.
632
+ _C.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH = ""
633
+
634
+ # Config for histogram visualization.
635
+ _C.TENSORBOARD.HISTOGRAM = CfgNode()
636
+ # Visualize histograms.
637
+ _C.TENSORBOARD.HISTOGRAM.ENABLE = False
638
+ # Path to a subset of classes to plot histograms.
639
+ # Class names must be separated by newline characters.
640
+ _C.TENSORBOARD.HISTOGRAM.SUBSET_PATH = ""
641
+ # Visualize top-k most predicted classes on histograms for each
642
+ # chosen true label.
643
+ _C.TENSORBOARD.HISTOGRAM.TOPK = 10
644
+ # Figure size of the histograms plotted.
645
+ _C.TENSORBOARD.HISTOGRAM.FIGSIZE = [8, 8]
646
+
647
+ # Config for layers' weights and activations visualization.
648
+ # _C.TENSORBOARD.ENABLE must be True.
649
+ _C.TENSORBOARD.MODEL_VIS = CfgNode()
650
+
651
+ # If False, skip model visualization.
652
+ _C.TENSORBOARD.MODEL_VIS.ENABLE = False
653
+
654
+ # If False, skip visualizing model weights.
655
+ _C.TENSORBOARD.MODEL_VIS.MODEL_WEIGHTS = False
656
+
657
+ # If False, skip visualizing model activations.
658
+ _C.TENSORBOARD.MODEL_VIS.ACTIVATIONS = False
659
+
660
+ # If False, skip visualizing input videos.
661
+ _C.TENSORBOARD.MODEL_VIS.INPUT_VIDEO = False
662
+
663
+
664
+ # List of strings containing data about layer names and their indexing to
665
+ # visualize weights and activations for. The indexing is meant for
666
+ # choosing a subset of activations outputed by a layer for visualization.
667
+ # If indexing is not specified, visualize all activations outputed by the layer.
668
+ # For each string, layer name and indexing is separated by whitespaces.
669
+ # e.g.: [layer1 1,2;1,2, layer2, layer3 150,151;3,4]; this means for each array `arr`
670
+ # along the batch dimension in `layer1`, we take arr[[1, 2], [1, 2]]
671
+ _C.TENSORBOARD.MODEL_VIS.LAYER_LIST = []
672
+ # Top-k predictions to plot on videos
673
+ _C.TENSORBOARD.MODEL_VIS.TOPK_PREDS = 1
674
+ # Colormap to for text boxes and bounding boxes colors
675
+ _C.TENSORBOARD.MODEL_VIS.COLORMAP = "Pastel2"
676
+ # Config for visualization video inputs with Grad-CAM.
677
+ # _C.TENSORBOARD.ENABLE must be True.
678
+ _C.TENSORBOARD.MODEL_VIS.GRAD_CAM = CfgNode()
679
+ # Whether to run visualization using Grad-CAM technique.
680
+ _C.TENSORBOARD.MODEL_VIS.GRAD_CAM.ENABLE = True
681
+ # CNN layers to use for Grad-CAM. The number of layers must be equal to
682
+ # number of pathway(s).
683
+ _C.TENSORBOARD.MODEL_VIS.GRAD_CAM.LAYER_LIST = []
684
+ # If True, visualize Grad-CAM using true labels for each instances.
685
+ # If False, use the highest predicted class.
686
+ _C.TENSORBOARD.MODEL_VIS.GRAD_CAM.USE_TRUE_LABEL = False
687
+ # Colormap to for text boxes and bounding boxes colors
688
+ _C.TENSORBOARD.MODEL_VIS.GRAD_CAM.COLORMAP = "viridis"
689
+
690
+ # Config for visualization for wrong prediction visualization.
691
+ # _C.TENSORBOARD.ENABLE must be True.
692
+ _C.TENSORBOARD.WRONG_PRED_VIS = CfgNode()
693
+ _C.TENSORBOARD.WRONG_PRED_VIS.ENABLE = False
694
+ # Folder tag to origanize model eval videos under.
695
+ _C.TENSORBOARD.WRONG_PRED_VIS.TAG = "Incorrectly classified videos."
696
+ # Subset of labels to visualize. Only wrong predictions with true labels
697
+ # within this subset is visualized.
698
+ _C.TENSORBOARD.WRONG_PRED_VIS.SUBSET_PATH = ""
699
+
700
+
701
+ # ---------------------------------------------------------------------------- #
702
+ # Demo options
703
+ # ---------------------------------------------------------------------------- #
704
+ _C.DEMO = CfgNode()
705
+
706
+ # Run model in DEMO mode.
707
+ _C.DEMO.ENABLE = False
708
+
709
+ # Path to a json file providing class_name - id mapping
710
+ # in the format {"class_name1": id1, "class_name2": id2, ...}.
711
+ _C.DEMO.LABEL_FILE_PATH = ""
712
+
713
+ # Specify a camera device as input. This will be prioritized
714
+ # over input video if set.
715
+ # If -1, use input video instead.
716
+ _C.DEMO.WEBCAM = -1
717
+
718
+ # Path to input video for demo.
719
+ _C.DEMO.INPUT_VIDEO = ""
720
+ # Custom width for reading input video data.
721
+ _C.DEMO.DISPLAY_WIDTH = 0
722
+ # Custom height for reading input video data.
723
+ _C.DEMO.DISPLAY_HEIGHT = 0
724
+ # Path to Detectron2 object detection model configuration,
725
+ # only used for detection tasks.
726
+ _C.DEMO.DETECTRON2_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
727
+ # Path to Detectron2 object detection model pre-trained weights.
728
+ _C.DEMO.DETECTRON2_WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
729
+ # Threshold for choosing predicted bounding boxes by Detectron2.
730
+ _C.DEMO.DETECTRON2_THRESH = 0.9
731
+ # Number of overlapping frames between 2 consecutive clips.
732
+ # Increase this number for more frequent action predictions.
733
+ # The number of overlapping frames cannot be larger than
734
+ # half of the sequence length `cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE`
735
+ _C.DEMO.BUFFER_SIZE = 0
736
+ # If specified, the visualized outputs will be written this a video file of
737
+ # this path. Otherwise, the visualized outputs will be displayed in a window.
738
+ _C.DEMO.OUTPUT_FILE = ""
739
+ # Frames per second rate for writing to output video file.
740
+ # If not set (-1), use fps rate from input file.
741
+ _C.DEMO.OUTPUT_FPS = -1
742
+ # Input format from demo video reader ("RGB" or "BGR").
743
+ _C.DEMO.INPUT_FORMAT = "BGR"
744
+ # Draw visualization frames in [keyframe_idx - CLIP_VIS_SIZE, keyframe_idx + CLIP_VIS_SIZE] inclusively.
745
+ _C.DEMO.CLIP_VIS_SIZE = 10
746
+ # Number of processes to run video visualizer.
747
+ _C.DEMO.NUM_VIS_INSTANCES = 2
748
+
749
+ # Path to pre-computed predicted boxes
750
+ _C.DEMO.PREDS_BOXES = ""
751
+ # Whether to run in with multi-threaded video reader.
752
+ _C.DEMO.THREAD_ENABLE = False
753
+ # Take one clip for every `DEMO.NUM_CLIPS_SKIP` + 1 for prediction and visualization.
754
+ # This is used for fast demo speed by reducing the prediction/visualiztion frequency.
755
+ # If -1, take the most recent read clip for visualization. This mode is only supported
756
+ # if `DEMO.THREAD_ENABLE` is set to True.
757
+ _C.DEMO.NUM_CLIPS_SKIP = 0
758
+ # Path to ground-truth boxes and labels (optional)
759
+ _C.DEMO.GT_BOXES = ""
760
+ # The starting second of the video w.r.t bounding boxes file.
761
+ _C.DEMO.STARTING_SECOND = 900
762
+ # Frames per second of the input video/folder of images.
763
+ _C.DEMO.FPS = 30
764
+ # Visualize with top-k predictions or predictions above certain threshold(s).
765
+ # Option: {"thres", "top-k"}
766
+ _C.DEMO.VIS_MODE = "thres"
767
+ # Threshold for common class names.
768
+ _C.DEMO.COMMON_CLASS_THRES = 0.7
769
+ # Theshold for uncommon class names. This will not be
770
+ # used if `_C.DEMO.COMMON_CLASS_NAMES` is empty.
771
+ _C.DEMO.UNCOMMON_CLASS_THRES = 0.3
772
+ # This is chosen based on distribution of examples in
773
+ # each classes in AVA dataset.
774
+ _C.DEMO.COMMON_CLASS_NAMES = [
775
+ "watch (a person)",
776
+ "talk to (e.g., self, a person, a group)",
777
+ "listen to (a person)",
778
+ "touch (an object)",
779
+ "carry/hold (an object)",
780
+ "walk",
781
+ "sit",
782
+ "lie/sleep",
783
+ "bend/bow (at the waist)",
784
+ ]
785
+ # Slow-motion rate for the visualization. The visualized portions of the
786
+ # video will be played `_C.DEMO.SLOWMO` times slower than usual speed.
787
+ _C.DEMO.SLOWMO = 1
788
+
789
+ def _assert_and_infer_cfg(cfg):
790
+ # BN assertions.
791
+ if cfg.BN.USE_PRECISE_STATS:
792
+ assert cfg.BN.NUM_BATCHES_PRECISE >= 0
793
+ # TRAIN assertions.
794
+ assert cfg.TRAIN.CHECKPOINT_TYPE in ["pytorch", "caffe2"]
795
+ assert cfg.TRAIN.BATCH_SIZE % cfg.NUM_GPUS == 0
796
+
797
+ # TEST assertions.
798
+ assert cfg.TEST.CHECKPOINT_TYPE in ["pytorch", "caffe2"]
799
+ assert cfg.TEST.BATCH_SIZE % cfg.NUM_GPUS == 0
800
+ assert cfg.TEST.NUM_SPATIAL_CROPS == 3
801
+
802
+ # RESNET assertions.
803
+ assert cfg.RESNET.NUM_GROUPS > 0
804
+ assert cfg.RESNET.WIDTH_PER_GROUP > 0
805
+ assert cfg.RESNET.WIDTH_PER_GROUP % cfg.RESNET.NUM_GROUPS == 0
806
+
807
+ # Execute LR scaling by num_shards.
808
+ if cfg.SOLVER.BASE_LR_SCALE_NUM_SHARDS:
809
+ cfg.SOLVER.BASE_LR *= cfg.NUM_SHARDS
810
+
811
+ # General assertions.
812
+ assert cfg.SHARD_ID < cfg.NUM_SHARDS
813
+ return cfg
814
+
815
+
816
+ def get_cfg():
817
+ """
818
+ Get a copy of the default config.
819
+ """
820
+ return _assert_and_infer_cfg(_C.clone())
TimeSformer/timesformer/datasets/DATASET.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Preparation
2
+
3
+ ## Kinetics
4
+
5
+ The Kinetics Dataset could be downloaded from the following [link](https://github.com/cvdfoundation/kinetics-dataset):
6
+
7
+ After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is:
8
+
9
+ ```
10
+ path_to_video_1 label_1
11
+ path_to_video_2 label_2
12
+ path_to_video_3 label_3
13
+ ...
14
+ path_to_video_N label_N
15
+ ```
16
+
17
+ ## Something-Something V2
18
+ 1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something).
19
+
20
+ 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)).
21
+
22
+ 3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command
23
+ `ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"`
24
+ in experiments.) Please put the frames in a structure consistent with the frame lists.
25
+
26
+ Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames.
TimeSformer/timesformer/datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from .build import DATASET_REGISTRY, build_dataset # noqa
4
+ from .kinetics import Kinetics # noqa
5
+ from .ssv2 import Ssv2 # noqa
TimeSformer/timesformer/datasets/build.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from fvcore.common.registry import Registry
4
+
5
+ DATASET_REGISTRY = Registry("DATASET")
6
+ DATASET_REGISTRY.__doc__ = """
7
+ Registry for dataset.
8
+
9
+ The registered object will be called with `obj(cfg, split)`.
10
+ The call should return a `torch.utils.data.Dataset` object.
11
+ """
12
+
13
+
14
+ def build_dataset(dataset_name, cfg, split):
15
+ """
16
+ Build a dataset, defined by `dataset_name`.
17
+ Args:
18
+ dataset_name (str): the name of the dataset to be constructed.
19
+ cfg (CfgNode): configs. Details can be found in
20
+ slowfast/config/defaults.py
21
+ split (str): the split of the data loader. Options include `train`,
22
+ `val`, and `test`.
23
+ Returns:
24
+ Dataset: a constructed dataset specified by dataset_name.
25
+ """
26
+ # Capitalize the the first letter of the dataset_name since the dataset_name
27
+ # in configs may be in lowercase but the name of dataset class should always
28
+ # start with an uppercase letter.
29
+ name = dataset_name.capitalize()
30
+ return DATASET_REGISTRY.get(name)(cfg, split)
TimeSformer/timesformer/datasets/cv2_transform.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import math
4
+ import numpy as np
5
+ import cv2
6
+
7
+
8
+ def clip_boxes_to_image(boxes, height, width):
9
+ """
10
+ Clip the boxes with the height and width of the image size.
11
+ Args:
12
+ boxes (ndarray): bounding boxes to peform crop. The dimension is
13
+ `num boxes` x 4.
14
+ height (int): the height of the image.
15
+ width (int): the width of the image.
16
+ Returns:
17
+ boxes (ndarray): cropped bounding boxes.
18
+ """
19
+ boxes[:, [0, 2]] = np.minimum(
20
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
21
+ )
22
+ boxes[:, [1, 3]] = np.minimum(
23
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
24
+ )
25
+ return boxes
26
+
27
+
28
+ def random_short_side_scale_jitter_list(images, min_size, max_size, boxes=None):
29
+ """
30
+ Perform a spatial short scale jittering on the given images and
31
+ corresponding boxes.
32
+ Args:
33
+ images (list): list of images to perform scale jitter. Dimension is
34
+ `height` x `width` x `channel`.
35
+ min_size (int): the minimal size to scale the frames.
36
+ max_size (int): the maximal size to scale the frames.
37
+ boxes (list): optional. Corresponding boxes to images. Dimension is
38
+ `num boxes` x 4.
39
+ Returns:
40
+ (list): the list of scaled images with dimension of
41
+ `new height` x `new width` x `channel`.
42
+ (ndarray or None): the scaled boxes with dimension of
43
+ `num boxes` x 4.
44
+ """
45
+ size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)))
46
+
47
+ height = images[0].shape[0]
48
+ width = images[0].shape[1]
49
+ if (width <= height and width == size) or (
50
+ height <= width and height == size
51
+ ):
52
+ return images, boxes
53
+ new_width = size
54
+ new_height = size
55
+ if width < height:
56
+ new_height = int(math.floor((float(height) / width) * size))
57
+ if boxes is not None:
58
+ boxes = [
59
+ proposal * float(new_height) / height for proposal in boxes
60
+ ]
61
+ else:
62
+ new_width = int(math.floor((float(width) / height) * size))
63
+ if boxes is not None:
64
+ boxes = [proposal * float(new_width) / width for proposal in boxes]
65
+ return (
66
+ [
67
+ cv2.resize(
68
+ image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
69
+ ).astype(np.float32)
70
+ for image in images
71
+ ],
72
+ boxes,
73
+ )
74
+
75
+
76
+ def scale(size, image):
77
+ """
78
+ Scale the short side of the image to size.
79
+ Args:
80
+ size (int): size to scale the image.
81
+ image (array): image to perform short side scale. Dimension is
82
+ `height` x `width` x `channel`.
83
+ Returns:
84
+ (ndarray): the scaled image with dimension of
85
+ `height` x `width` x `channel`.
86
+ """
87
+ height = image.shape[0]
88
+ width = image.shape[1]
89
+ if (width <= height and width == size) or (
90
+ height <= width and height == size
91
+ ):
92
+ return image
93
+ new_width = size
94
+ new_height = size
95
+ if width < height:
96
+ new_height = int(math.floor((float(height) / width) * size))
97
+ else:
98
+ new_width = int(math.floor((float(width) / height) * size))
99
+ img = cv2.resize(
100
+ image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
101
+ )
102
+ return img.astype(np.float32)
103
+
104
+
105
+ def scale_boxes(size, boxes, height, width):
106
+ """
107
+ Scale the short side of the box to size.
108
+ Args:
109
+ size (int): size to scale the image.
110
+ boxes (ndarray): bounding boxes to peform scale. The dimension is
111
+ `num boxes` x 4.
112
+ height (int): the height of the image.
113
+ width (int): the width of the image.
114
+ Returns:
115
+ boxes (ndarray): scaled bounding boxes.
116
+ """
117
+ if (width <= height and width == size) or (
118
+ height <= width and height == size
119
+ ):
120
+ return boxes
121
+
122
+ new_width = size
123
+ new_height = size
124
+ if width < height:
125
+ new_height = int(math.floor((float(height) / width) * size))
126
+ boxes *= float(new_height) / height
127
+ else:
128
+ new_width = int(math.floor((float(width) / height) * size))
129
+ boxes *= float(new_width) / width
130
+ return boxes
131
+
132
+
133
+ def horizontal_flip_list(prob, images, order="CHW", boxes=None):
134
+ """
135
+ Horizontally flip the list of image and optional boxes.
136
+ Args:
137
+ prob (float): probability to flip.
138
+ image (list): ilist of images to perform short side scale. Dimension is
139
+ `height` x `width` x `channel` or `channel` x `height` x `width`.
140
+ order (str): order of the `height`, `channel` and `width`.
141
+ boxes (list): optional. Corresponding boxes to images.
142
+ Dimension is `num boxes` x 4.
143
+ Returns:
144
+ (ndarray): the scaled image with dimension of
145
+ `height` x `width` x `channel`.
146
+ (list): optional. Corresponding boxes to images. Dimension is
147
+ `num boxes` x 4.
148
+ """
149
+ _, width, _ = images[0].shape
150
+ if np.random.uniform() < prob:
151
+ if boxes is not None:
152
+ boxes = [flip_boxes(proposal, width) for proposal in boxes]
153
+ if order == "CHW":
154
+ out_images = []
155
+ for image in images:
156
+ image = np.asarray(image).swapaxes(2, 0)
157
+ image = image[::-1]
158
+ out_images.append(image.swapaxes(0, 2))
159
+ return out_images, boxes
160
+ elif order == "HWC":
161
+ return [cv2.flip(image, 1) for image in images], boxes
162
+ return images, boxes
163
+
164
+
165
+ def spatial_shift_crop_list(size, images, spatial_shift_pos, boxes=None):
166
+ """
167
+ Perform left, center, or right crop of the given list of images.
168
+ Args:
169
+ size (int): size to crop.
170
+ image (list): ilist of images to perform short side scale. Dimension is
171
+ `height` x `width` x `channel` or `channel` x `height` x `width`.
172
+ spatial_shift_pos (int): option includes 0 (left), 1 (middle), and
173
+ 2 (right) crop.
174
+ boxes (list): optional. Corresponding boxes to images.
175
+ Dimension is `num boxes` x 4.
176
+ Returns:
177
+ cropped (ndarray): the cropped list of images with dimension of
178
+ `height` x `width` x `channel`.
179
+ boxes (list): optional. Corresponding boxes to images. Dimension is
180
+ `num boxes` x 4.
181
+ """
182
+
183
+ assert spatial_shift_pos in [0, 1, 2]
184
+
185
+ height = images[0].shape[0]
186
+ width = images[0].shape[1]
187
+ y_offset = int(math.ceil((height - size) / 2))
188
+ x_offset = int(math.ceil((width - size) / 2))
189
+
190
+ if height > width:
191
+ if spatial_shift_pos == 0:
192
+ y_offset = 0
193
+ elif spatial_shift_pos == 2:
194
+ y_offset = height - size
195
+ else:
196
+ if spatial_shift_pos == 0:
197
+ x_offset = 0
198
+ elif spatial_shift_pos == 2:
199
+ x_offset = width - size
200
+
201
+ cropped = [
202
+ image[y_offset : y_offset + size, x_offset : x_offset + size, :]
203
+ for image in images
204
+ ]
205
+ assert cropped[0].shape[0] == size, "Image height not cropped properly"
206
+ assert cropped[0].shape[1] == size, "Image width not cropped properly"
207
+
208
+ if boxes is not None:
209
+ for i in range(len(boxes)):
210
+ boxes[i][:, [0, 2]] -= x_offset
211
+ boxes[i][:, [1, 3]] -= y_offset
212
+ return cropped, boxes
213
+
214
+
215
+ def CHW2HWC(image):
216
+ """
217
+ Transpose the dimension from `channel` x `height` x `width` to
218
+ `height` x `width` x `channel`.
219
+ Args:
220
+ image (array): image to transpose.
221
+ Returns
222
+ (array): transposed image.
223
+ """
224
+ return image.transpose([1, 2, 0])
225
+
226
+
227
+ def HWC2CHW(image):
228
+ """
229
+ Transpose the dimension from `height` x `width` x `channel` to
230
+ `channel` x `height` x `width`.
231
+ Args:
232
+ image (array): image to transpose.
233
+ Returns
234
+ (array): transposed image.
235
+ """
236
+ return image.transpose([2, 0, 1])
237
+
238
+
239
+ def color_jitter_list(
240
+ images, img_brightness=0, img_contrast=0, img_saturation=0
241
+ ):
242
+ """
243
+ Perform color jitter on the list of images.
244
+ Args:
245
+ images (list): list of images to perform color jitter.
246
+ img_brightness (float): jitter ratio for brightness.
247
+ img_contrast (float): jitter ratio for contrast.
248
+ img_saturation (float): jitter ratio for saturation.
249
+ Returns:
250
+ images (list): the jittered list of images.
251
+ """
252
+ jitter = []
253
+ if img_brightness != 0:
254
+ jitter.append("brightness")
255
+ if img_contrast != 0:
256
+ jitter.append("contrast")
257
+ if img_saturation != 0:
258
+ jitter.append("saturation")
259
+
260
+ if len(jitter) > 0:
261
+ order = np.random.permutation(np.arange(len(jitter)))
262
+ for idx in range(0, len(jitter)):
263
+ if jitter[order[idx]] == "brightness":
264
+ images = brightness_list(img_brightness, images)
265
+ elif jitter[order[idx]] == "contrast":
266
+ images = contrast_list(img_contrast, images)
267
+ elif jitter[order[idx]] == "saturation":
268
+ images = saturation_list(img_saturation, images)
269
+ return images
270
+
271
+
272
+ def lighting_list(imgs, alphastd, eigval, eigvec, alpha=None):
273
+ """
274
+ Perform AlexNet-style PCA jitter on the given list of images.
275
+ Args:
276
+ images (list): list of images to perform lighting jitter.
277
+ alphastd (float): jitter ratio for PCA jitter.
278
+ eigval (list): eigenvalues for PCA jitter.
279
+ eigvec (list[list]): eigenvectors for PCA jitter.
280
+ Returns:
281
+ out_images (list): the list of jittered images.
282
+ """
283
+ if alphastd == 0:
284
+ return imgs
285
+ # generate alpha1, alpha2, alpha3
286
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
287
+ eig_vec = np.array(eigvec)
288
+ eig_val = np.reshape(eigval, (1, 3))
289
+ rgb = np.sum(
290
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
291
+ axis=1,
292
+ )
293
+ out_images = []
294
+ for img in imgs:
295
+ for idx in range(img.shape[0]):
296
+ img[idx] = img[idx] + rgb[2 - idx]
297
+ out_images.append(img)
298
+ return out_images
299
+
300
+
301
+ def color_normalization(image, mean, stddev):
302
+ """
303
+ Perform color normalization on the image with the given mean and stddev.
304
+ Args:
305
+ image (array): image to perform color normalization.
306
+ mean (float): mean value to subtract.
307
+ stddev (float): stddev to devide.
308
+ """
309
+ # Input image should in format of CHW
310
+ assert len(mean) == image.shape[0], "channel mean not computed properly"
311
+ assert len(stddev) == image.shape[0], "channel stddev not computed properly"
312
+ for idx in range(image.shape[0]):
313
+ image[idx] = image[idx] - mean[idx]
314
+ image[idx] = image[idx] / stddev[idx]
315
+ return image
316
+
317
+
318
+ def pad_image(image, pad_size, order="CHW"):
319
+ """
320
+ Pad the given image with the size of pad_size.
321
+ Args:
322
+ image (array): image to pad.
323
+ pad_size (int): size to pad.
324
+ order (str): order of the `height`, `channel` and `width`.
325
+ Returns:
326
+ img (array): padded image.
327
+ """
328
+ if order == "CHW":
329
+ img = np.pad(
330
+ image,
331
+ ((0, 0), (pad_size, pad_size), (pad_size, pad_size)),
332
+ mode=str("constant"),
333
+ )
334
+ elif order == "HWC":
335
+ img = np.pad(
336
+ image,
337
+ ((pad_size, pad_size), (pad_size, pad_size), (0, 0)),
338
+ mode=str("constant"),
339
+ )
340
+ return img
341
+
342
+
343
+ def horizontal_flip(prob, image, order="CHW"):
344
+ """
345
+ Horizontally flip the image.
346
+ Args:
347
+ prob (float): probability to flip.
348
+ image (array): image to pad.
349
+ order (str): order of the `height`, `channel` and `width`.
350
+ Returns:
351
+ img (array): flipped image.
352
+ """
353
+ assert order in ["CHW", "HWC"], "order {} is not supported".format(order)
354
+ if np.random.uniform() < prob:
355
+ if order == "CHW":
356
+ image = image[:, :, ::-1]
357
+ elif order == "HWC":
358
+ image = image[:, ::-1, :]
359
+ else:
360
+ raise NotImplementedError("Unknown order {}".format(order))
361
+ return image
362
+
363
+
364
+ def flip_boxes(boxes, im_width):
365
+ """
366
+ Horizontally flip the boxes.
367
+ Args:
368
+ boxes (array): box to flip.
369
+ im_width (int): width of the image.
370
+ Returns:
371
+ boxes_flipped (array): flipped box.
372
+ """
373
+
374
+ boxes_flipped = boxes.copy()
375
+ boxes_flipped[:, 0::4] = im_width - boxes[:, 2::4] - 1
376
+ boxes_flipped[:, 2::4] = im_width - boxes[:, 0::4] - 1
377
+ return boxes_flipped
378
+
379
+
380
+ def crop_boxes(boxes, x_offset, y_offset):
381
+ """
382
+ Crop the boxes given the offsets.
383
+ Args:
384
+ boxes (array): boxes to crop.
385
+ x_offset (int): offset on x.
386
+ y_offset (int): offset on y.
387
+ """
388
+ boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
389
+ boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
390
+ return boxes
391
+
392
+
393
+ def random_crop_list(images, size, pad_size=0, order="CHW", boxes=None):
394
+ """
395
+ Perform random crop on a list of images.
396
+ Args:
397
+ images (list): list of images to perform random crop.
398
+ size (int): size to crop.
399
+ pad_size (int): padding size.
400
+ order (str): order of the `height`, `channel` and `width`.
401
+ boxes (list): optional. Corresponding boxes to images.
402
+ Dimension is `num boxes` x 4.
403
+ Returns:
404
+ cropped (ndarray): the cropped list of images with dimension of
405
+ `height` x `width` x `channel`.
406
+ boxes (list): optional. Corresponding boxes to images. Dimension is
407
+ `num boxes` x 4.
408
+ """
409
+ # explicitly dealing processing per image order to avoid flipping images.
410
+ if pad_size > 0:
411
+ images = [
412
+ pad_image(pad_size=pad_size, image=image, order=order)
413
+ for image in images
414
+ ]
415
+
416
+ # image format should be CHW.
417
+ if order == "CHW":
418
+ if images[0].shape[1] == size and images[0].shape[2] == size:
419
+ return images, boxes
420
+ height = images[0].shape[1]
421
+ width = images[0].shape[2]
422
+ y_offset = 0
423
+ if height > size:
424
+ y_offset = int(np.random.randint(0, height - size))
425
+ x_offset = 0
426
+ if width > size:
427
+ x_offset = int(np.random.randint(0, width - size))
428
+ cropped = [
429
+ image[:, y_offset : y_offset + size, x_offset : x_offset + size]
430
+ for image in images
431
+ ]
432
+ assert cropped[0].shape[1] == size, "Image not cropped properly"
433
+ assert cropped[0].shape[2] == size, "Image not cropped properly"
434
+ elif order == "HWC":
435
+ if images[0].shape[0] == size and images[0].shape[1] == size:
436
+ return images, boxes
437
+ height = images[0].shape[0]
438
+ width = images[0].shape[1]
439
+ y_offset = 0
440
+ if height > size:
441
+ y_offset = int(np.random.randint(0, height - size))
442
+ x_offset = 0
443
+ if width > size:
444
+ x_offset = int(np.random.randint(0, width - size))
445
+ cropped = [
446
+ image[y_offset : y_offset + size, x_offset : x_offset + size, :]
447
+ for image in images
448
+ ]
449
+ assert cropped[0].shape[0] == size, "Image not cropped properly"
450
+ assert cropped[0].shape[1] == size, "Image not cropped properly"
451
+
452
+ if boxes is not None:
453
+ boxes = [crop_boxes(proposal, x_offset, y_offset) for proposal in boxes]
454
+ return cropped, boxes
455
+
456
+
457
+ def center_crop(size, image):
458
+ """
459
+ Perform center crop on input images.
460
+ Args:
461
+ size (int): size of the cropped height and width.
462
+ image (array): the image to perform center crop.
463
+ """
464
+ height = image.shape[0]
465
+ width = image.shape[1]
466
+ y_offset = int(math.ceil((height - size) / 2))
467
+ x_offset = int(math.ceil((width - size) / 2))
468
+ cropped = image[y_offset : y_offset + size, x_offset : x_offset + size, :]
469
+ assert cropped.shape[0] == size, "Image height not cropped properly"
470
+ assert cropped.shape[1] == size, "Image width not cropped properly"
471
+ return cropped
472
+
473
+
474
+ # ResNet style scale jittering: randomly select the scale from
475
+ # [1/max_size, 1/min_size]
476
+ def random_scale_jitter(image, min_size, max_size):
477
+ """
478
+ Perform ResNet style random scale jittering: randomly select the scale from
479
+ [1/max_size, 1/min_size].
480
+ Args:
481
+ image (array): image to perform random scale.
482
+ min_size (int): min size to scale.
483
+ max_size (int) max size to scale.
484
+ Returns:
485
+ image (array): scaled image.
486
+ """
487
+ img_scale = int(
488
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
489
+ )
490
+ image = scale(img_scale, image)
491
+ return image
492
+
493
+
494
+ def random_scale_jitter_list(images, min_size, max_size):
495
+ """
496
+ Perform ResNet style random scale jittering on a list of image: randomly
497
+ select the scale from [1/max_size, 1/min_size]. Note that all the image
498
+ will share the same scale.
499
+ Args:
500
+ images (list): list of images to perform random scale.
501
+ min_size (int): min size to scale.
502
+ max_size (int) max size to scale.
503
+ Returns:
504
+ images (list): list of scaled image.
505
+ """
506
+ img_scale = int(
507
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
508
+ )
509
+ return [scale(img_scale, image) for image in images]
510
+
511
+
512
+ def random_sized_crop(image, size, area_frac=0.08):
513
+ """
514
+ Perform random sized cropping on the given image. Random crop with size
515
+ 8% - 100% image area and aspect ratio in [3/4, 4/3].
516
+ Args:
517
+ image (array): image to crop.
518
+ size (int): size to crop.
519
+ area_frac (float): area of fraction.
520
+ Returns:
521
+ (array): cropped image.
522
+ """
523
+ for _ in range(0, 10):
524
+ height = image.shape[0]
525
+ width = image.shape[1]
526
+ area = height * width
527
+ target_area = np.random.uniform(area_frac, 1.0) * area
528
+ aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
529
+ w = int(round(math.sqrt(float(target_area) * aspect_ratio)))
530
+ h = int(round(math.sqrt(float(target_area) / aspect_ratio)))
531
+ if np.random.uniform() < 0.5:
532
+ w, h = h, w
533
+ if h <= height and w <= width:
534
+ if height == h:
535
+ y_offset = 0
536
+ else:
537
+ y_offset = np.random.randint(0, height - h)
538
+ if width == w:
539
+ x_offset = 0
540
+ else:
541
+ x_offset = np.random.randint(0, width - w)
542
+ y_offset = int(y_offset)
543
+ x_offset = int(x_offset)
544
+ cropped = image[y_offset : y_offset + h, x_offset : x_offset + w, :]
545
+ assert (
546
+ cropped.shape[0] == h and cropped.shape[1] == w
547
+ ), "Wrong crop size"
548
+ cropped = cv2.resize(
549
+ cropped, (size, size), interpolation=cv2.INTER_LINEAR
550
+ )
551
+ return cropped.astype(np.float32)
552
+ return center_crop(size, scale(size, image))
553
+
554
+
555
+ def lighting(img, alphastd, eigval, eigvec):
556
+ """
557
+ Perform AlexNet-style PCA jitter on the given image.
558
+ Args:
559
+ image (array): list of images to perform lighting jitter.
560
+ alphastd (float): jitter ratio for PCA jitter.
561
+ eigval (array): eigenvalues for PCA jitter.
562
+ eigvec (list): eigenvectors for PCA jitter.
563
+ Returns:
564
+ img (tensor): the jittered image.
565
+ """
566
+ if alphastd == 0:
567
+ return img
568
+ # generate alpha1, alpha2, alpha3.
569
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
570
+ eig_vec = np.array(eigvec)
571
+ eig_val = np.reshape(eigval, (1, 3))
572
+ rgb = np.sum(
573
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
574
+ axis=1,
575
+ )
576
+ for idx in range(img.shape[0]):
577
+ img[idx] = img[idx] + rgb[2 - idx]
578
+ return img
579
+
580
+
581
+ def random_sized_crop_list(images, size, crop_area_fraction=0.08):
582
+ """
583
+ Perform random sized cropping on the given list of images. Random crop with
584
+ size 8% - 100% image area and aspect ratio in [3/4, 4/3].
585
+ Args:
586
+ images (list): image to crop.
587
+ size (int): size to crop.
588
+ area_frac (float): area of fraction.
589
+ Returns:
590
+ (list): list of cropped image.
591
+ """
592
+ for _ in range(0, 10):
593
+ height = images[0].shape[0]
594
+ width = images[0].shape[1]
595
+ area = height * width
596
+ target_area = np.random.uniform(crop_area_fraction, 1.0) * area
597
+ aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
598
+ w = int(round(math.sqrt(float(target_area) * aspect_ratio)))
599
+ h = int(round(math.sqrt(float(target_area) / aspect_ratio)))
600
+ if np.random.uniform() < 0.5:
601
+ w, h = h, w
602
+ if h <= height and w <= width:
603
+ if height == h:
604
+ y_offset = 0
605
+ else:
606
+ y_offset = np.random.randint(0, height - h)
607
+ if width == w:
608
+ x_offset = 0
609
+ else:
610
+ x_offset = np.random.randint(0, width - w)
611
+ y_offset = int(y_offset)
612
+ x_offset = int(x_offset)
613
+
614
+ croppsed_images = []
615
+ for image in images:
616
+ cropped = image[
617
+ y_offset : y_offset + h, x_offset : x_offset + w, :
618
+ ]
619
+ assert (
620
+ cropped.shape[0] == h and cropped.shape[1] == w
621
+ ), "Wrong crop size"
622
+ cropped = cv2.resize(
623
+ cropped, (size, size), interpolation=cv2.INTER_LINEAR
624
+ )
625
+ croppsed_images.append(cropped.astype(np.float32))
626
+ return croppsed_images
627
+
628
+ return [center_crop(size, scale(size, image)) for image in images]
629
+
630
+
631
+ def blend(image1, image2, alpha):
632
+ return image1 * alpha + image2 * (1 - alpha)
633
+
634
+
635
+ def grayscale(image):
636
+ """
637
+ Convert the image to gray scale.
638
+ Args:
639
+ image (tensor): image to convert to gray scale. Dimension is
640
+ `channel` x `height` x `width`.
641
+ Returns:
642
+ img_gray (tensor): image in gray scale.
643
+ """
644
+ # R -> 0.299, G -> 0.587, B -> 0.114.
645
+ img_gray = np.copy(image)
646
+ gray_channel = 0.299 * image[2] + 0.587 * image[1] + 0.114 * image[0]
647
+ img_gray[0] = gray_channel
648
+ img_gray[1] = gray_channel
649
+ img_gray[2] = gray_channel
650
+ return img_gray
651
+
652
+
653
+ def saturation(var, image):
654
+ """
655
+ Perform color saturation on the given image.
656
+ Args:
657
+ var (float): variance.
658
+ image (array): image to perform color saturation.
659
+ Returns:
660
+ (array): image that performed color saturation.
661
+ """
662
+ img_gray = grayscale(image)
663
+ alpha = 1.0 + np.random.uniform(-var, var)
664
+ return blend(image, img_gray, alpha)
665
+
666
+
667
+ def brightness(var, image):
668
+ """
669
+ Perform color brightness on the given image.
670
+ Args:
671
+ var (float): variance.
672
+ image (array): image to perform color brightness.
673
+ Returns:
674
+ (array): image that performed color brightness.
675
+ """
676
+ img_bright = np.zeros(image.shape).astype(image.dtype)
677
+ alpha = 1.0 + np.random.uniform(-var, var)
678
+ return blend(image, img_bright, alpha)
679
+
680
+
681
+ def contrast(var, image):
682
+ """
683
+ Perform color contrast on the given image.
684
+ Args:
685
+ var (float): variance.
686
+ image (array): image to perform color contrast.
687
+ Returns:
688
+ (array): image that performed color contrast.
689
+ """
690
+ img_gray = grayscale(image)
691
+ img_gray.fill(np.mean(img_gray[0]))
692
+ alpha = 1.0 + np.random.uniform(-var, var)
693
+ return blend(image, img_gray, alpha)
694
+
695
+
696
+ def saturation_list(var, images):
697
+ """
698
+ Perform color saturation on the list of given images.
699
+ Args:
700
+ var (float): variance.
701
+ images (list): list of images to perform color saturation.
702
+ Returns:
703
+ (list): list of images that performed color saturation.
704
+ """
705
+ alpha = 1.0 + np.random.uniform(-var, var)
706
+
707
+ out_images = []
708
+ for image in images:
709
+ img_gray = grayscale(image)
710
+ out_images.append(blend(image, img_gray, alpha))
711
+ return out_images
712
+
713
+
714
+ def brightness_list(var, images):
715
+ """
716
+ Perform color brightness on the given list of images.
717
+ Args:
718
+ var (float): variance.
719
+ images (list): list of images to perform color brightness.
720
+ Returns:
721
+ (array): list of images that performed color brightness.
722
+ """
723
+ alpha = 1.0 + np.random.uniform(-var, var)
724
+
725
+ out_images = []
726
+ for image in images:
727
+ img_bright = np.zeros(image.shape).astype(image.dtype)
728
+ out_images.append(blend(image, img_bright, alpha))
729
+ return out_images
730
+
731
+
732
+ def contrast_list(var, images):
733
+ """
734
+ Perform color contrast on the given list of images.
735
+ Args:
736
+ var (float): variance.
737
+ images (list): list of images to perform color contrast.
738
+ Returns:
739
+ (array): image that performed color contrast.
740
+ """
741
+ alpha = 1.0 + np.random.uniform(-var, var)
742
+
743
+ out_images = []
744
+ for image in images:
745
+ img_gray = grayscale(image)
746
+ img_gray.fill(np.mean(img_gray[0]))
747
+ out_images.append(blend(image, img_gray, alpha))
748
+ return out_images
749
+
750
+
751
+ def color_jitter(image, img_brightness=0, img_contrast=0, img_saturation=0):
752
+ """
753
+ Perform color jitter on the given image.
754
+ Args:
755
+ image (array): image to perform color jitter.
756
+ img_brightness (float): jitter ratio for brightness.
757
+ img_contrast (float): jitter ratio for contrast.
758
+ img_saturation (float): jitter ratio for saturation.
759
+ Returns:
760
+ image (array): the jittered image.
761
+ """
762
+ jitter = []
763
+ if img_brightness != 0:
764
+ jitter.append("brightness")
765
+ if img_contrast != 0:
766
+ jitter.append("contrast")
767
+ if img_saturation != 0:
768
+ jitter.append("saturation")
769
+
770
+ if len(jitter) > 0:
771
+ order = np.random.permutation(np.arange(len(jitter)))
772
+ for idx in range(0, len(jitter)):
773
+ if jitter[order[idx]] == "brightness":
774
+ image = brightness(img_brightness, image)
775
+ elif jitter[order[idx]] == "contrast":
776
+ image = contrast(img_contrast, image)
777
+ elif jitter[order[idx]] == "saturation":
778
+ image = saturation(img_saturation, image)
779
+ return image
780
+
781
+
782
+ def revert_scaled_boxes(size, boxes, img_height, img_width):
783
+ """
784
+ Revert scaled input boxes to match the original image size.
785
+ Args:
786
+ size (int): size of the cropped image.
787
+ boxes (array): shape (num_boxes, 4).
788
+ img_height (int): height of original image.
789
+ img_width (int): width of original image.
790
+ Returns:
791
+ reverted_boxes (array): boxes scaled back to the original image size.
792
+ """
793
+ scaled_aspect = np.min([img_height, img_width])
794
+ scale_ratio = scaled_aspect / size
795
+ reverted_boxes = boxes * scale_ratio
796
+ return reverted_boxes
TimeSformer/timesformer/datasets/decoder.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import math
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+ import torchvision.io as io
8
+
9
+
10
+ def temporal_sampling(frames, start_idx, end_idx, num_samples):
11
+ """
12
+ Given the start and end frame index, sample num_samples frames between
13
+ the start and end with equal interval.
14
+ Args:
15
+ frames (tensor): a tensor of video frames, dimension is
16
+ `num video frames` x `channel` x `height` x `width`.
17
+ start_idx (int): the index of the start frame.
18
+ end_idx (int): the index of the end frame.
19
+ num_samples (int): number of frames to sample.
20
+ Returns:
21
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
22
+ `num clip frames` x `channel` x `height` x `width`.
23
+ """
24
+ index = torch.linspace(start_idx, end_idx, num_samples)
25
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
26
+ frames = torch.index_select(frames, 0, index)
27
+ return frames
28
+
29
+
30
+ def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
31
+ """
32
+ Sample a clip of size clip_size from a video of size video_size and
33
+ return the indices of the first and last frame of the clip. If clip_idx is
34
+ -1, the clip is randomly sampled, otherwise uniformly split the video to
35
+ num_clips clips, and select the start and end index of clip_idx-th video
36
+ clip.
37
+ Args:
38
+ video_size (int): number of overall frames.
39
+ clip_size (int): size of the clip to sample from the frames.
40
+ clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
41
+ clip_idx is larger than -1, uniformly split the video to num_clips
42
+ clips, and select the start and end index of the clip_idx-th video
43
+ clip.
44
+ num_clips (int): overall number of clips to uniformly sample from the
45
+ given video for testing.
46
+ Returns:
47
+ start_idx (int): the start frame index.
48
+ end_idx (int): the end frame index.
49
+ """
50
+ delta = max(video_size - clip_size, 0)
51
+ if clip_idx == -1:
52
+ # Random temporal sampling.
53
+ start_idx = random.uniform(0, delta)
54
+ else:
55
+ # Uniformly sample the clip with the given index.
56
+ start_idx = delta * clip_idx / num_clips
57
+ end_idx = start_idx + clip_size - 1
58
+ return start_idx, end_idx
59
+
60
+
61
+ def pyav_decode_stream(
62
+ container, start_pts, end_pts, stream, stream_name, buffer_size=0
63
+ ):
64
+ """
65
+ Decode the video with PyAV decoder.
66
+ Args:
67
+ container (container): PyAV container.
68
+ start_pts (int): the starting Presentation TimeStamp to fetch the
69
+ video frames.
70
+ end_pts (int): the ending Presentation TimeStamp of the decoded frames.
71
+ stream (stream): PyAV stream.
72
+ stream_name (dict): a dictionary of streams. For example, {"video": 0}
73
+ means video stream at stream index 0.
74
+ buffer_size (int): number of additional frames to decode beyond end_pts.
75
+ Returns:
76
+ result (list): list of frames decoded.
77
+ max_pts (int): max Presentation TimeStamp of the video sequence.
78
+ """
79
+ # Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
80
+ # margin pts.
81
+ margin = 1024
82
+ seek_offset = max(start_pts - margin, 0)
83
+
84
+ container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
85
+ frames = {}
86
+ buffer_count = 0
87
+ max_pts = 0
88
+ for frame in container.decode(**stream_name):
89
+ max_pts = max(max_pts, frame.pts)
90
+ if frame.pts < start_pts:
91
+ continue
92
+ if frame.pts <= end_pts:
93
+ frames[frame.pts] = frame
94
+ else:
95
+ buffer_count += 1
96
+ frames[frame.pts] = frame
97
+ if buffer_count >= buffer_size:
98
+ break
99
+ result = [frames[pts] for pts in sorted(frames)]
100
+ return result, max_pts
101
+
102
+
103
+ def torchvision_decode(
104
+ video_handle,
105
+ sampling_rate,
106
+ num_frames,
107
+ clip_idx,
108
+ video_meta,
109
+ num_clips=10,
110
+ target_fps=30,
111
+ modalities=("visual",),
112
+ max_spatial_scale=0,
113
+ ):
114
+ """
115
+ If video_meta is not empty, perform temporal selective decoding to sample a
116
+ clip from the video with TorchVision decoder. If video_meta is empty, decode
117
+ the entire video and update the video_meta.
118
+ Args:
119
+ video_handle (bytes): raw bytes of the video file.
120
+ sampling_rate (int): frame sampling rate (interval between two sampled
121
+ frames).
122
+ num_frames (int): number of frames to sample.
123
+ clip_idx (int): if clip_idx is -1, perform random temporal
124
+ sampling. If clip_idx is larger than -1, uniformly split the
125
+ video to num_clips clips, and select the clip_idx-th video clip.
126
+ video_meta (dict): a dict contains VideoMetaData. Details can be found
127
+ at `pytorch/vision/torchvision/io/_video_opt.py`.
128
+ num_clips (int): overall number of clips to uniformly sample from the
129
+ given video.
130
+ target_fps (int): the input video may has different fps, convert it to
131
+ the target video fps.
132
+ modalities (tuple): tuple of modalities to decode. Currently only
133
+ support `visual`, planning to support `acoustic` soon.
134
+ max_spatial_scale (int): the maximal resolution of the spatial shorter
135
+ edge size during decoding.
136
+ Returns:
137
+ frames (tensor): decoded frames from the video.
138
+ fps (float): the number of frames per second of the video.
139
+ decode_all_video (bool): if True, the entire video was decoded.
140
+ """
141
+ # Convert the bytes to a tensor.
142
+ video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
143
+
144
+ decode_all_video = True
145
+ video_start_pts, video_end_pts = 0, -1
146
+ # The video_meta is empty, fetch the meta data from the raw video.
147
+ if len(video_meta) == 0:
148
+ # Tracking the meta info for selective decoding in the future.
149
+ meta = io._probe_video_from_memory(video_tensor)
150
+ # Using the information from video_meta to perform selective decoding.
151
+ video_meta["video_timebase"] = meta.video_timebase
152
+ video_meta["video_numerator"] = meta.video_timebase.numerator
153
+ video_meta["video_denominator"] = meta.video_timebase.denominator
154
+ video_meta["has_video"] = meta.has_video
155
+ video_meta["video_duration"] = meta.video_duration
156
+ video_meta["video_fps"] = meta.video_fps
157
+ video_meta["audio_timebas"] = meta.audio_timebase
158
+ video_meta["audio_numerator"] = meta.audio_timebase.numerator
159
+ video_meta["audio_denominator"] = meta.audio_timebase.denominator
160
+ video_meta["has_audio"] = meta.has_audio
161
+ video_meta["audio_duration"] = meta.audio_duration
162
+ video_meta["audio_sample_rate"] = meta.audio_sample_rate
163
+
164
+ fps = video_meta["video_fps"]
165
+ if (
166
+ video_meta["has_video"]
167
+ and video_meta["video_denominator"] > 0
168
+ and video_meta["video_duration"] > 0
169
+ ):
170
+ # try selective decoding.
171
+ decode_all_video = False
172
+ clip_size = sampling_rate * num_frames / target_fps * fps
173
+ start_idx, end_idx = get_start_end_idx(
174
+ fps * video_meta["video_duration"], clip_size, clip_idx, num_clips
175
+ )
176
+ # Convert frame index to pts.
177
+ pts_per_frame = video_meta["video_denominator"] / fps
178
+ video_start_pts = int(start_idx * pts_per_frame)
179
+ video_end_pts = int(end_idx * pts_per_frame)
180
+
181
+ # Decode the raw video with the tv decoder.
182
+ v_frames, _ = io._read_video_from_memory(
183
+ video_tensor,
184
+ seek_frame_margin=1.0,
185
+ read_video_stream="visual" in modalities,
186
+ video_width=0,
187
+ video_height=0,
188
+ video_min_dimension=max_spatial_scale,
189
+ video_pts_range=(video_start_pts, video_end_pts),
190
+ video_timebase_numerator=video_meta["video_numerator"],
191
+ video_timebase_denominator=video_meta["video_denominator"],
192
+ )
193
+
194
+ if v_frames.shape == torch.Size([0]):
195
+ # failed selective decoding
196
+ decode_all_video = True
197
+ video_start_pts, video_end_pts = 0, -1
198
+ v_frames, _ = io._read_video_from_memory(
199
+ video_tensor,
200
+ seek_frame_margin=1.0,
201
+ read_video_stream="visual" in modalities,
202
+ video_width=0,
203
+ video_height=0,
204
+ video_min_dimension=max_spatial_scale,
205
+ video_pts_range=(video_start_pts, video_end_pts),
206
+ video_timebase_numerator=video_meta["video_numerator"],
207
+ video_timebase_denominator=video_meta["video_denominator"],
208
+ )
209
+
210
+ return v_frames, fps, decode_all_video
211
+
212
+
213
+ def pyav_decode(
214
+ container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, start=None, end=None
215
+ , duration=None, frames_length=None):
216
+ """
217
+ Convert the video from its original fps to the target_fps. If the video
218
+ support selective decoding (contain decoding information in the video head),
219
+ the perform temporal selective decoding and sample a clip from the video
220
+ with the PyAV decoder. If the video does not support selective decoding,
221
+ decode the entire video.
222
+
223
+ Args:
224
+ container (container): pyav container.
225
+ sampling_rate (int): frame sampling rate (interval between two sampled
226
+ frames.
227
+ num_frames (int): number of frames to sample.
228
+ clip_idx (int): if clip_idx is -1, perform random temporal sampling. If
229
+ clip_idx is larger than -1, uniformly split the video to num_clips
230
+ clips, and select the clip_idx-th video clip.
231
+ num_clips (int): overall number of clips to uniformly sample from the
232
+ given video.
233
+ target_fps (int): the input video may has different fps, convert it to
234
+ the target video fps before frame sampling.
235
+ Returns:
236
+ frames (tensor): decoded frames from the video. Return None if the no
237
+ video stream was found.
238
+ fps (float): the number of frames per second of the video.
239
+ decode_all_video (bool): If True, the entire video was decoded.
240
+ """
241
+ # Try to fetch the decoding information from the video head. Some of the
242
+ # videos does not support fetching the decoding information, for that case
243
+ # it will get None duration.
244
+ fps = float(container.streams.video[0].average_rate)
245
+
246
+ orig_duration = duration
247
+ tb = float(container.streams.video[0].time_base)
248
+ frames_length = container.streams.video[0].frames
249
+ duration = container.streams.video[0].duration
250
+ if duration is None and orig_duration is not None:
251
+ duration = orig_duration / tb
252
+
253
+ if duration is None:
254
+ # If failed to fetch the decoding information, decode the entire video.
255
+ decode_all_video = True
256
+ video_start_pts, video_end_pts = 0, math.inf
257
+ else:
258
+ # Perform selective decoding.
259
+ decode_all_video = False
260
+ start_idx, end_idx = get_start_end_idx(
261
+ frames_length,
262
+ sampling_rate * num_frames / target_fps * fps,
263
+ clip_idx,
264
+ num_clips,
265
+ )
266
+ timebase = duration / frames_length
267
+ video_start_pts = int(start_idx * timebase)
268
+ video_end_pts = int(end_idx * timebase)
269
+
270
+ if start is not None and end is not None:
271
+ decode_all_video = False
272
+
273
+ frames = None
274
+ # If video stream was found, fetch video frames from the video.
275
+ if container.streams.video:
276
+ if start is None and end is None:
277
+ video_frames, max_pts = pyav_decode_stream(
278
+ container,
279
+ video_start_pts,
280
+ video_end_pts,
281
+ container.streams.video[0],
282
+ {"video": 0},
283
+ )
284
+ else:
285
+ timebase = duration / frames_length
286
+ start_i = start
287
+ end_i = end
288
+ video_frames, max_pts = pyav_decode_stream(
289
+ container,
290
+ start_i,
291
+ end_i,
292
+ container.streams.video[0],
293
+ {"video": 0},
294
+ )
295
+ container.close()
296
+
297
+ frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
298
+ frames = torch.as_tensor(np.stack(frames))
299
+
300
+ return frames, fps, decode_all_video
301
+
302
+
303
+ def decode(
304
+ container,
305
+ sampling_rate,
306
+ num_frames,
307
+ clip_idx=-1,
308
+ num_clips=10,
309
+ video_meta=None,
310
+ target_fps=30,
311
+ backend="pyav",
312
+ max_spatial_scale=0,
313
+ start=None,
314
+ end=None,
315
+ duration=None,
316
+ frames_length=None,
317
+ ):
318
+ """
319
+ Decode the video and perform temporal sampling.
320
+ Args:
321
+ container (container): pyav container.
322
+ sampling_rate (int): frame sampling rate (interval between two sampled
323
+ frames).
324
+ num_frames (int): number of frames to sample.
325
+ clip_idx (int): if clip_idx is -1, perform random temporal
326
+ sampling. If clip_idx is larger than -1, uniformly split the
327
+ video to num_clips clips, and select the
328
+ clip_idx-th video clip.
329
+ num_clips (int): overall number of clips to uniformly
330
+ sample from the given video.
331
+ video_meta (dict): a dict contains VideoMetaData. Details can be find
332
+ at `pytorch/vision/torchvision/io/_video_opt.py`.
333
+ target_fps (int): the input video may have different fps, convert it to
334
+ the target video fps before frame sampling.
335
+ backend (str): decoding backend includes `pyav` and `torchvision`. The
336
+ default one is `pyav`.
337
+ max_spatial_scale (int): keep the aspect ratio and resize the frame so
338
+ that shorter edge size is max_spatial_scale. Only used in
339
+ `torchvision` backend.
340
+ Returns:
341
+ frames (tensor): decoded frames from the video.
342
+ """
343
+ # Currently support two decoders: 1) PyAV, and 2) TorchVision.
344
+ assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
345
+ try:
346
+ if backend == "pyav":
347
+ frames, fps, decode_all_video = pyav_decode(
348
+ container,
349
+ sampling_rate,
350
+ num_frames,
351
+ clip_idx,
352
+ num_clips,
353
+ target_fps,
354
+ start,
355
+ end,
356
+ duration,
357
+ frames_length,
358
+ )
359
+ elif backend == "torchvision":
360
+ frames, fps, decode_all_video = torchvision_decode(
361
+ container,
362
+ sampling_rate,
363
+ num_frames,
364
+ clip_idx,
365
+ video_meta,
366
+ num_clips,
367
+ target_fps,
368
+ ("visual",),
369
+ max_spatial_scale,
370
+ )
371
+ else:
372
+ raise NotImplementedError(
373
+ "Unknown decoding backend {}".format(backend)
374
+ )
375
+ except Exception as e:
376
+ print("Failed to decode by {} with exception: {}".format(backend, e))
377
+ return None
378
+
379
+ # Return None if the frames was not decoded successfully.
380
+ if frames is None or frames.size(0) == 0:
381
+ return None
382
+
383
+ clip_sz = sampling_rate * num_frames / target_fps * fps
384
+ start_idx, end_idx = get_start_end_idx(
385
+ frames.shape[0],
386
+ clip_sz,
387
+ clip_idx if decode_all_video else 0,
388
+ num_clips if decode_all_video else 1,
389
+ )
390
+ # Perform temporal sampling from the decoded video.
391
+ frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
392
+ return frames
TimeSformer/timesformer/datasets/kinetics.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import os
4
+ import random
5
+ import torch
6
+ import torch.utils.data
7
+ from fvcore.common.file_io import PathManager
8
+
9
+ import timesformer.utils.logging as logging
10
+
11
+ from . import decoder as decoder
12
+ from . import utils as utils
13
+ from . import video_container as container
14
+ from .build import DATASET_REGISTRY
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ @DATASET_REGISTRY.register()
19
+ class Kinetics(torch.utils.data.Dataset):
20
+ """
21
+ Kinetics video loader. Construct the Kinetics video loader, then sample
22
+ clips from the videos. For training and validation, a single clip is
23
+ randomly sampled from every video with random cropping, scaling, and
24
+ flipping. For testing, multiple clips are uniformaly sampled from every
25
+ video with uniform cropping. For uniform cropping, we take the left, center,
26
+ and right crop if the width is larger than height, or take top, center, and
27
+ bottom crop if the height is larger than the width.
28
+ """
29
+
30
+ def __init__(self, cfg, mode, num_retries=10):
31
+ """
32
+ Construct the Kinetics video loader with a given csv file. The format of
33
+ the csv file is:
34
+ ```
35
+ path_to_video_1 label_1
36
+ path_to_video_2 label_2
37
+ ...
38
+ path_to_video_N label_N
39
+ ```
40
+ Args:
41
+ cfg (CfgNode): configs.
42
+ mode (string): Options includes `train`, `val`, or `test` mode.
43
+ For the train and val mode, the data loader will take data
44
+ from the train or val set, and sample one clip per video.
45
+ For the test mode, the data loader will take data from test set,
46
+ and sample multiple clips per video.
47
+ num_retries (int): number of retries.
48
+ """
49
+ # Only support train, val, and test mode.
50
+ assert mode in [
51
+ "train",
52
+ "val",
53
+ "test",
54
+ ], "Split '{}' not supported for Kinetics".format(mode)
55
+ self.mode = mode
56
+ self.cfg = cfg
57
+
58
+ self._video_meta = {}
59
+ self._num_retries = num_retries
60
+ # For training or validation mode, one single clip is sampled from every
61
+ # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
62
+ # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
63
+ # the frames.
64
+ if self.mode in ["train", "val"]:
65
+ self._num_clips = 1
66
+ elif self.mode in ["test"]:
67
+ self._num_clips = (
68
+ cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
69
+ )
70
+
71
+ logger.info("Constructing Kinetics {}...".format(mode))
72
+ self._construct_loader()
73
+
74
+ def _construct_loader(self):
75
+ """
76
+ Construct the video loader.
77
+ """
78
+ path_to_file = os.path.join(
79
+ self.cfg.DATA.PATH_TO_DATA_DIR, "{}.csv".format(self.mode)
80
+ )
81
+ assert PathManager.exists(path_to_file), "{} dir not found".format(
82
+ path_to_file
83
+ )
84
+
85
+ self._path_to_videos = []
86
+ self._labels = []
87
+ self._spatial_temporal_idx = []
88
+ with PathManager.open(path_to_file, "r") as f:
89
+ for clip_idx, path_label in enumerate(f.read().splitlines()):
90
+ assert (
91
+ len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR))
92
+ == 2
93
+ )
94
+ path, label = path_label.split(
95
+ self.cfg.DATA.PATH_LABEL_SEPARATOR
96
+ )
97
+ for idx in range(self._num_clips):
98
+ self._path_to_videos.append(
99
+ os.path.join(self.cfg.DATA.PATH_PREFIX, path)
100
+ )
101
+ self._labels.append(int(label))
102
+ self._spatial_temporal_idx.append(idx)
103
+ self._video_meta[clip_idx * self._num_clips + idx] = {}
104
+ assert (
105
+ len(self._path_to_videos) > 0
106
+ ), "Failed to load Kinetics split {} from {}".format(
107
+ self._split_idx, path_to_file
108
+ )
109
+ logger.info(
110
+ "Constructing kinetics dataloader (size: {}) from {}".format(
111
+ len(self._path_to_videos), path_to_file
112
+ )
113
+ )
114
+
115
+ def __getitem__(self, index):
116
+ """
117
+ Given the video index, return the list of frames, label, and video
118
+ index if the video can be fetched and decoded successfully, otherwise
119
+ repeatly find a random video that can be decoded as a replacement.
120
+ Args:
121
+ index (int): the video index provided by the pytorch sampler.
122
+ Returns:
123
+ frames (tensor): the frames of sampled from the video. The dimension
124
+ is `channel` x `num frames` x `height` x `width`.
125
+ label (int): the label of the current video.
126
+ index (int): if the video provided by pytorch sampler can be
127
+ decoded, then return the index of the video. If not, return the
128
+ index of the video replacement that can be decoded.
129
+ """
130
+ short_cycle_idx = None
131
+ # When short cycle is used, input index is a tupple.
132
+ if isinstance(index, tuple):
133
+ index, short_cycle_idx = index
134
+
135
+ if self.mode in ["train", "val"]:
136
+ # -1 indicates random sampling.
137
+ temporal_sample_index = -1
138
+ spatial_sample_index = -1
139
+ min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
140
+ max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
141
+ crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
142
+ if short_cycle_idx in [0, 1]:
143
+ crop_size = int(
144
+ round(
145
+ self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
146
+ * self.cfg.MULTIGRID.DEFAULT_S
147
+ )
148
+ )
149
+ if self.cfg.MULTIGRID.DEFAULT_S > 0:
150
+ # Decreasing the scale is equivalent to using a larger "span"
151
+ # in a sampling grid.
152
+ min_scale = int(
153
+ round(
154
+ float(min_scale)
155
+ * crop_size
156
+ / self.cfg.MULTIGRID.DEFAULT_S
157
+ )
158
+ )
159
+ elif self.mode in ["test"]:
160
+ temporal_sample_index = (
161
+ self._spatial_temporal_idx[index]
162
+ // self.cfg.TEST.NUM_SPATIAL_CROPS
163
+ )
164
+ # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
165
+ # center, or right if width is larger than height, and top, middle,
166
+ # or bottom if height is larger than width.
167
+ spatial_sample_index = (
168
+ (
169
+ self._spatial_temporal_idx[index]
170
+ % self.cfg.TEST.NUM_SPATIAL_CROPS
171
+ )
172
+ if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
173
+ else 1
174
+ )
175
+ min_scale, max_scale, crop_size = (
176
+ [self.cfg.DATA.TEST_CROP_SIZE] * 3
177
+ if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
178
+ else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2
179
+ + [self.cfg.DATA.TEST_CROP_SIZE]
180
+ )
181
+ # The testing is deterministic and no jitter should be performed.
182
+ # min_scale, max_scale, and crop_size are expect to be the same.
183
+ assert len({min_scale, max_scale}) == 1
184
+ else:
185
+ raise NotImplementedError(
186
+ "Does not support {} mode".format(self.mode)
187
+ )
188
+ sampling_rate = utils.get_random_sampling_rate(
189
+ self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
190
+ self.cfg.DATA.SAMPLING_RATE,
191
+ )
192
+ # Try to decode and sample a clip from a video. If the video can not be
193
+ # decoded, repeatly find a random video replacement that can be decoded.
194
+ for i_try in range(self._num_retries):
195
+ video_container = None
196
+ try:
197
+ video_container = container.get_video_container(
198
+ self._path_to_videos[index],
199
+ self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
200
+ self.cfg.DATA.DECODING_BACKEND,
201
+ )
202
+ except Exception as e:
203
+ logger.info(
204
+ "Failed to load video from {} with error {}".format(
205
+ self._path_to_videos[index], e
206
+ )
207
+ )
208
+ # Select a random video if the current video was not able to access.
209
+ if video_container is None:
210
+ logger.warning(
211
+ "Failed to meta load video idx {} from {}; trial {}".format(
212
+ index, self._path_to_videos[index], i_try
213
+ )
214
+ )
215
+ if self.mode not in ["test"] and i_try > self._num_retries // 2:
216
+ # let's try another one
217
+ index = random.randint(0, len(self._path_to_videos) - 1)
218
+ continue
219
+
220
+ # Decode video. Meta info is used to perform selective decoding.
221
+ frames = decoder.decode(
222
+ video_container,
223
+ sampling_rate,
224
+ self.cfg.DATA.NUM_FRAMES,
225
+ temporal_sample_index,
226
+ self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
227
+ video_meta=self._video_meta[index],
228
+ target_fps=self.cfg.DATA.TARGET_FPS,
229
+ backend=self.cfg.DATA.DECODING_BACKEND,
230
+ max_spatial_scale=min_scale,
231
+ )
232
+
233
+ # If decoding failed (wrong format, video is too short, and etc),
234
+ # select another video.
235
+ if frames is None:
236
+ logger.warning(
237
+ "Failed to decode video idx {} from {}; trial {}".format(
238
+ index, self._path_to_videos[index], i_try
239
+ )
240
+ )
241
+ if self.mode not in ["test"] and i_try > self._num_retries // 2:
242
+ # let's try another one
243
+ index = random.randint(0, len(self._path_to_videos) - 1)
244
+ continue
245
+
246
+
247
+ label = self._labels[index]
248
+
249
+ # Perform color normalization.
250
+ frames = utils.tensor_normalize(
251
+ frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
252
+ )
253
+
254
+ # T H W C -> C T H W.
255
+ frames = frames.permute(3, 0, 1, 2)
256
+ # Perform data augmentation.
257
+ frames = utils.spatial_sampling(
258
+ frames,
259
+ spatial_idx=spatial_sample_index,
260
+ min_scale=min_scale,
261
+ max_scale=max_scale,
262
+ crop_size=crop_size,
263
+ random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
264
+ inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
265
+ )
266
+
267
+
268
+ if not self.cfg.MODEL.ARCH in ['vit']:
269
+ frames = utils.pack_pathway_output(self.cfg, frames)
270
+ else:
271
+ # Perform temporal sampling from the fast pathway.
272
+ frames = torch.index_select(
273
+ frames,
274
+ 1,
275
+ torch.linspace(
276
+ 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
277
+
278
+ ).long(),
279
+ )
280
+
281
+ return frames, label, index, {}
282
+ else:
283
+ raise RuntimeError(
284
+ "Failed to fetch video after {} retries.".format(
285
+ self._num_retries
286
+ )
287
+ )
288
+
289
+ def __len__(self):
290
+ """
291
+ Returns:
292
+ (int): the number of videos in the dataset.
293
+ """
294
+ return len(self._path_to_videos)
TimeSformer/timesformer/datasets/loader.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """Data loader."""
4
+
5
+ import itertools
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data._utils.collate import default_collate
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from torch.utils.data.sampler import RandomSampler
11
+
12
+ from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler
13
+
14
+ from . import utils as utils
15
+ from .build import build_dataset
16
+
17
+
18
+ def detection_collate(batch):
19
+ """
20
+ Collate function for detection task. Concatanate bboxes, labels and
21
+ metadata from different samples in the first dimension instead of
22
+ stacking them to have a batch-size dimension.
23
+ Args:
24
+ batch (tuple or list): data batch to collate.
25
+ Returns:
26
+ (tuple): collated detection data batch.
27
+ """
28
+ inputs, labels, video_idx, extra_data = zip(*batch)
29
+ inputs, video_idx = default_collate(inputs), default_collate(video_idx)
30
+ labels = torch.tensor(np.concatenate(labels, axis=0)).float()
31
+
32
+ collated_extra_data = {}
33
+ for key in extra_data[0].keys():
34
+ data = [d[key] for d in extra_data]
35
+ if key == "boxes" or key == "ori_boxes":
36
+ # Append idx info to the bboxes before concatenating them.
37
+ bboxes = [
38
+ np.concatenate(
39
+ [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
40
+ )
41
+ for i in range(len(data))
42
+ ]
43
+ bboxes = np.concatenate(bboxes, axis=0)
44
+ collated_extra_data[key] = torch.tensor(bboxes).float()
45
+ elif key == "metadata":
46
+ collated_extra_data[key] = torch.tensor(
47
+ list(itertools.chain(*data))
48
+ ).view(-1, 2)
49
+ else:
50
+ collated_extra_data[key] = default_collate(data)
51
+
52
+ return inputs, labels, video_idx, collated_extra_data
53
+
54
+
55
+ def construct_loader(cfg, split, is_precise_bn=False):
56
+ """
57
+ Constructs the data loader for the given dataset.
58
+ Args:
59
+ cfg (CfgNode): configs. Details can be found in
60
+ slowfast/config/defaults.py
61
+ split (str): the split of the data loader. Options include `train`,
62
+ `val`, and `test`.
63
+ """
64
+ assert split in ["train", "val", "test"]
65
+ if split in ["train"]:
66
+ dataset_name = cfg.TRAIN.DATASET
67
+ batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
68
+ shuffle = True
69
+ drop_last = True
70
+ elif split in ["val"]:
71
+ dataset_name = cfg.TRAIN.DATASET
72
+ batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
73
+ shuffle = False
74
+ drop_last = False
75
+ elif split in ["test"]:
76
+ dataset_name = cfg.TEST.DATASET
77
+ batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
78
+ shuffle = False
79
+ drop_last = False
80
+
81
+ # Construct the dataset
82
+ dataset = build_dataset(dataset_name, cfg, split)
83
+
84
+ if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
85
+ # Create a sampler for multi-process training
86
+ sampler = utils.create_sampler(dataset, shuffle, cfg)
87
+ batch_sampler = ShortCycleBatchSampler(
88
+ sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
89
+ )
90
+ # Create a loader
91
+ loader = torch.utils.data.DataLoader(
92
+ dataset,
93
+ batch_sampler=batch_sampler,
94
+ num_workers=cfg.DATA_LOADER.NUM_WORKERS,
95
+ pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
96
+ worker_init_fn=utils.loader_worker_init_fn(dataset),
97
+ )
98
+ else:
99
+ # Create a sampler for multi-process training
100
+ sampler = utils.create_sampler(dataset, shuffle, cfg)
101
+ # Create a loader
102
+ loader = torch.utils.data.DataLoader(
103
+ dataset,
104
+ batch_size=batch_size,
105
+ shuffle=(False if sampler else shuffle),
106
+ sampler=sampler,
107
+ num_workers=cfg.DATA_LOADER.NUM_WORKERS,
108
+ pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
109
+ drop_last=drop_last,
110
+ collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
111
+ worker_init_fn=utils.loader_worker_init_fn(dataset),
112
+ )
113
+ return loader
114
+
115
+
116
+ def shuffle_dataset(loader, cur_epoch):
117
+ """ "
118
+ Shuffles the data.
119
+ Args:
120
+ loader (loader): data loader to perform shuffle.
121
+ cur_epoch (int): number of the current epoch.
122
+ """
123
+ sampler = (
124
+ loader.batch_sampler.sampler
125
+ if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
126
+ else loader.sampler
127
+ )
128
+ assert isinstance(
129
+ sampler, (RandomSampler, DistributedSampler)
130
+ ), "Sampler type '{}' not supported".format(type(sampler))
131
+ # RandomSampler handles shuffling automatically
132
+ if isinstance(sampler, DistributedSampler):
133
+ # DistributedSampler shuffles data based on epoch
134
+ sampler.set_epoch(cur_epoch)
TimeSformer/timesformer/datasets/multigrid_helper.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """Helper functions for multigrid training."""
4
+
5
+ import numpy as np
6
+ from torch._six import int_classes as _int_classes
7
+ from torch.utils.data.sampler import Sampler
8
+
9
+
10
+ class ShortCycleBatchSampler(Sampler):
11
+ """
12
+ Extend Sampler to support "short cycle" sampling.
13
+ See paper "A Multigrid Method for Efficiently Training Video Models",
14
+ Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details.
15
+ """
16
+
17
+ def __init__(self, sampler, batch_size, drop_last, cfg):
18
+ if not isinstance(sampler, Sampler):
19
+ raise ValueError(
20
+ "sampler should be an instance of "
21
+ "torch.utils.data.Sampler, but got sampler={}".format(sampler)
22
+ )
23
+ if (
24
+ not isinstance(batch_size, _int_classes)
25
+ or isinstance(batch_size, bool)
26
+ or batch_size <= 0
27
+ ):
28
+ raise ValueError(
29
+ "batch_size should be a positive integer value, "
30
+ "but got batch_size={}".format(batch_size)
31
+ )
32
+ if not isinstance(drop_last, bool):
33
+ raise ValueError(
34
+ "drop_last should be a boolean value, but got "
35
+ "drop_last={}".format(drop_last)
36
+ )
37
+ self.sampler = sampler
38
+ self.drop_last = drop_last
39
+
40
+ bs_factor = [
41
+ int(
42
+ round(
43
+ (
44
+ float(cfg.DATA.TRAIN_CROP_SIZE)
45
+ / (s * cfg.MULTIGRID.DEFAULT_S)
46
+ )
47
+ ** 2
48
+ )
49
+ )
50
+ for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS
51
+ ]
52
+
53
+ self.batch_sizes = [
54
+ batch_size * bs_factor[0],
55
+ batch_size * bs_factor[1],
56
+ batch_size,
57
+ ]
58
+
59
+ def __iter__(self):
60
+ counter = 0
61
+ batch_size = self.batch_sizes[0]
62
+ batch = []
63
+ for idx in self.sampler:
64
+ batch.append((idx, counter % 3))
65
+ if len(batch) == batch_size:
66
+ yield batch
67
+ counter += 1
68
+ batch_size = self.batch_sizes[counter % 3]
69
+ batch = []
70
+ if len(batch) > 0 and not self.drop_last:
71
+ yield batch
72
+
73
+ def __len__(self):
74
+ avg_batch_size = sum(self.batch_sizes) / 3.0
75
+ if self.drop_last:
76
+ return int(np.floor(len(self.sampler) / avg_batch_size))
77
+ else:
78
+ return int(np.ceil(len(self.sampler) / avg_batch_size))
TimeSformer/timesformer/datasets/ssv2.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ from itertools import chain as chain
8
+ import torch
9
+ import torch.utils.data
10
+ from fvcore.common.file_io import PathManager
11
+
12
+ import timesformer.utils.logging as logging
13
+
14
+ from . import utils as utils
15
+ from .build import DATASET_REGISTRY
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ @DATASET_REGISTRY.register()
21
+ class Ssv2(torch.utils.data.Dataset):
22
+ """
23
+ Something-Something v2 (SSV2) video loader. Construct the SSV2 video loader,
24
+ then sample clips from the videos. For training and validation, a single
25
+ clip is randomly sampled from every video with random cropping, scaling, and
26
+ flipping. For testing, multiple clips are uniformaly sampled from every
27
+ video with uniform cropping. For uniform cropping, we take the left, center,
28
+ and right crop if the width is larger than height, or take top, center, and
29
+ bottom crop if the height is larger than the width.
30
+ """
31
+
32
+ def __init__(self, cfg, mode, num_retries=10):
33
+ """
34
+ Load Something-Something V2 data (frame paths, labels, etc. ) to a given
35
+ Dataset object. The dataset could be downloaded from Something-Something
36
+ official website (https://20bn.com/datasets/something-something).
37
+ Please see datasets/DATASET.md for more information about the data format.
38
+ Args:
39
+ cfg (CfgNode): configs.
40
+ mode (string): Options includes `train`, `val`, or `test` mode.
41
+ For the train and val mode, the data loader will take data
42
+ from the train or val set, and sample one clip per video.
43
+ For the test mode, the data loader will take data from test set,
44
+ and sample multiple clips per video.
45
+ num_retries (int): number of retries for reading frames from disk.
46
+ """
47
+ # Only support train, val, and test mode.
48
+ assert mode in [
49
+ "train",
50
+ "val",
51
+ "test",
52
+ ], "Split '{}' not supported for Something-Something V2".format(mode)
53
+ self.mode = mode
54
+ self.cfg = cfg
55
+
56
+ self._video_meta = {}
57
+ self._num_retries = num_retries
58
+ # For training or validation mode, one single clip is sampled from every
59
+ # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
60
+ # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
61
+ # the frames.
62
+ if self.mode in ["train", "val"]:
63
+ self._num_clips = 1
64
+ elif self.mode in ["test"]:
65
+ self._num_clips = (
66
+ cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
67
+ )
68
+
69
+ logger.info("Constructing Something-Something V2 {}...".format(mode))
70
+ self._construct_loader()
71
+
72
+ def _construct_loader(self):
73
+ """
74
+ Construct the video loader.
75
+ """
76
+ # Loading label names.
77
+ with PathManager.open(
78
+ os.path.join(
79
+ self.cfg.DATA.PATH_TO_DATA_DIR,
80
+ "something-something-v2-labels.json",
81
+ ),
82
+ "r",
83
+ ) as f:
84
+ label_dict = json.load(f)
85
+
86
+ # Loading labels.
87
+ label_file = os.path.join(
88
+ self.cfg.DATA.PATH_TO_DATA_DIR,
89
+ "something-something-v2-{}.json".format(
90
+ "train" if self.mode == "train" else "validation"
91
+ ),
92
+ )
93
+ with PathManager.open(label_file, "r") as f:
94
+ label_json = json.load(f)
95
+
96
+ self._video_names = []
97
+ self._labels = []
98
+ for video in label_json:
99
+ video_name = video["id"]
100
+ template = video["template"]
101
+ template = template.replace("[", "")
102
+ template = template.replace("]", "")
103
+ label = int(label_dict[template])
104
+ self._video_names.append(video_name)
105
+ self._labels.append(label)
106
+
107
+ path_to_file = os.path.join(
108
+ self.cfg.DATA.PATH_TO_DATA_DIR,
109
+ "{}.csv".format("train" if self.mode == "train" else "val"),
110
+ )
111
+ assert PathManager.exists(path_to_file), "{} dir not found".format(
112
+ path_to_file
113
+ )
114
+
115
+ self._path_to_videos, _ = utils.load_image_lists(
116
+ path_to_file, self.cfg.DATA.PATH_PREFIX
117
+ )
118
+
119
+ assert len(self._path_to_videos) == len(self._video_names), (
120
+ len(self._path_to_videos),
121
+ len(self._video_names),
122
+ )
123
+
124
+
125
+ # From dict to list.
126
+ new_paths, new_labels = [], []
127
+ for index in range(len(self._video_names)):
128
+ if self._video_names[index] in self._path_to_videos:
129
+ new_paths.append(self._path_to_videos[self._video_names[index]])
130
+ new_labels.append(self._labels[index])
131
+
132
+ self._labels = new_labels
133
+ self._path_to_videos = new_paths
134
+
135
+ # Extend self when self._num_clips > 1 (during testing).
136
+ self._path_to_videos = list(
137
+ chain.from_iterable(
138
+ [[x] * self._num_clips for x in self._path_to_videos]
139
+ )
140
+ )
141
+ self._labels = list(
142
+ chain.from_iterable([[x] * self._num_clips for x in self._labels])
143
+ )
144
+ self._spatial_temporal_idx = list(
145
+ chain.from_iterable(
146
+ [
147
+ range(self._num_clips)
148
+ for _ in range(len(self._path_to_videos))
149
+ ]
150
+ )
151
+ )
152
+ logger.info(
153
+ "Something-Something V2 dataloader constructed "
154
+ " (size: {}) from {}".format(
155
+ len(self._path_to_videos), path_to_file
156
+ )
157
+ )
158
+
159
+ def __getitem__(self, index):
160
+ """
161
+ Given the video index, return the list of frames, label, and video
162
+ index if the video frames can be fetched.
163
+ Args:
164
+ index (int): the video index provided by the pytorch sampler.
165
+ Returns:
166
+ frames (tensor): the frames of sampled from the video. The dimension
167
+ is `channel` x `num frames` x `height` x `width`.
168
+ label (int): the label of the current video.
169
+ index (int): the index of the video.
170
+ """
171
+ short_cycle_idx = None
172
+ # When short cycle is used, input index is a tupple.
173
+ if isinstance(index, tuple):
174
+ index, short_cycle_idx = index
175
+
176
+ if self.mode in ["train", "val"]: #or self.cfg.MODEL.ARCH in ['resformer', 'vit']:
177
+ # -1 indicates random sampling.
178
+ spatial_sample_index = -1
179
+ min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
180
+ max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
181
+ crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
182
+ if short_cycle_idx in [0, 1]:
183
+ crop_size = int(
184
+ round(
185
+ self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
186
+ * self.cfg.MULTIGRID.DEFAULT_S
187
+ )
188
+ )
189
+ if self.cfg.MULTIGRID.DEFAULT_S > 0:
190
+ # Decreasing the scale is equivalent to using a larger "span"
191
+ # in a sampling grid.
192
+ min_scale = int(
193
+ round(
194
+ float(min_scale)
195
+ * crop_size
196
+ / self.cfg.MULTIGRID.DEFAULT_S
197
+ )
198
+ )
199
+ elif self.mode in ["test"]:
200
+ # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
201
+ # center, or right if width is larger than height, and top, middle,
202
+ # or bottom if height is larger than width.
203
+ spatial_sample_index = (
204
+ self._spatial_temporal_idx[index]
205
+ % self.cfg.TEST.NUM_SPATIAL_CROPS
206
+ )
207
+ if self.cfg.TEST.NUM_SPATIAL_CROPS == 1:
208
+ spatial_sample_index = 1
209
+
210
+ min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3
211
+ # The testing is deterministic and no jitter should be performed.
212
+ # min_scale, max_scale, and crop_size are expect to be the same.
213
+ assert len({min_scale, max_scale, crop_size}) == 1
214
+ else:
215
+ raise NotImplementedError(
216
+ "Does not support {} mode".format(self.mode)
217
+ )
218
+
219
+ label = self._labels[index]
220
+
221
+ num_frames = self.cfg.DATA.NUM_FRAMES
222
+ video_length = len(self._path_to_videos[index])
223
+
224
+
225
+ seg_size = float(video_length - 1) / num_frames
226
+ seq = []
227
+ for i in range(num_frames):
228
+ start = int(np.round(seg_size * i))
229
+ end = int(np.round(seg_size * (i + 1)))
230
+ if self.mode == "train":
231
+ seq.append(random.randint(start, end))
232
+ else:
233
+ seq.append((start + end) // 2)
234
+
235
+ frames = torch.as_tensor(
236
+ utils.retry_load_images(
237
+ [self._path_to_videos[index][frame] for frame in seq],
238
+ self._num_retries,
239
+ )
240
+ )
241
+
242
+ # Perform color normalization.
243
+ frames = utils.tensor_normalize(
244
+ frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
245
+ )
246
+
247
+ # T H W C -> C T H W.
248
+ frames = frames.permute(3, 0, 1, 2)
249
+ frames = utils.spatial_sampling(
250
+ frames,
251
+ spatial_idx=spatial_sample_index,
252
+ min_scale=min_scale,
253
+ max_scale=max_scale,
254
+ crop_size=crop_size,
255
+ random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
256
+ inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
257
+ )
258
+ #if not self.cfg.RESFORMER.ACTIVE:
259
+ if not self.cfg.MODEL.ARCH in ['vit']:
260
+ frames = utils.pack_pathway_output(self.cfg, frames)
261
+ else:
262
+ # Perform temporal sampling from the fast pathway.
263
+ frames = torch.index_select(
264
+ frames,
265
+ 1,
266
+ torch.linspace(
267
+ 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
268
+
269
+ ).long(),
270
+ )
271
+ return frames, label, index, {}
272
+
273
+ def __len__(self):
274
+ """
275
+ Returns:
276
+ (int): the number of videos in the dataset.
277
+ """
278
+ return len(self._path_to_videos)
TimeSformer/timesformer/datasets/transform.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def random_short_side_scale_jitter(
9
+ images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
10
+ ):
11
+ """
12
+ Perform a spatial short scale jittering on the given images and
13
+ corresponding boxes.
14
+ Args:
15
+ images (tensor): images to perform scale jitter. Dimension is
16
+ `num frames` x `channel` x `height` x `width`.
17
+ min_size (int): the minimal size to scale the frames.
18
+ max_size (int): the maximal size to scale the frames.
19
+ boxes (ndarray): optional. Corresponding boxes to images.
20
+ Dimension is `num boxes` x 4.
21
+ inverse_uniform_sampling (bool): if True, sample uniformly in
22
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
23
+ scale. If False, take a uniform sample from [min_scale, max_scale].
24
+ Returns:
25
+ (tensor): the scaled images with dimension of
26
+ `num frames` x `channel` x `new height` x `new width`.
27
+ (ndarray or None): the scaled boxes with dimension of
28
+ `num boxes` x 4.
29
+ """
30
+ if inverse_uniform_sampling:
31
+ size = int(
32
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
33
+ )
34
+ else:
35
+ size = int(round(np.random.uniform(min_size, max_size)))
36
+
37
+ height = images.shape[2]
38
+ width = images.shape[3]
39
+ if (width <= height and width == size) or (
40
+ height <= width and height == size
41
+ ):
42
+ return images, boxes
43
+ new_width = size
44
+ new_height = size
45
+ if width < height:
46
+ new_height = int(math.floor((float(height) / width) * size))
47
+ if boxes is not None:
48
+ boxes = boxes * float(new_height) / height
49
+ else:
50
+ new_width = int(math.floor((float(width) / height) * size))
51
+ if boxes is not None:
52
+ boxes = boxes * float(new_width) / width
53
+
54
+ return (
55
+ torch.nn.functional.interpolate(
56
+ images,
57
+ size=(new_height, new_width),
58
+ mode="bilinear",
59
+ align_corners=False,
60
+ ),
61
+ boxes,
62
+ )
63
+
64
+
65
+ def crop_boxes(boxes, x_offset, y_offset):
66
+ """
67
+ Peform crop on the bounding boxes given the offsets.
68
+ Args:
69
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
70
+ is `num boxes` x 4.
71
+ x_offset (int): cropping offset in the x axis.
72
+ y_offset (int): cropping offset in the y axis.
73
+ Returns:
74
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
75
+ `num boxes` x 4.
76
+ """
77
+ cropped_boxes = boxes.copy()
78
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
79
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
80
+
81
+ return cropped_boxes
82
+
83
+
84
+ def random_crop(images, size, boxes=None):
85
+ """
86
+ Perform random spatial crop on the given images and corresponding boxes.
87
+ Args:
88
+ images (tensor): images to perform random crop. The dimension is
89
+ `num frames` x `channel` x `height` x `width`.
90
+ size (int): the size of height and width to crop on the image.
91
+ boxes (ndarray or None): optional. Corresponding boxes to images.
92
+ Dimension is `num boxes` x 4.
93
+ Returns:
94
+ cropped (tensor): cropped images with dimension of
95
+ `num frames` x `channel` x `size` x `size`.
96
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
97
+ `num boxes` x 4.
98
+ """
99
+ if images.shape[2] == size and images.shape[3] == size:
100
+ return images, None
101
+ height = images.shape[2]
102
+ width = images.shape[3]
103
+ y_offset = 0
104
+ if height > size:
105
+ y_offset = int(np.random.randint(0, height - size))
106
+ x_offset = 0
107
+ if width > size:
108
+ x_offset = int(np.random.randint(0, width - size))
109
+ cropped = images[
110
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
111
+ ]
112
+
113
+ cropped_boxes = (
114
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
115
+ )
116
+
117
+ return cropped, cropped_boxes
118
+
119
+
120
+ def horizontal_flip(prob, images, boxes=None):
121
+ """
122
+ Perform horizontal flip on the given images and corresponding boxes.
123
+ Args:
124
+ prob (float): probility to flip the images.
125
+ images (tensor): images to perform horizontal flip, the dimension is
126
+ `num frames` x `channel` x `height` x `width`.
127
+ boxes (ndarray or None): optional. Corresponding boxes to images.
128
+ Dimension is `num boxes` x 4.
129
+ Returns:
130
+ images (tensor): images with dimension of
131
+ `num frames` x `channel` x `height` x `width`.
132
+ flipped_boxes (ndarray or None): the flipped boxes with dimension of
133
+ `num boxes` x 4.
134
+ """
135
+ if boxes is None:
136
+ flipped_boxes = None
137
+ else:
138
+ flipped_boxes = boxes.copy()
139
+
140
+ if np.random.uniform() < prob:
141
+ images = images.flip((-1))
142
+
143
+ width = images.shape[3]
144
+ if boxes is not None:
145
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
146
+
147
+ return images, flipped_boxes
148
+
149
+
150
+ def uniform_crop(images, size, spatial_idx, boxes=None):
151
+ """
152
+ Perform uniform spatial sampling on the images and corresponding boxes.
153
+ Args:
154
+ images (tensor): images to perform uniform crop. The dimension is
155
+ `num frames` x `channel` x `height` x `width`.
156
+ size (int): size of height and weight to crop the images.
157
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
158
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
159
+ crop if height is larger than width.
160
+ boxes (ndarray or None): optional. Corresponding boxes to images.
161
+ Dimension is `num boxes` x 4.
162
+ Returns:
163
+ cropped (tensor): images with dimension of
164
+ `num frames` x `channel` x `size` x `size`.
165
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
166
+ `num boxes` x 4.
167
+ """
168
+ assert spatial_idx in [0, 1, 2]
169
+ height = images.shape[2]
170
+ width = images.shape[3]
171
+
172
+ y_offset = int(math.ceil((height - size) / 2))
173
+ x_offset = int(math.ceil((width - size) / 2))
174
+
175
+ if height > width:
176
+ if spatial_idx == 0:
177
+ y_offset = 0
178
+ elif spatial_idx == 2:
179
+ y_offset = height - size
180
+ else:
181
+ if spatial_idx == 0:
182
+ x_offset = 0
183
+ elif spatial_idx == 2:
184
+ x_offset = width - size
185
+ cropped = images[
186
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
187
+ ]
188
+
189
+ cropped_boxes = (
190
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
191
+ )
192
+
193
+ return cropped, cropped_boxes
194
+
195
+
196
+ def uniform_crop_2crops(images, size, spatial_idx, boxes=None):
197
+ """
198
+ Perform uniform spatial sampling on the images and corresponding boxes.
199
+ Args:
200
+ images (tensor): images to perform uniform crop. The dimension is
201
+ `num frames` x `channel` x `height` x `width`.
202
+ size (int): size of height and weight to crop the images.
203
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
204
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
205
+ crop if height is larger than width.
206
+ boxes (ndarray or None): optional. Corresponding boxes to images.
207
+ Dimension is `num boxes` x 4.
208
+ Returns:
209
+ cropped (tensor): images with dimension of
210
+ `num frames` x `channel` x `size` x `size`.
211
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
212
+ `num boxes` x 4.
213
+ """
214
+ assert spatial_idx in [0, 1, 2]
215
+ height = images.shape[2]
216
+ width = images.shape[3]
217
+
218
+
219
+ if height > width:
220
+ x_offset = 0
221
+ if height > size * 2:
222
+ if spatial_idx == 0:
223
+ y_offset = int((height - size * 2) // 2)
224
+ elif spatial_idx == 1:
225
+ y_offset = int(height - size - ((height - size * 2) // 2))
226
+ else:
227
+ if spatial_idx == 0:
228
+ y_offset = 0
229
+ elif spatial_idx == 1:
230
+ y_offset = height - size
231
+ else:
232
+ y_offset = 0
233
+ if width > size * 2:
234
+ if spatial_idx == 0:
235
+ x_offset = int((width - size * 2) // 2)
236
+ elif spatial_idx == 1:
237
+ x_offset = int(width - size - ((width - size * 2) // 2))
238
+ else:
239
+ if spatial_idx == 0:
240
+ x_offset = 0
241
+ elif spatial_idx == 1:
242
+ x_offset = width - size
243
+
244
+ cropped = images[
245
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
246
+ ]
247
+
248
+ cropped_boxes = (
249
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
250
+ )
251
+
252
+ return cropped, cropped_boxes
253
+
254
+ def clip_boxes_to_image(boxes, height, width):
255
+ """
256
+ Clip an array of boxes to an image with the given height and width.
257
+ Args:
258
+ boxes (ndarray): bounding boxes to perform clipping.
259
+ Dimension is `num boxes` x 4.
260
+ height (int): given image height.
261
+ width (int): given image width.
262
+ Returns:
263
+ clipped_boxes (ndarray): the clipped boxes with dimension of
264
+ `num boxes` x 4.
265
+ """
266
+ clipped_boxes = boxes.copy()
267
+ clipped_boxes[:, [0, 2]] = np.minimum(
268
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
269
+ )
270
+ clipped_boxes[:, [1, 3]] = np.minimum(
271
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
272
+ )
273
+ return clipped_boxes
274
+
275
+
276
+ def blend(images1, images2, alpha):
277
+ """
278
+ Blend two images with a given weight alpha.
279
+ Args:
280
+ images1 (tensor): the first images to be blended, the dimension is
281
+ `num frames` x `channel` x `height` x `width`.
282
+ images2 (tensor): the second images to be blended, the dimension is
283
+ `num frames` x `channel` x `height` x `width`.
284
+ alpha (float): the blending weight.
285
+ Returns:
286
+ (tensor): blended images, the dimension is
287
+ `num frames` x `channel` x `height` x `width`.
288
+ """
289
+ return images1 * alpha + images2 * (1 - alpha)
290
+
291
+
292
+ def grayscale(images):
293
+ """
294
+ Get the grayscale for the input images. The channels of images should be
295
+ in order BGR.
296
+ Args:
297
+ images (tensor): the input images for getting grayscale. Dimension is
298
+ `num frames` x `channel` x `height` x `width`.
299
+ Returns:
300
+ img_gray (tensor): blended images, the dimension is
301
+ `num frames` x `channel` x `height` x `width`.
302
+ """
303
+ # R -> 0.299, G -> 0.587, B -> 0.114.
304
+ img_gray = torch.tensor(images)
305
+ gray_channel = (
306
+ 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
307
+ )
308
+ img_gray[:, 0] = gray_channel
309
+ img_gray[:, 1] = gray_channel
310
+ img_gray[:, 2] = gray_channel
311
+ return img_gray
312
+
313
+
314
+ def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
315
+ """
316
+ Perfrom a color jittering on the input images. The channels of images
317
+ should be in order BGR.
318
+ Args:
319
+ images (tensor): images to perform color jitter. Dimension is
320
+ `num frames` x `channel` x `height` x `width`.
321
+ img_brightness (float): jitter ratio for brightness.
322
+ img_contrast (float): jitter ratio for contrast.
323
+ img_saturation (float): jitter ratio for saturation.
324
+ Returns:
325
+ images (tensor): the jittered images, the dimension is
326
+ `num frames` x `channel` x `height` x `width`.
327
+ """
328
+
329
+ jitter = []
330
+ if img_brightness != 0:
331
+ jitter.append("brightness")
332
+ if img_contrast != 0:
333
+ jitter.append("contrast")
334
+ if img_saturation != 0:
335
+ jitter.append("saturation")
336
+
337
+ if len(jitter) > 0:
338
+ order = np.random.permutation(np.arange(len(jitter)))
339
+ for idx in range(0, len(jitter)):
340
+ if jitter[order[idx]] == "brightness":
341
+ images = brightness_jitter(img_brightness, images)
342
+ elif jitter[order[idx]] == "contrast":
343
+ images = contrast_jitter(img_contrast, images)
344
+ elif jitter[order[idx]] == "saturation":
345
+ images = saturation_jitter(img_saturation, images)
346
+ return images
347
+
348
+
349
+ def brightness_jitter(var, images):
350
+ """
351
+ Perfrom brightness jittering on the input images. The channels of images
352
+ should be in order BGR.
353
+ Args:
354
+ var (float): jitter ratio for brightness.
355
+ images (tensor): images to perform color jitter. Dimension is
356
+ `num frames` x `channel` x `height` x `width`.
357
+ Returns:
358
+ images (tensor): the jittered images, the dimension is
359
+ `num frames` x `channel` x `height` x `width`.
360
+ """
361
+ alpha = 1.0 + np.random.uniform(-var, var)
362
+
363
+ img_bright = torch.zeros(images.shape)
364
+ images = blend(images, img_bright, alpha)
365
+ return images
366
+
367
+
368
+ def contrast_jitter(var, images):
369
+ """
370
+ Perfrom contrast jittering on the input images. The channels of images
371
+ should be in order BGR.
372
+ Args:
373
+ var (float): jitter ratio for contrast.
374
+ images (tensor): images to perform color jitter. Dimension is
375
+ `num frames` x `channel` x `height` x `width`.
376
+ Returns:
377
+ images (tensor): the jittered images, the dimension is
378
+ `num frames` x `channel` x `height` x `width`.
379
+ """
380
+ alpha = 1.0 + np.random.uniform(-var, var)
381
+
382
+ img_gray = grayscale(images)
383
+ img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
384
+ images = blend(images, img_gray, alpha)
385
+ return images
386
+
387
+
388
+ def saturation_jitter(var, images):
389
+ """
390
+ Perfrom saturation jittering on the input images. The channels of images
391
+ should be in order BGR.
392
+ Args:
393
+ var (float): jitter ratio for saturation.
394
+ images (tensor): images to perform color jitter. Dimension is
395
+ `num frames` x `channel` x `height` x `width`.
396
+ Returns:
397
+ images (tensor): the jittered images, the dimension is
398
+ `num frames` x `channel` x `height` x `width`.
399
+ """
400
+ alpha = 1.0 + np.random.uniform(-var, var)
401
+ img_gray = grayscale(images)
402
+ images = blend(images, img_gray, alpha)
403
+
404
+ return images
405
+
406
+
407
+ def lighting_jitter(images, alphastd, eigval, eigvec):
408
+ """
409
+ Perform AlexNet-style PCA jitter on the given images.
410
+ Args:
411
+ images (tensor): images to perform lighting jitter. Dimension is
412
+ `num frames` x `channel` x `height` x `width`.
413
+ alphastd (float): jitter ratio for PCA jitter.
414
+ eigval (list): eigenvalues for PCA jitter.
415
+ eigvec (list[list]): eigenvectors for PCA jitter.
416
+ Returns:
417
+ out_images (tensor): the jittered images, the dimension is
418
+ `num frames` x `channel` x `height` x `width`.
419
+ """
420
+ if alphastd == 0:
421
+ return images
422
+ # generate alpha1, alpha2, alpha3.
423
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
424
+ eig_vec = np.array(eigvec)
425
+ eig_val = np.reshape(eigval, (1, 3))
426
+ rgb = np.sum(
427
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
428
+ axis=1,
429
+ )
430
+ out_images = torch.zeros_like(images)
431
+ for idx in range(images.shape[1]):
432
+ out_images[:, idx] = images[:, idx] + rgb[2 - idx]
433
+
434
+ return out_images
435
+
436
+
437
+ def color_normalization(images, mean, stddev):
438
+ """
439
+ Perform color nomration on the given images.
440
+ Args:
441
+ images (tensor): images to perform color normalization. Dimension is
442
+ `num frames` x `channel` x `height` x `width`.
443
+ mean (list): mean values for normalization.
444
+ stddev (list): standard deviations for normalization.
445
+
446
+ Returns:
447
+ out_images (tensor): the noramlized images, the dimension is
448
+ `num frames` x `channel` x `height` x `width`.
449
+ """
450
+ assert len(mean) == images.shape[1], "channel mean not computed properly"
451
+ assert (
452
+ len(stddev) == images.shape[1]
453
+ ), "channel stddev not computed properly"
454
+
455
+ out_images = torch.zeros_like(images)
456
+ for idx in range(len(mean)):
457
+ out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
458
+
459
+ return out_images
TimeSformer/timesformer/datasets/utils.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import logging
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import time
8
+ from collections import defaultdict
9
+ import cv2
10
+ import torch
11
+ from fvcore.common.file_io import PathManager
12
+ from torch.utils.data.distributed import DistributedSampler
13
+
14
+ from . import transform as transform
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def retry_load_images(image_paths, retry=10, backend="pytorch"):
20
+ """
21
+ This function is to load images with support of retrying for failed load.
22
+
23
+ Args:
24
+ image_paths (list): paths of images needed to be loaded.
25
+ retry (int, optional): maximum time of loading retrying. Defaults to 10.
26
+ backend (str): `pytorch` or `cv2`.
27
+
28
+ Returns:
29
+ imgs (list): list of loaded images.
30
+ """
31
+ for i in range(retry):
32
+ imgs = []
33
+ for image_path in image_paths:
34
+ with PathManager.open(image_path, "rb") as f:
35
+ img_str = np.frombuffer(f.read(), np.uint8)
36
+ img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
37
+ imgs.append(img)
38
+
39
+ if all(img is not None for img in imgs):
40
+ if backend == "pytorch":
41
+ imgs = torch.as_tensor(np.stack(imgs))
42
+ return imgs
43
+ else:
44
+ logger.warn("Reading failed. Will retry.")
45
+ time.sleep(1.0)
46
+ if i == retry - 1:
47
+ raise Exception("Failed to load images {}".format(image_paths))
48
+
49
+
50
+ def get_sequence(center_idx, half_len, sample_rate, num_frames):
51
+ """
52
+ Sample frames among the corresponding clip.
53
+
54
+ Args:
55
+ center_idx (int): center frame idx for current clip
56
+ half_len (int): half of the clip length
57
+ sample_rate (int): sampling rate for sampling frames inside of the clip
58
+ num_frames (int): number of expected sampled frames
59
+
60
+ Returns:
61
+ seq (list): list of indexes of sampled frames in this clip.
62
+ """
63
+ seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate))
64
+
65
+ for seq_idx in range(len(seq)):
66
+ if seq[seq_idx] < 0:
67
+ seq[seq_idx] = 0
68
+ elif seq[seq_idx] >= num_frames:
69
+ seq[seq_idx] = num_frames - 1
70
+ return seq
71
+
72
+
73
+ def pack_pathway_output(cfg, frames):
74
+ """
75
+ Prepare output as a list of tensors. Each tensor corresponding to a
76
+ unique pathway.
77
+ Args:
78
+ frames (tensor): frames of images sampled from the video. The
79
+ dimension is `channel` x `num frames` x `height` x `width`.
80
+ Returns:
81
+ frame_list (list): list of tensors with the dimension of
82
+ `channel` x `num frames` x `height` x `width`.
83
+ """
84
+ if cfg.DATA.REVERSE_INPUT_CHANNEL:
85
+ frames = frames[[2, 1, 0], :, :, :]
86
+ if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
87
+ frame_list = [frames]
88
+ elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH:
89
+ fast_pathway = frames
90
+ # Perform temporal sampling from the fast pathway.
91
+ slow_pathway = torch.index_select(
92
+ frames,
93
+ 1,
94
+ torch.linspace(
95
+ 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA
96
+ ).long(),
97
+ )
98
+ frame_list = [slow_pathway, fast_pathway]
99
+ else:
100
+ raise NotImplementedError(
101
+ "Model arch {} is not in {}".format(
102
+ cfg.MODEL.ARCH,
103
+ cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH,
104
+ )
105
+ )
106
+ return frame_list
107
+
108
+
109
+ def spatial_sampling(
110
+ frames,
111
+ spatial_idx=-1,
112
+ min_scale=256,
113
+ max_scale=320,
114
+ crop_size=224,
115
+ random_horizontal_flip=True,
116
+ inverse_uniform_sampling=False,
117
+ ):
118
+ """
119
+ Perform spatial sampling on the given video frames. If spatial_idx is
120
+ -1, perform random scale, random crop, and random flip on the given
121
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
122
+ with the given spatial_idx.
123
+ Args:
124
+ frames (tensor): frames of images sampled from the video. The
125
+ dimension is `num frames` x `height` x `width` x `channel`.
126
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
127
+ or 2, perform left, center, right crop if width is larger than
128
+ height, and perform top, center, buttom crop if height is larger
129
+ than width.
130
+ min_scale (int): the minimal size of scaling.
131
+ max_scale (int): the maximal size of scaling.
132
+ crop_size (int): the size of height and width used to crop the
133
+ frames.
134
+ inverse_uniform_sampling (bool): if True, sample uniformly in
135
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
136
+ scale. If False, take a uniform sample from [min_scale,
137
+ max_scale].
138
+ Returns:
139
+ frames (tensor): spatially sampled frames.
140
+ """
141
+ assert spatial_idx in [-1, 0, 1, 2]
142
+ if spatial_idx == -1:
143
+ frames, _ = transform.random_short_side_scale_jitter(
144
+ images=frames,
145
+ min_size=min_scale,
146
+ max_size=max_scale,
147
+ inverse_uniform_sampling=inverse_uniform_sampling,
148
+ )
149
+ frames, _ = transform.random_crop(frames, crop_size)
150
+ if random_horizontal_flip:
151
+ frames, _ = transform.horizontal_flip(0.5, frames)
152
+ else:
153
+ # The testing is deterministic and no jitter should be performed.
154
+ # min_scale, max_scale, and crop_size are expect to be the same.
155
+ #assert len({min_scale, max_scale, crop_size}) == 1
156
+ frames, _ = transform.random_short_side_scale_jitter(
157
+ frames, min_scale, max_scale
158
+ )
159
+ frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx)
160
+ return frames
161
+
162
+ def spatial_sampling_2crops(
163
+ frames,
164
+ spatial_idx=-1,
165
+ min_scale=256,
166
+ max_scale=320,
167
+ crop_size=224,
168
+ random_horizontal_flip=True,
169
+ inverse_uniform_sampling=False,
170
+ ):
171
+ """
172
+ Perform spatial sampling on the given video frames. If spatial_idx is
173
+ -1, perform random scale, random crop, and random flip on the given
174
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
175
+ with the given spatial_idx.
176
+ Args:
177
+ frames (tensor): frames of images sampled from the video. The
178
+ dimension is `num frames` x `height` x `width` x `channel`.
179
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
180
+ or 2, perform left, center, right crop if width is larger than
181
+ height, and perform top, center, buttom crop if height is larger
182
+ than width.
183
+ min_scale (int): the minimal size of scaling.
184
+ max_scale (int): the maximal size of scaling.
185
+ crop_size (int): the size of height and width used to crop the
186
+ frames.
187
+ inverse_uniform_sampling (bool): if True, sample uniformly in
188
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
189
+ scale. If False, take a uniform sample from [min_scale,
190
+ max_scale].
191
+ Returns:
192
+ frames (tensor): spatially sampled frames.
193
+ """
194
+ assert spatial_idx in [-1, 0, 1, 2]
195
+ if spatial_idx == -1:
196
+ frames, _ = transform.random_short_side_scale_jitter(
197
+ images=frames,
198
+ min_size=min_scale,
199
+ max_size=max_scale,
200
+ inverse_uniform_sampling=inverse_uniform_sampling,
201
+ )
202
+ frames, _ = transform.random_crop(frames, crop_size)
203
+ if random_horizontal_flip:
204
+ frames, _ = transform.horizontal_flip(0.5, frames)
205
+ else:
206
+ # The testing is deterministic and no jitter should be performed.
207
+ # min_scale, max_scale, and crop_size are expect to be the same.
208
+ #assert len({min_scale, max_scale, crop_size}) == 1
209
+ frames, _ = transform.random_short_side_scale_jitter(
210
+ frames, min_scale, max_scale
211
+ )
212
+ frames, _ = transform.uniform_crop_2crops(frames, crop_size, spatial_idx)
213
+ return frames
214
+
215
+
216
+ def as_binary_vector(labels, num_classes):
217
+ """
218
+ Construct binary label vector given a list of label indices.
219
+ Args:
220
+ labels (list): The input label list.
221
+ num_classes (int): Number of classes of the label vector.
222
+ Returns:
223
+ labels (numpy array): the resulting binary vector.
224
+ """
225
+ label_arr = np.zeros((num_classes,))
226
+
227
+ for lbl in set(labels):
228
+ label_arr[lbl] = 1.0
229
+ return label_arr
230
+
231
+
232
+ def aggregate_labels(label_list):
233
+ """
234
+ Join a list of label list.
235
+ Args:
236
+ labels (list): The input label list.
237
+ Returns:
238
+ labels (list): The joint list of all lists in input.
239
+ """
240
+ all_labels = []
241
+ for labels in label_list:
242
+ for l in labels:
243
+ all_labels.append(l)
244
+ return list(set(all_labels))
245
+
246
+
247
+ def convert_to_video_level_labels(labels):
248
+ """
249
+ Aggregate annotations from all frames of a video to form video-level labels.
250
+ Args:
251
+ labels (list): The input label list.
252
+ Returns:
253
+ labels (list): Same as input, but with each label replaced by
254
+ a video-level one.
255
+ """
256
+ for video_id in range(len(labels)):
257
+ video_level_labels = aggregate_labels(labels[video_id])
258
+ for i in range(len(labels[video_id])):
259
+ labels[video_id][i] = video_level_labels
260
+ return labels
261
+
262
+
263
+ def load_image_lists(frame_list_file, prefix="", return_list=False):
264
+ """
265
+ Load image paths and labels from a "frame list".
266
+ Each line of the frame list contains:
267
+ `original_vido_id video_id frame_id path labels`
268
+ Args:
269
+ frame_list_file (string): path to the frame list.
270
+ prefix (str): the prefix for the path.
271
+ return_list (bool): if True, return a list. If False, return a dict.
272
+ Returns:
273
+ image_paths (list or dict): list of list containing path to each frame.
274
+ If return_list is False, then return in a dict form.
275
+ labels (list or dict): list of list containing label of each frame.
276
+ If return_list is False, then return in a dict form.
277
+ """
278
+ image_paths = defaultdict(list)
279
+ labels = defaultdict(list)
280
+ with PathManager.open(frame_list_file, "r") as f:
281
+ assert f.readline().startswith("original_vido_id")
282
+ for line in f:
283
+ row = line.split()
284
+ # original_vido_id video_id frame_id path labels
285
+ assert len(row) == 5
286
+ video_name = row[0]
287
+ if prefix == "":
288
+ path = row[3]
289
+ else:
290
+ path = os.path.join(prefix, row[3])
291
+ image_paths[video_name].append(path)
292
+ frame_labels = row[-1].replace('"', "")
293
+ if frame_labels != "":
294
+ labels[video_name].append(
295
+ [int(x) for x in frame_labels.split(",")]
296
+ )
297
+ else:
298
+ labels[video_name].append([])
299
+
300
+ if return_list:
301
+ keys = image_paths.keys()
302
+ image_paths = [image_paths[key] for key in keys]
303
+ labels = [labels[key] for key in keys]
304
+ return image_paths, labels
305
+ return dict(image_paths), dict(labels)
306
+
307
+
308
+ def tensor_normalize(tensor, mean, std):
309
+ """
310
+ Normalize a given tensor by subtracting the mean and dividing the std.
311
+ Args:
312
+ tensor (tensor): tensor to normalize.
313
+ mean (tensor or list): mean value to subtract.
314
+ std (tensor or list): std to divide.
315
+ """
316
+ if tensor.dtype == torch.uint8:
317
+ tensor = tensor.float()
318
+ tensor = tensor / 255.0
319
+ if type(mean) == list:
320
+ mean = torch.tensor(mean)
321
+ if type(std) == list:
322
+ std = torch.tensor(std)
323
+ tensor = tensor - mean
324
+ tensor = tensor / std
325
+ return tensor
326
+
327
+
328
+ def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate):
329
+ """
330
+ When multigrid training uses a fewer number of frames, we randomly
331
+ increase the sampling rate so that some clips cover the original span.
332
+ """
333
+ if long_cycle_sampling_rate > 0:
334
+ assert long_cycle_sampling_rate >= sampling_rate
335
+ return random.randint(sampling_rate, long_cycle_sampling_rate)
336
+ else:
337
+ return sampling_rate
338
+
339
+
340
+ def revert_tensor_normalize(tensor, mean, std):
341
+ """
342
+ Revert normalization for a given tensor by multiplying by the std and adding the mean.
343
+ Args:
344
+ tensor (tensor): tensor to revert normalization.
345
+ mean (tensor or list): mean value to add.
346
+ std (tensor or list): std to multiply.
347
+ """
348
+ if type(mean) == list:
349
+ mean = torch.tensor(mean)
350
+ if type(std) == list:
351
+ std = torch.tensor(std)
352
+ tensor = tensor * std
353
+ tensor = tensor + mean
354
+ return tensor
355
+
356
+
357
+ def create_sampler(dataset, shuffle, cfg):
358
+ """
359
+ Create sampler for the given dataset.
360
+ Args:
361
+ dataset (torch.utils.data.Dataset): the given dataset.
362
+ shuffle (bool): set to ``True`` to have the data reshuffled
363
+ at every epoch.
364
+ cfg (CfgNode): configs. Details can be found in
365
+ slowfast/config/defaults.py
366
+ Returns:
367
+ sampler (Sampler): the created sampler.
368
+ """
369
+ sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
370
+
371
+ return sampler
372
+
373
+
374
+ def loader_worker_init_fn(dataset):
375
+ """
376
+ Create init function passed to pytorch data loader.
377
+ Args:
378
+ dataset (torch.utils.data.Dataset): the given dataset.
379
+ """
380
+ return None
TimeSformer/timesformer/datasets/video_container.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import av
4
+
5
+
6
+ def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
7
+ """
8
+ Given the path to the video, return the pyav video container.
9
+ Args:
10
+ path_to_vid (str): path to the video.
11
+ multi_thread_decode (bool): if True, perform multi-thread decoding.
12
+ backend (str): decoder backend, options include `pyav` and
13
+ `torchvision`, default is `pyav`.
14
+ Returns:
15
+ container (container): video container.
16
+ """
17
+ if backend == "torchvision":
18
+ with open(path_to_vid, "rb") as fp:
19
+ container = fp.read()
20
+ return container
21
+ elif backend == "pyav":
22
+ #try:
23
+ container = av.open(path_to_vid)
24
+ if multi_thread_decode:
25
+ # Enable multiple threads for decoding.
26
+ container.streams.video[0].thread_type = "AUTO"
27
+ #except:
28
+ # container = None
29
+ return container
30
+ else:
31
+ raise NotImplementedError("Unknown backend {}".format(backend))
TimeSformer/timesformer/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from .build import MODEL_REGISTRY, build_model # noqa
4
+ from .custom_video_model_builder import * # noqa
5
+ from .video_model_builder import ResNet, SlowFast # noqa
TimeSformer/timesformer/models/batchnorm_helper.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """BatchNorm (BN) utility functions and custom batch-size BN implementations"""
4
+
5
+ from functools import partial
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.nn as nn
9
+ from torch.autograd.function import Function
10
+
11
+ import timesformer.utils.distributed as du
12
+
13
+
14
+ def get_norm(cfg):
15
+ """
16
+ Args:
17
+ cfg (CfgNode): model building configs, details are in the comments of
18
+ the config file.
19
+ Returns:
20
+ nn.Module: the normalization layer.
21
+ """
22
+ if cfg.BN.NORM_TYPE == "batchnorm":
23
+ return nn.BatchNorm3d
24
+ elif cfg.BN.NORM_TYPE == "sub_batchnorm":
25
+ return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS)
26
+ elif cfg.BN.NORM_TYPE == "sync_batchnorm":
27
+ return partial(
28
+ NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES
29
+ )
30
+ else:
31
+ raise NotImplementedError(
32
+ "Norm type {} is not supported".format(cfg.BN.NORM_TYPE)
33
+ )
34
+
35
+
36
+ class SubBatchNorm3d(nn.Module):
37
+ """
38
+ The standard BN layer computes stats across all examples in a GPU. In some
39
+ cases it is desirable to compute stats across only a subset of examples
40
+ (e.g., in multigrid training https://arxiv.org/abs/1912.00998).
41
+ SubBatchNorm3d splits the batch dimension into N splits, and run BN on
42
+ each of them separately (so that the stats are computed on each subset of
43
+ examples (1/N of batch) independently. During evaluation, it aggregates
44
+ the stats from all splits into one BN.
45
+ """
46
+
47
+ def __init__(self, num_splits, **args):
48
+ """
49
+ Args:
50
+ num_splits (int): number of splits.
51
+ args (list): other arguments.
52
+ """
53
+ super(SubBatchNorm3d, self).__init__()
54
+ self.num_splits = num_splits
55
+ num_features = args["num_features"]
56
+ # Keep only one set of weight and bias.
57
+ if args.get("affine", True):
58
+ self.affine = True
59
+ args["affine"] = False
60
+ self.weight = torch.nn.Parameter(torch.ones(num_features))
61
+ self.bias = torch.nn.Parameter(torch.zeros(num_features))
62
+ else:
63
+ self.affine = False
64
+ self.bn = nn.BatchNorm3d(**args)
65
+ args["num_features"] = num_features * num_splits
66
+ self.split_bn = nn.BatchNorm3d(**args)
67
+
68
+ def _get_aggregated_mean_std(self, means, stds, n):
69
+ """
70
+ Calculate the aggregated mean and stds.
71
+ Args:
72
+ means (tensor): mean values.
73
+ stds (tensor): standard deviations.
74
+ n (int): number of sets of means and stds.
75
+ """
76
+ mean = means.view(n, -1).sum(0) / n
77
+ std = (
78
+ stds.view(n, -1).sum(0) / n
79
+ + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n
80
+ )
81
+ return mean.detach(), std.detach()
82
+
83
+ def aggregate_stats(self):
84
+ """
85
+ Synchronize running_mean, and running_var. Call this before eval.
86
+ """
87
+ if self.split_bn.track_running_stats:
88
+ (
89
+ self.bn.running_mean.data,
90
+ self.bn.running_var.data,
91
+ ) = self._get_aggregated_mean_std(
92
+ self.split_bn.running_mean,
93
+ self.split_bn.running_var,
94
+ self.num_splits,
95
+ )
96
+
97
+ def forward(self, x):
98
+ if self.training:
99
+ n, c, t, h, w = x.shape
100
+ x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
101
+ x = self.split_bn(x)
102
+ x = x.view(n, c, t, h, w)
103
+ else:
104
+ x = self.bn(x)
105
+ if self.affine:
106
+ x = x * self.weight.view((-1, 1, 1, 1))
107
+ x = x + self.bias.view((-1, 1, 1, 1))
108
+ return x
109
+
110
+
111
+ class GroupGather(Function):
112
+ """
113
+ GroupGather performs all gather on each of the local process/ GPU groups.
114
+ """
115
+
116
+ @staticmethod
117
+ def forward(ctx, input, num_sync_devices, num_groups):
118
+ """
119
+ Perform forwarding, gathering the stats across different process/ GPU
120
+ group.
121
+ """
122
+ ctx.num_sync_devices = num_sync_devices
123
+ ctx.num_groups = num_groups
124
+
125
+ input_list = [
126
+ torch.zeros_like(input) for k in range(du.get_local_size())
127
+ ]
128
+ dist.all_gather(
129
+ input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP
130
+ )
131
+
132
+ inputs = torch.stack(input_list, dim=0)
133
+ if num_groups > 1:
134
+ rank = du.get_local_rank()
135
+ group_idx = rank // num_sync_devices
136
+ inputs = inputs[
137
+ group_idx
138
+ * num_sync_devices : (group_idx + 1)
139
+ * num_sync_devices
140
+ ]
141
+ inputs = torch.sum(inputs, dim=0)
142
+ return inputs
143
+
144
+ @staticmethod
145
+ def backward(ctx, grad_output):
146
+ """
147
+ Perform backwarding, gathering the gradients across different process/ GPU
148
+ group.
149
+ """
150
+ grad_output_list = [
151
+ torch.zeros_like(grad_output) for k in range(du.get_local_size())
152
+ ]
153
+ dist.all_gather(
154
+ grad_output_list,
155
+ grad_output,
156
+ async_op=False,
157
+ group=du._LOCAL_PROCESS_GROUP,
158
+ )
159
+
160
+ grads = torch.stack(grad_output_list, dim=0)
161
+ if ctx.num_groups > 1:
162
+ rank = du.get_local_rank()
163
+ group_idx = rank // ctx.num_sync_devices
164
+ grads = grads[
165
+ group_idx
166
+ * ctx.num_sync_devices : (group_idx + 1)
167
+ * ctx.num_sync_devices
168
+ ]
169
+ grads = torch.sum(grads, dim=0)
170
+ return grads, None, None
171
+
172
+
173
+ class NaiveSyncBatchNorm3d(nn.BatchNorm3d):
174
+ def __init__(self, num_sync_devices, **args):
175
+ """
176
+ Naive version of Synchronized 3D BatchNorm.
177
+ Args:
178
+ num_sync_devices (int): number of device to sync.
179
+ args (list): other arguments.
180
+ """
181
+ self.num_sync_devices = num_sync_devices
182
+ if self.num_sync_devices > 0:
183
+ assert du.get_local_size() % self.num_sync_devices == 0, (
184
+ du.get_local_size(),
185
+ self.num_sync_devices,
186
+ )
187
+ self.num_groups = du.get_local_size() // self.num_sync_devices
188
+ else:
189
+ self.num_sync_devices = du.get_local_size()
190
+ self.num_groups = 1
191
+ super(NaiveSyncBatchNorm3d, self).__init__(**args)
192
+
193
+ def forward(self, input):
194
+ if du.get_local_size() == 1 or not self.training:
195
+ return super().forward(input)
196
+
197
+ assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
198
+ C = input.shape[1]
199
+ mean = torch.mean(input, dim=[0, 2, 3, 4])
200
+ meansqr = torch.mean(input * input, dim=[0, 2, 3, 4])
201
+
202
+ vec = torch.cat([mean, meansqr], dim=0)
203
+ vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
204
+ 1.0 / self.num_sync_devices
205
+ )
206
+
207
+ mean, meansqr = torch.split(vec, C)
208
+ var = meansqr - mean * mean
209
+ self.running_mean += self.momentum * (mean.detach() - self.running_mean)
210
+ self.running_var += self.momentum * (var.detach() - self.running_var)
211
+
212
+ invstd = torch.rsqrt(var + self.eps)
213
+ scale = self.weight * invstd
214
+ bias = self.bias - mean * scale
215
+ scale = scale.reshape(1, -1, 1, 1, 1)
216
+ bias = bias.reshape(1, -1, 1, 1, 1)
217
+ return input * scale + bias
TimeSformer/timesformer/models/build.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """Model construction functions."""
4
+
5
+ import torch
6
+ from fvcore.common.registry import Registry
7
+
8
+ MODEL_REGISTRY = Registry("MODEL")
9
+ MODEL_REGISTRY.__doc__ = """
10
+ Registry for video model.
11
+
12
+ The registered object will be called with `obj(cfg)`.
13
+ The call should return a `torch.nn.Module` object.
14
+ """
15
+
16
+
17
+ def build_model(cfg, gpu_id=None):
18
+ """
19
+ Builds the video model.
20
+ Args:
21
+ cfg (configs): configs that contains the hyper-parameters to build the
22
+ backbone. Details can be seen in slowfast/config/defaults.py.
23
+ gpu_id (Optional[int]): specify the gpu index to build model.
24
+ """
25
+ if torch.cuda.is_available():
26
+ assert (
27
+ cfg.NUM_GPUS <= torch.cuda.device_count()
28
+ ), "Cannot use more GPU devices than available"
29
+ else:
30
+ assert (
31
+ cfg.NUM_GPUS == 0
32
+ ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs."
33
+
34
+ # Construct the model
35
+ name = cfg.MODEL.MODEL_NAME
36
+ model = MODEL_REGISTRY.get(name)(cfg)
37
+
38
+ if cfg.NUM_GPUS:
39
+ if gpu_id is None:
40
+ # Determine the GPU used by the current process
41
+ cur_device = torch.cuda.current_device()
42
+ else:
43
+ cur_device = gpu_id
44
+ # Transfer the model to the current GPU device
45
+ model = model.cuda(device=cur_device)
46
+
47
+
48
+ # Use multi-process data parallel model in the multi-gpu setting
49
+ if cfg.NUM_GPUS > 1:
50
+ # Make model replica operate on the current device
51
+ model = torch.nn.parallel.DistributedDataParallel(
52
+ module=model, device_ids=[cur_device], output_device=cur_device
53
+ )
54
+ return model
TimeSformer/timesformer/models/conv2d_same.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Ross Wightman
2
+ # Conv2d w/ Same Padding
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional
8
+
9
+ import math
10
+ from typing import List, Tuple
11
+ #from .padding import pad_same, get_padding_value
12
+
13
+ # Dynamically pad input x with 'SAME' padding for conv with specified args
14
+ def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
15
+ ih, iw = x.size()[-2:]
16
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
17
+ if pad_h > 0 or pad_w > 0:
18
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
19
+ return x
20
+
21
+ # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
22
+ def get_same_padding(x: int, k: int, s: int, d: int):
23
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
24
+
25
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
26
+ dynamic = False
27
+ if isinstance(padding, str):
28
+ # for any string padding, the padding will be calculated for you, one of three ways
29
+ padding = padding.lower()
30
+ if padding == 'same':
31
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
32
+ if is_static_pad(kernel_size, **kwargs):
33
+ # static case, no extra overhead
34
+ padding = get_padding(kernel_size, **kwargs)
35
+ else:
36
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
37
+ padding = 0
38
+ dynamic = True
39
+ elif padding == 'valid':
40
+ # 'VALID' padding, same as padding=0
41
+ padding = 0
42
+ else:
43
+ # Default to PyTorch style 'same'-ish symmetric padding
44
+ padding = get_padding(kernel_size, **kwargs)
45
+ return padding, dynamic
46
+
47
+ def conv2d_same(
48
+ x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
49
+ padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
50
+ x = pad_same(x, weight.shape[-2:], stride, dilation)
51
+ return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
52
+
53
+
54
+ class Conv2dSame(nn.Conv2d):
55
+ """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
56
+ """
57
+
58
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
59
+ padding=0, dilation=1, groups=1, bias=True):
60
+ super(Conv2dSame, self).__init__(
61
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
62
+
63
+ def forward(self, x):
64
+ return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
65
+
66
+
67
+ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
68
+ padding = kwargs.pop('padding', '')
69
+ kwargs.setdefault('bias', False)
70
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
71
+ if is_dynamic:
72
+ return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
73
+ else:
74
+ return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
TimeSformer/timesformer/models/custom_video_model_builder.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+
4
+ """A More Flexible Video models."""
TimeSformer/timesformer/models/features.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Ross Wightman
2
+
3
+ from collections import OrderedDict, defaultdict
4
+ from copy import deepcopy
5
+ from functools import partial
6
+ from typing import Dict, List, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class FeatureInfo:
13
+
14
+ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
15
+ prev_reduction = 1
16
+ for fi in feature_info:
17
+ # sanity check the mandatory fields, there may be additional fields depending on the model
18
+ assert 'num_chs' in fi and fi['num_chs'] > 0
19
+ assert 'reduction' in fi and fi['reduction'] >= prev_reduction
20
+ prev_reduction = fi['reduction']
21
+ assert 'module' in fi
22
+ self.out_indices = out_indices
23
+ self.info = feature_info
24
+
25
+ def from_other(self, out_indices: Tuple[int]):
26
+ return FeatureInfo(deepcopy(self.info), out_indices)
27
+
28
+ def get(self, key, idx=None):
29
+ """ Get value by key at specified index (indices)
30
+ if idx == None, returns value for key at each output index
31
+ if idx is an integer, return value for that feature module index (ignoring output indices)
32
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
33
+ """
34
+ if idx is None:
35
+ return [self.info[i][key] for i in self.out_indices]
36
+ if isinstance(idx, (tuple, list)):
37
+ return [self.info[i][key] for i in idx]
38
+ else:
39
+ return self.info[idx][key]
40
+
41
+ def get_dicts(self, keys=None, idx=None):
42
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
43
+ """
44
+ if idx is None:
45
+ if keys is None:
46
+ return [self.info[i] for i in self.out_indices]
47
+ else:
48
+ return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
49
+ if isinstance(idx, (tuple, list)):
50
+ return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
51
+ else:
52
+ return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
53
+
54
+ def channels(self, idx=None):
55
+ """ feature channels accessor
56
+ """
57
+ return self.get('num_chs', idx)
58
+
59
+ def reduction(self, idx=None):
60
+ """ feature reduction (output stride) accessor
61
+ """
62
+ return self.get('reduction', idx)
63
+
64
+ def module_name(self, idx=None):
65
+ """ feature module name accessor
66
+ """
67
+ return self.get('module', idx)
68
+
69
+ def __getitem__(self, item):
70
+ return self.info[item]
71
+
72
+ def __len__(self):
73
+ return len(self.info)
74
+
75
+
76
+ class FeatureHooks:
77
+ """ Feature Hook Helper
78
+ This module helps with the setup and extraction of hooks for extracting features from
79
+ internal nodes in a model by node name. This works quite well in eager Python but needs
80
+ redesign for torcscript.
81
+ """
82
+
83
+ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
84
+ # setup feature hooks
85
+ modules = {k: v for k, v in named_modules}
86
+ for i, h in enumerate(hooks):
87
+ hook_name = h['module']
88
+ m = modules[hook_name]
89
+ hook_id = out_map[i] if out_map else hook_name
90
+ hook_fn = partial(self._collect_output_hook, hook_id)
91
+ hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
92
+ if hook_type == 'forward_pre':
93
+ m.register_forward_pre_hook(hook_fn)
94
+ elif hook_type == 'forward':
95
+ m.register_forward_hook(hook_fn)
96
+ else:
97
+ assert False, "Unsupported hook type"
98
+ self._feature_outputs = defaultdict(OrderedDict)
99
+
100
+ def _collect_output_hook(self, hook_id, *args):
101
+ x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
102
+ if isinstance(x, tuple):
103
+ x = x[0] # unwrap input tuple
104
+ self._feature_outputs[x.device][hook_id] = x
105
+
106
+ def get_output(self, device) -> Dict[str, torch.tensor]:
107
+ output = self._feature_outputs[device]
108
+ self._feature_outputs[device] = OrderedDict() # clear after reading
109
+ return output
110
+
111
+
112
+ def _module_list(module, flatten_sequential=False):
113
+ # a yield/iter would be better for this but wouldn't be compatible with torchscript
114
+ ml = []
115
+ for name, module in module.named_children():
116
+ if flatten_sequential and isinstance(module, nn.Sequential):
117
+ # first level of Sequential containers is flattened into containing model
118
+ for child_name, child_module in module.named_children():
119
+ combined = [name, child_name]
120
+ ml.append(('_'.join(combined), '.'.join(combined), child_module))
121
+ else:
122
+ ml.append((name, name, module))
123
+ return ml
124
+
125
+
126
+ def _get_feature_info(net, out_indices):
127
+ feature_info = getattr(net, 'feature_info')
128
+ if isinstance(feature_info, FeatureInfo):
129
+ return feature_info.from_other(out_indices)
130
+ elif isinstance(feature_info, (list, tuple)):
131
+ return FeatureInfo(net.feature_info, out_indices)
132
+ else:
133
+ assert False, "Provided feature_info is not valid"
134
+
135
+
136
+ def _get_return_layers(feature_info, out_map):
137
+ module_names = feature_info.module_name()
138
+ return_layers = {}
139
+ for i, name in enumerate(module_names):
140
+ return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
141
+ return return_layers
142
+
143
+
144
+ class FeatureDictNet(nn.ModuleDict):
145
+ """ Feature extractor with OrderedDict return
146
+ Wrap a model and extract features as specified by the out indices, the network is
147
+ partially re-built from contained modules.
148
+ There is a strong assumption that the modules have been registered into the model in the same
149
+ order as they are used. There should be no reuse of the same nn.Module more than once, including
150
+ trivial modules like `self.relu = nn.ReLU`.
151
+ Only submodules that are directly assigned to the model class (`model.feature1`) or at most
152
+ one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
153
+ All Sequential containers that are directly assigned to the original model will have their
154
+ modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
155
+ Arguments:
156
+ model (nn.Module): model from which we will extract the features
157
+ out_indices (tuple[int]): model output indices to extract features for
158
+ out_map (sequence): list or tuple specifying desired return id for each out index,
159
+ otherwise str(index) is used
160
+ feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
161
+ vs select element [0]
162
+ flatten_sequential (bool): whether to flatten sequential modules assigned to model
163
+ """
164
+ def __init__(
165
+ self, model,
166
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
167
+ super(FeatureDictNet, self).__init__()
168
+ self.feature_info = _get_feature_info(model, out_indices)
169
+ self.concat = feature_concat
170
+ self.return_layers = {}
171
+ return_layers = _get_return_layers(self.feature_info, out_map)
172
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
173
+ remaining = set(return_layers.keys())
174
+ layers = OrderedDict()
175
+ for new_name, old_name, module in modules:
176
+ layers[new_name] = module
177
+ if old_name in remaining:
178
+ # return id has to be consistently str type for torchscript
179
+ self.return_layers[new_name] = str(return_layers[old_name])
180
+ remaining.remove(old_name)
181
+ if not remaining:
182
+ break
183
+ assert not remaining and len(self.return_layers) == len(return_layers), \
184
+ f'Return layers ({remaining}) are not present in model'
185
+ self.update(layers)
186
+
187
+ def _collect(self, x) -> (Dict[str, torch.Tensor]):
188
+ out = OrderedDict()
189
+ for name, module in self.items():
190
+ x = module(x)
191
+ if name in self.return_layers:
192
+ out_id = self.return_layers[name]
193
+ if isinstance(x, (tuple, list)):
194
+ # If model tap is a tuple or list, concat or select first element
195
+ # FIXME this may need to be more generic / flexible for some nets
196
+ out[out_id] = torch.cat(x, 1) if self.concat else x[0]
197
+ else:
198
+ out[out_id] = x
199
+ return out
200
+
201
+ def forward(self, x) -> Dict[str, torch.Tensor]:
202
+ return self._collect(x)
203
+
204
+
205
+ class FeatureListNet(FeatureDictNet):
206
+ """ Feature extractor with list return
207
+ See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
208
+ In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
209
+ """
210
+ def __init__(
211
+ self, model,
212
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
213
+ super(FeatureListNet, self).__init__(
214
+ model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
215
+ flatten_sequential=flatten_sequential)
216
+
217
+ def forward(self, x) -> (List[torch.Tensor]):
218
+ return list(self._collect(x).values())
219
+
220
+
221
+ class FeatureHookNet(nn.ModuleDict):
222
+ """ FeatureHookNet
223
+ Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
224
+ If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
225
+ network in any way.
226
+ If `no_rewrite` is False, the model will be re-written as in the
227
+ FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
228
+ FIXME this does not currently work with Torchscript, see FeatureHooks class
229
+ """
230
+ def __init__(
231
+ self, model,
232
+ out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
233
+ feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
234
+ super(FeatureHookNet, self).__init__()
235
+ assert not torch.jit.is_scripting()
236
+ self.feature_info = _get_feature_info(model, out_indices)
237
+ self.out_as_dict = out_as_dict
238
+ layers = OrderedDict()
239
+ hooks = []
240
+ if no_rewrite:
241
+ assert not flatten_sequential
242
+ if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
243
+ model.reset_classifier(0)
244
+ layers['body'] = model
245
+ hooks.extend(self.feature_info.get_dicts())
246
+ else:
247
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
248
+ remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
249
+ for f in self.feature_info.get_dicts()}
250
+ for new_name, old_name, module in modules:
251
+ layers[new_name] = module
252
+ for fn, fm in module.named_modules(prefix=old_name):
253
+ if fn in remaining:
254
+ hooks.append(dict(module=fn, hook_type=remaining[fn]))
255
+ del remaining[fn]
256
+ if not remaining:
257
+ break
258
+ assert not remaining, f'Return layers ({remaining}) are not present in model'
259
+ self.update(layers)
260
+ self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
261
+
262
+ def forward(self, x):
263
+ for name, module in self.items():
264
+ x = module(x)
265
+ out = self.hooks.get_output(x.device)
266
+ return out if self.out_as_dict else list(out.values())
TimeSformer/timesformer/models/head_helper.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ """ResNe(X)t Head helper."""
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ class ResNetBasicHead(nn.Module):
9
+ """
10
+ ResNe(X)t 3D head.
11
+ This layer performs a fully-connected projection during training, when the
12
+ input size is 1x1x1. It performs a convolutional projection during testing
13
+ when the input size is larger than 1x1x1. If the inputs are from multiple
14
+ different pathways, the inputs will be concatenated after pooling.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dim_in,
20
+ num_classes,
21
+ pool_size,
22
+ dropout_rate=0.0,
23
+ act_func="softmax",
24
+ ):
25
+ """
26
+ The `__init__` method of any subclass should also contain these
27
+ arguments.
28
+ ResNetBasicHead takes p pathways as input where p in [1, infty].
29
+
30
+ Args:
31
+ dim_in (list): the list of channel dimensions of the p inputs to the
32
+ ResNetHead.
33
+ num_classes (int): the channel dimensions of the p outputs to the
34
+ ResNetHead.
35
+ pool_size (list): the list of kernel sizes of p spatial temporal
36
+ poolings, temporal pool kernel size, spatial pool kernel size,
37
+ spatial pool kernel size in order.
38
+ dropout_rate (float): dropout rate. If equal to 0.0, perform no
39
+ dropout.
40
+ act_func (string): activation function to use. 'softmax': applies
41
+ softmax on the output. 'sigmoid': applies sigmoid on the output.
42
+ """
43
+ super(ResNetBasicHead, self).__init__()
44
+ assert (
45
+ len({len(pool_size), len(dim_in)}) == 1
46
+ ), "pathway dimensions are not consistent."
47
+ self.num_pathways = len(pool_size)
48
+
49
+ for pathway in range(self.num_pathways):
50
+ if pool_size[pathway] is None:
51
+ avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
52
+ else:
53
+ avg_pool = nn.AvgPool3d(pool_size[pathway], stride=1)
54
+ self.add_module("pathway{}_avgpool".format(pathway), avg_pool)
55
+
56
+ if dropout_rate > 0.0:
57
+ self.dropout = nn.Dropout(dropout_rate)
58
+ # Perform FC in a fully convolutional manner. The FC layer will be
59
+ # initialized with a different std comparing to convolutional layers.
60
+ self.projection = nn.Linear(sum(dim_in), num_classes, bias=True)
61
+
62
+ # Softmax for evaluation and testing.
63
+ if act_func == "softmax":
64
+ self.act = nn.Softmax(dim=4)
65
+ elif act_func == "sigmoid":
66
+ self.act = nn.Sigmoid()
67
+ else:
68
+ raise NotImplementedError(
69
+ "{} is not supported as an activation"
70
+ "function.".format(act_func)
71
+ )
72
+
73
+ def forward(self, inputs):
74
+ assert (
75
+ len(inputs) == self.num_pathways
76
+ ), "Input tensor does not contain {} pathway".format(self.num_pathways)
77
+ pool_out = []
78
+ for pathway in range(self.num_pathways):
79
+ m = getattr(self, "pathway{}_avgpool".format(pathway))
80
+ pool_out.append(m(inputs[pathway]))
81
+ x = torch.cat(pool_out, 1)
82
+ # (N, C, T, H, W) -> (N, T, H, W, C).
83
+ x = x.permute((0, 2, 3, 4, 1))
84
+ # Perform dropout.
85
+ if hasattr(self, "dropout"):
86
+ x = self.dropout(x)
87
+ x = self.projection(x)
88
+
89
+ # Performs fully convlutional inference.
90
+ if not self.training:
91
+ x = self.act(x)
92
+ x = x.mean([1, 2, 3])
93
+
94
+ x = x.view(x.shape[0], -1)
95
+ return x
96
+
97
+
98
+ class X3DHead(nn.Module):
99
+ """
100
+ X3D head.
101
+ This layer performs a fully-connected projection during training, when the
102
+ input size is 1x1x1. It performs a convolutional projection during testing
103
+ when the input size is larger than 1x1x1. If the inputs are from multiple
104
+ different pathways, the inputs will be concatenated after pooling.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ dim_in,
110
+ dim_inner,
111
+ dim_out,
112
+ num_classes,
113
+ pool_size,
114
+ dropout_rate=0.0,
115
+ act_func="softmax",
116
+ inplace_relu=True,
117
+ eps=1e-5,
118
+ bn_mmt=0.1,
119
+ norm_module=nn.BatchNorm3d,
120
+ bn_lin5_on=False,
121
+ ):
122
+ """
123
+ The `__init__` method of any subclass should also contain these
124
+ arguments.
125
+ X3DHead takes a 5-dim feature tensor (BxCxTxHxW) as input.
126
+
127
+ Args:
128
+ dim_in (float): the channel dimension C of the input.
129
+ num_classes (int): the channel dimensions of the output.
130
+ pool_size (float): a single entry list of kernel size for
131
+ spatiotemporal pooling for the TxHxW dimensions.
132
+ dropout_rate (float): dropout rate. If equal to 0.0, perform no
133
+ dropout.
134
+ act_func (string): activation function to use. 'softmax': applies
135
+ softmax on the output. 'sigmoid': applies sigmoid on the output.
136
+ inplace_relu (bool): if True, calculate the relu on the original
137
+ input without allocating new memory.
138
+ eps (float): epsilon for batch norm.
139
+ bn_mmt (float): momentum for batch norm. Noted that BN momentum in
140
+ PyTorch = 1 - BN momentum in Caffe2.
141
+ norm_module (nn.Module): nn.Module for the normalization layer. The
142
+ default is nn.BatchNorm3d.
143
+ bn_lin5_on (bool): if True, perform normalization on the features
144
+ before the classifier.
145
+ """
146
+ super(X3DHead, self).__init__()
147
+ self.pool_size = pool_size
148
+ self.dropout_rate = dropout_rate
149
+ self.num_classes = num_classes
150
+ self.act_func = act_func
151
+ self.eps = eps
152
+ self.bn_mmt = bn_mmt
153
+ self.inplace_relu = inplace_relu
154
+ self.bn_lin5_on = bn_lin5_on
155
+ self._construct_head(dim_in, dim_inner, dim_out, norm_module)
156
+
157
+ def _construct_head(self, dim_in, dim_inner, dim_out, norm_module):
158
+
159
+ self.conv_5 = nn.Conv3d(
160
+ dim_in,
161
+ dim_inner,
162
+ kernel_size=(1, 1, 1),
163
+ stride=(1, 1, 1),
164
+ padding=(0, 0, 0),
165
+ bias=False,
166
+ )
167
+ self.conv_5_bn = norm_module(
168
+ num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt
169
+ )
170
+ self.conv_5_relu = nn.ReLU(self.inplace_relu)
171
+
172
+ if self.pool_size is None:
173
+ self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
174
+ else:
175
+ self.avg_pool = nn.AvgPool3d(self.pool_size, stride=1)
176
+
177
+ self.lin_5 = nn.Conv3d(
178
+ dim_inner,
179
+ dim_out,
180
+ kernel_size=(1, 1, 1),
181
+ stride=(1, 1, 1),
182
+ padding=(0, 0, 0),
183
+ bias=False,
184
+ )
185
+ if self.bn_lin5_on:
186
+ self.lin_5_bn = norm_module(
187
+ num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
188
+ )
189
+ self.lin_5_relu = nn.ReLU(self.inplace_relu)
190
+
191
+ if self.dropout_rate > 0.0:
192
+ self.dropout = nn.Dropout(self.dropout_rate)
193
+ # Perform FC in a fully convolutional manner. The FC layer will be
194
+ # initialized with a different std comparing to convolutional layers.
195
+ self.projection = nn.Linear(dim_out, self.num_classes, bias=True)
196
+
197
+ # Softmax for evaluation and testing.
198
+ if self.act_func == "softmax":
199
+ self.act = nn.Softmax(dim=4)
200
+ elif self.act_func == "sigmoid":
201
+ self.act = nn.Sigmoid()
202
+ else:
203
+ raise NotImplementedError(
204
+ "{} is not supported as an activation"
205
+ "function.".format(self.act_func)
206
+ )
207
+
208
+ def forward(self, inputs):
209
+ # In its current design the X3D head is only useable for a single
210
+ # pathway input.
211
+ assert len(inputs) == 1, "Input tensor does not contain 1 pathway"
212
+ x = self.conv_5(inputs[0])
213
+ x = self.conv_5_bn(x)
214
+ x = self.conv_5_relu(x)
215
+ x = self.avg_pool(x)
216
+
217
+ x = self.lin_5(x)
218
+ if self.bn_lin5_on:
219
+ x = self.lin_5_bn(x)
220
+ x = self.lin_5_relu(x)
221
+
222
+ # (N, C, T, H, W) -> (N, T, H, W, C).
223
+ x = x.permute((0, 2, 3, 4, 1))
224
+ # Perform dropout.
225
+ if hasattr(self, "dropout"):
226
+ x = self.dropout(x)
227
+ x = self.projection(x)
228
+
229
+ # Performs fully convlutional inference.
230
+ if not self.training:
231
+ x = self.act(x)
232
+ x = x.mean([1, 2, 3])
233
+
234
+ x = x.view(x.shape[0], -1)
235
+ return x
TimeSformer/timesformer/models/helpers.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # Copyright 2020 Ross Wightman
3
+ # Modified model creation / weight loading / state_dict helpers
4
+
5
+ import logging
6
+ import os
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import Callable
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.model_zoo as model_zoo
15
+ import torch.nn.functional as F
16
+
17
+ from timesformer.models.features import FeatureListNet, FeatureDictNet, FeatureHookNet
18
+ from timesformer.models.conv2d_same import Conv2dSame
19
+ from timesformer.models.linear import Linear
20
+
21
+
22
+ _logger = logging.getLogger(__name__)
23
+
24
+ def load_state_dict(checkpoint_path, use_ema=False):
25
+ if checkpoint_path and os.path.isfile(checkpoint_path):
26
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
27
+ state_dict_key = 'state_dict'
28
+ if isinstance(checkpoint, dict):
29
+ if use_ema and 'state_dict_ema' in checkpoint:
30
+ state_dict_key = 'state_dict_ema'
31
+ if state_dict_key and state_dict_key in checkpoint:
32
+ new_state_dict = OrderedDict()
33
+ for k, v in checkpoint[state_dict_key].items():
34
+ # strip `module.` prefix
35
+ name = k[7:] if k.startswith('module') else k
36
+ new_state_dict[name] = v
37
+ state_dict = new_state_dict
38
+ elif 'model_state' in checkpoint:
39
+ state_dict_key = 'model_state'
40
+ new_state_dict = OrderedDict()
41
+ for k, v in checkpoint[state_dict_key].items():
42
+ # strip `model.` prefix
43
+ name = k[6:] if k.startswith('model') else k
44
+ new_state_dict[name] = v
45
+ state_dict = new_state_dict
46
+ else:
47
+ state_dict = checkpoint
48
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
49
+ return state_dict
50
+ else:
51
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
52
+ raise FileNotFoundError()
53
+
54
+
55
+ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
56
+ state_dict = load_state_dict(checkpoint_path, use_ema)
57
+ model.load_state_dict(state_dict, strict=strict)
58
+
59
+
60
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
61
+ resume_epoch = None
62
+ if os.path.isfile(checkpoint_path):
63
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
64
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
65
+ if log_info:
66
+ _logger.info('Restoring model state from checkpoint...')
67
+ new_state_dict = OrderedDict()
68
+ for k, v in checkpoint['state_dict'].items():
69
+ name = k[7:] if k.startswith('module') else k
70
+ new_state_dict[name] = v
71
+ model.load_state_dict(new_state_dict)
72
+
73
+ if optimizer is not None and 'optimizer' in checkpoint:
74
+ if log_info:
75
+ _logger.info('Restoring optimizer state from checkpoint...')
76
+ optimizer.load_state_dict(checkpoint['optimizer'])
77
+
78
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
79
+ if log_info:
80
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
81
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
82
+
83
+ if 'epoch' in checkpoint:
84
+ resume_epoch = checkpoint['epoch']
85
+ if 'version' in checkpoint and checkpoint['version'] > 1:
86
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
87
+
88
+ if log_info:
89
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
90
+ else:
91
+ model.load_state_dict(checkpoint)
92
+ if log_info:
93
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
94
+ return resume_epoch
95
+ else:
96
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
97
+ raise FileNotFoundError()
98
+
99
+
100
+ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True):
101
+ if cfg is None:
102
+ cfg = getattr(model, 'default_cfg')
103
+ if cfg is None or 'url' not in cfg or not cfg['url']:
104
+ _logger.warning("Pretrained model URL is invalid, using random initialization.")
105
+ return
106
+
107
+ if len(pretrained_model) == 0:
108
+ state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
109
+ else:
110
+ try:
111
+ state_dict = load_state_dict(pretrained_model)['model']
112
+ except:
113
+ state_dict = load_state_dict(pretrained_model)
114
+
115
+
116
+ if filter_fn is not None:
117
+ state_dict = filter_fn(state_dict)
118
+
119
+ if in_chans == 1:
120
+ conv1_name = cfg['first_conv']
121
+ _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
122
+ conv1_weight = state_dict[conv1_name + '.weight']
123
+ conv1_type = conv1_weight.dtype
124
+ conv1_weight = conv1_weight.float()
125
+ O, I, J, K = conv1_weight.shape
126
+ if I > 3:
127
+ assert conv1_weight.shape[1] % 3 == 0
128
+ # For models with space2depth stems
129
+ conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
130
+ conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
131
+ else:
132
+ conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
133
+ conv1_weight = conv1_weight.to(conv1_type)
134
+ state_dict[conv1_name + '.weight'] = conv1_weight
135
+ elif in_chans != 3:
136
+ conv1_name = cfg['first_conv']
137
+ conv1_weight = state_dict[conv1_name + '.weight']
138
+ conv1_type = conv1_weight.dtype
139
+ conv1_weight = conv1_weight.float()
140
+ O, I, J, K = conv1_weight.shape
141
+ if I != 3:
142
+ _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
143
+ del state_dict[conv1_name + '.weight']
144
+ strict = False
145
+ else:
146
+ _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
147
+ repeat = int(math.ceil(in_chans / 3))
148
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
149
+ conv1_weight *= (3 / float(in_chans))
150
+ conv1_weight = conv1_weight.to(conv1_type)
151
+ state_dict[conv1_name + '.weight'] = conv1_weight
152
+
153
+
154
+ classifier_name = cfg['classifier']
155
+ if num_classes == 1000 and cfg['num_classes'] == 1001:
156
+ # special case for imagenet trained models with extra background class in pretrained weights
157
+ classifier_weight = state_dict[classifier_name + '.weight']
158
+ state_dict[classifier_name + '.weight'] = classifier_weight[1:]
159
+ classifier_bias = state_dict[classifier_name + '.bias']
160
+ state_dict[classifier_name + '.bias'] = classifier_bias[1:]
161
+ elif num_classes != state_dict[classifier_name + '.weight'].size(0):
162
+ #print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
163
+ # completely discard fully connected for all other differences between pretrained and created model
164
+ del state_dict[classifier_name + '.weight']
165
+ del state_dict[classifier_name + '.bias']
166
+ strict = False
167
+
168
+
169
+ ## Resizing the positional embeddings in case they don't match
170
+ if num_patches + 1 != state_dict['pos_embed'].size(1):
171
+ pos_embed = state_dict['pos_embed']
172
+ cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
173
+ other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
174
+ new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
175
+ new_pos_embed = new_pos_embed.transpose(1, 2)
176
+ new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
177
+ state_dict['pos_embed'] = new_pos_embed
178
+
179
+ ## Resizing time embeddings in case they don't match
180
+ if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
181
+ time_embed = state_dict['time_embed'].transpose(1, 2)
182
+ new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
183
+ state_dict['time_embed'] = new_time_embed.transpose(1, 2)
184
+
185
+ ## Initializing temporal attention
186
+ if attention_type == 'divided_space_time':
187
+ new_state_dict = state_dict.copy()
188
+ for key in state_dict:
189
+ if 'blocks' in key and 'attn' in key:
190
+ new_key = key.replace('attn','temporal_attn')
191
+ if not new_key in state_dict:
192
+ new_state_dict[new_key] = state_dict[key]
193
+ else:
194
+ new_state_dict[new_key] = state_dict[new_key]
195
+ if 'blocks' in key and 'norm1' in key:
196
+ new_key = key.replace('norm1','temporal_norm1')
197
+ if not new_key in state_dict:
198
+ new_state_dict[new_key] = state_dict[key]
199
+ else:
200
+ new_state_dict[new_key] = state_dict[new_key]
201
+ state_dict = new_state_dict
202
+
203
+ ## Loading the weights
204
+ model.load_state_dict(state_dict, strict=False)
205
+
206
+
207
+ def extract_layer(model, layer):
208
+ layer = layer.split('.')
209
+ module = model
210
+ if hasattr(model, 'module') and layer[0] != 'module':
211
+ module = model.module
212
+ if not hasattr(model, 'module') and layer[0] == 'module':
213
+ layer = layer[1:]
214
+ for l in layer:
215
+ if hasattr(module, l):
216
+ if not l.isdigit():
217
+ module = getattr(module, l)
218
+ else:
219
+ module = module[int(l)]
220
+ else:
221
+ return module
222
+ return module
223
+
224
+
225
+ def set_layer(model, layer, val):
226
+ layer = layer.split('.')
227
+ module = model
228
+ if hasattr(model, 'module') and layer[0] != 'module':
229
+ module = model.module
230
+ lst_index = 0
231
+ module2 = module
232
+ for l in layer:
233
+ if hasattr(module2, l):
234
+ if not l.isdigit():
235
+ module2 = getattr(module2, l)
236
+ else:
237
+ module2 = module2[int(l)]
238
+ lst_index += 1
239
+ lst_index -= 1
240
+ for l in layer[:lst_index]:
241
+ if not l.isdigit():
242
+ module = getattr(module, l)
243
+ else:
244
+ module = module[int(l)]
245
+ l = layer[lst_index]
246
+ setattr(module, l, val)
247
+
248
+
249
+ def adapt_model_from_string(parent_module, model_string):
250
+ separator = '***'
251
+ state_dict = {}
252
+ lst_shape = model_string.split(separator)
253
+ for k in lst_shape:
254
+ k = k.split(':')
255
+ key = k[0]
256
+ shape = k[1][1:-1].split(',')
257
+ if shape[0] != '':
258
+ state_dict[key] = [int(i) for i in shape]
259
+
260
+ new_module = deepcopy(parent_module)
261
+ for n, m in parent_module.named_modules():
262
+ old_module = extract_layer(parent_module, n)
263
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
264
+ if isinstance(old_module, Conv2dSame):
265
+ conv = Conv2dSame
266
+ else:
267
+ conv = nn.Conv2d
268
+ s = state_dict[n + '.weight']
269
+ in_channels = s[1]
270
+ out_channels = s[0]
271
+ g = 1
272
+ if old_module.groups > 1:
273
+ in_channels = out_channels
274
+ g = in_channels
275
+ new_conv = conv(
276
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
277
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
278
+ groups=g, stride=old_module.stride)
279
+ set_layer(new_module, n, new_conv)
280
+ if isinstance(old_module, nn.BatchNorm2d):
281
+ new_bn = nn.BatchNorm2d(
282
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
283
+ affine=old_module.affine, track_running_stats=True)
284
+ set_layer(new_module, n, new_bn)
285
+ if isinstance(old_module, nn.Linear):
286
+ num_features = state_dict[n + '.weight'][1]
287
+ new_fc = Linear(
288
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
289
+ set_layer(new_module, n, new_fc)
290
+ if hasattr(new_module, 'num_features'):
291
+ new_module.num_features = num_features
292
+ new_module.eval()
293
+ parent_module.eval()
294
+
295
+ return new_module
296
+
297
+
298
+ def adapt_model_from_file(parent_module, model_variant):
299
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
300
+ with open(adapt_file, 'r') as f:
301
+ return adapt_model_from_string(parent_module, f.read().strip())
302
+
303
+
304
+ def default_cfg_for_features(default_cfg):
305
+ default_cfg = deepcopy(default_cfg)
306
+ # remove default pretrained cfg fields that don't have much relevance for feature backbone
307
+ to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
308
+ for tr in to_remove:
309
+ default_cfg.pop(tr, None)
310
+ return default_cfg
311
+
312
+
313
+ def build_model_with_cfg(
314
+ model_cls: Callable,
315
+ variant: str,
316
+ pretrained: bool,
317
+ default_cfg: dict,
318
+ model_cfg: dict = None,
319
+ feature_cfg: dict = None,
320
+ pretrained_strict: bool = True,
321
+ pretrained_filter_fn: Callable = None,
322
+ **kwargs):
323
+ pruned = kwargs.pop('pruned', False)
324
+ features = False
325
+ feature_cfg = feature_cfg or {}
326
+
327
+ if kwargs.pop('features_only', False):
328
+ features = True
329
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
330
+ if 'out_indices' in kwargs:
331
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
332
+
333
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
334
+ model.default_cfg = deepcopy(default_cfg)
335
+
336
+ if pruned:
337
+ model = adapt_model_from_file(model, variant)
338
+
339
+ # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
340
+ num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
341
+ if pretrained:
342
+ load_pretrained(
343
+ model,
344
+ num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
345
+ filter_fn=pretrained_filter_fn, strict=pretrained_strict)
346
+
347
+ if features:
348
+ feature_cls = FeatureListNet
349
+ if 'feature_cls' in feature_cfg:
350
+ feature_cls = feature_cfg.pop('feature_cls')
351
+ if isinstance(feature_cls, str):
352
+ feature_cls = feature_cls.lower()
353
+ if 'hook' in feature_cls:
354
+ feature_cls = FeatureHookNet
355
+ else:
356
+ assert False, f'Unknown feature class {feature_cls}'
357
+ model = feature_cls(model, **feature_cfg)
358
+ model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
359
+
360
+ return model