Upload 47 files
Browse files- .gitattributes +3 -0
- CONTRIBUTING.md +29 -0
- LICENSE +202 -0
- README.md +276 -12
- WINDOWS_INSTALLATION.md +89 -0
- cog.yaml +23 -0
- datasets/create_middlebury_tfrecord.py +164 -0
- datasets/create_ucf101_tfrecord.py +138 -0
- datasets/create_vimeo90K_tfrecord.py +167 -0
- datasets/create_xiph_tfrecord.py +146 -0
- datasets/util.py +204 -0
- eval/.DS_Store +0 -0
- eval/config/middlebury.gin +18 -0
- eval/config/ucf101.gin +18 -0
- eval/config/vimeo_90K.gin +18 -0
- eval/config/xiph_2K.gin +18 -0
- eval/config/xiph_4K.gin +18 -0
- eval/eval_cli.py +216 -0
- eval/interpolator.py +209 -0
- eval/interpolator_cli.py +197 -0
- eval/interpolator_test.py +109 -0
- eval/util.py +162 -0
- losses/losses.py +266 -0
- losses/vgg19_loss.py +362 -0
- models/.DS_Store +0 -0
- models/film_net/feature_extractor.py +193 -0
- models/film_net/fusion.py +140 -0
- models/film_net/interpolator.py +207 -0
- models/film_net/options.py +81 -0
- models/film_net/pyramid_flow_estimator.py +163 -0
- models/film_net/util.py +143 -0
- moment.gif +3 -0
- photos/one.png +3 -0
- photos/two.png +3 -0
- predict.py +88 -0
- requirements.txt +14 -0
- training/.DS_Store +0 -0
- training/augmentation_lib.py +220 -0
- training/build_saved_model_cli.py +98 -0
- training/config/film_net-L1.gin +55 -0
- training/config/film_net-Style.gin +66 -0
- training/config/film_net-VGG.gin +64 -0
- training/data_lib.py +296 -0
- training/eval_lib.py +131 -0
- training/metrics_lib.py +142 -0
- training/model_lib.py +53 -0
- training/train.py +131 -0
- training/train_lib.py +343 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
moment.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
photos/one.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
photos/two.png filter=lfs diff=lfs merge=lfs -text
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Contribute
|
2 |
+
|
3 |
+
We'd love to accept your patches and contributions to this project. There are
|
4 |
+
just a few small guidelines you need to follow.
|
5 |
+
|
6 |
+
## Contributor License Agreement
|
7 |
+
|
8 |
+
Contributions to this project must be accompanied by a Contributor License
|
9 |
+
Agreement (CLA). You (or your employer) retain the copyright to your
|
10 |
+
contribution; this simply gives us permission to use and redistribute your
|
11 |
+
contributions as part of the project. Head over to
|
12 |
+
<https://cla.developers.google.com/> to see your current agreements on file or
|
13 |
+
to sign a new one.
|
14 |
+
|
15 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
16 |
+
(even if it was for a different project), you probably don't need to do it
|
17 |
+
again.
|
18 |
+
|
19 |
+
## Code Reviews
|
20 |
+
|
21 |
+
All submissions, including submissions by project members, require review. We
|
22 |
+
use GitHub pull requests for this purpose. Consult
|
23 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
24 |
+
information on using pull requests.
|
25 |
+
|
26 |
+
## Community Guidelines
|
27 |
+
|
28 |
+
This project follows
|
29 |
+
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,276 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FILM: Frame Interpolation for Large Motion
|
2 |
+
|
3 |
+
### [Website](https://film-net.github.io/) | [Paper](https://arxiv.org/pdf/2202.04901.pdf) | [Google AI Blog](https://ai.googleblog.com/2022/10/large-motion-frame-interpolation.html) | [Tensorflow Hub Colab](https://www.tensorflow.org/hub/tutorials/tf_hub_film_example) | [YouTube](https://www.youtube.com/watch?v=OAD-BieIjH4) <br>
|
4 |
+
|
5 |
+
The official Tensorflow 2 implementation of our high quality frame interpolation neural network. We present a unified single-network approach that doesn't use additional pre-trained networks, like optical flow or depth, and yet achieve state-of-the-art results. We use a multi-scale feature extractor that shares the same convolution weights across the scales. Our model is trainable from frame triplets alone. <br>
|
6 |
+
|
7 |
+
[FILM: Frame Interpolation for Large Motion](https://arxiv.org/abs/2202.04901) <br />
|
8 |
+
[Fitsum Reda](https://fitsumreda.github.io/)<sup>1</sup>, [Janne Kontkanen](https://scholar.google.com/citations?user=MnXc4JQAAAAJ&hl=en)<sup>1</sup>, [Eric Tabellion](http://www.tabellion.org/et/)<sup>1</sup>, [Deqing Sun](https://deqings.github.io/)<sup>1</sup>, [Caroline Pantofaru](https://scholar.google.com/citations?user=vKAKE1gAAAAJ&hl=en)<sup>1</sup>, [Brian Curless](https://homes.cs.washington.edu/~curless/)<sup>1,2</sup><br />
|
9 |
+
<sup>1</sup>Google Research, <sup>2</sup>University of Washington<br />
|
10 |
+
In ECCV 2022.
|
11 |
+
|
12 |
+
![A sample 2 seconds moment.](https://github.com/googlestaging/frame-interpolation/blob/main/moment.gif)
|
13 |
+
FILM transforms near-duplicate photos into a slow motion footage that look like it is shot with a video camera.
|
14 |
+
|
15 |
+
## Web Demo
|
16 |
+
|
17 |
+
Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/johngoad/frame-interpolation)
|
18 |
+
|
19 |
+
Try the interpolation model with the replicate web demo at
|
20 |
+
[![Replicate](https://replicate.com/google-research/frame-interpolation/badge)](https://replicate.com/google-research/frame-interpolation)
|
21 |
+
|
22 |
+
Try FILM to interpolate between two or more images with the PyTTI-Tools at [![PyTTI-Tools:FILM](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/pytti-tools/frame-interpolation/blob/main/PyTTI_Tools_FiLM-colab.ipynb#scrollTo=-7TD7YZJbsy_)
|
23 |
+
|
24 |
+
An alternative Colab for running FILM on arbitrarily more input images, not just on two images, [![FILM-Gdrive](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1NuaPPSvUhYafymUf2mEkvhnEtpD5oihs)
|
25 |
+
|
26 |
+
## Change Log
|
27 |
+
* **Nov 28, 2022**: Upgrade `eval.interpolator_cli` for **high resolution frame interpolation**. `--block_height` and `--block_width` determine the total number of patches (`block_height*block_width`) to subdivide the input images. By default, both arguments are set to 1, and so no subdivision will be done.
|
28 |
+
* **Mar 12, 2022**: Support for Windows, see [WINDOWS_INSTALLATION.md](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md).
|
29 |
+
* **Mar 09, 2022**: Support for **high resolution frame interpolation**. Set `--block_height` and `--block_width` in `eval.interpolator_test` to extract patches from the inputs, and reconstruct the interpolated frame from the iteratively interpolated patches.
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
* Get Frame Interpolation source codes
|
34 |
+
|
35 |
+
```
|
36 |
+
git clone https://github.com/google-research/frame-interpolation
|
37 |
+
cd frame-interpolation
|
38 |
+
```
|
39 |
+
|
40 |
+
* Optionally, pull the recommended Docker base image
|
41 |
+
|
42 |
+
```
|
43 |
+
docker pull gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest
|
44 |
+
```
|
45 |
+
|
46 |
+
* If you do not use Docker, set up your NVIDIA GPU environment with:
|
47 |
+
* [Anaconda Python 3.9](https://www.anaconda.com/products/individual)
|
48 |
+
* [CUDA Toolkit 11.2.1](https://developer.nvidia.com/cuda-11.2.1-download-archive)
|
49 |
+
* [cuDNN 8.1.0](https://developer.nvidia.com/rdp/cudnn-download)
|
50 |
+
|
51 |
+
* Install frame interpolation dependencies
|
52 |
+
|
53 |
+
```
|
54 |
+
pip3 install -r requirements.txt
|
55 |
+
sudo apt-get install -y ffmpeg
|
56 |
+
```
|
57 |
+
|
58 |
+
### See [WINDOWS_INSTALLATION](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md) for Windows Support
|
59 |
+
|
60 |
+
## Pre-trained Models
|
61 |
+
|
62 |
+
* Create a directory where you can keep large files. Ideally, not in this
|
63 |
+
directory.
|
64 |
+
|
65 |
+
```
|
66 |
+
mkdir -p <pretrained_models>
|
67 |
+
```
|
68 |
+
|
69 |
+
* Download pre-trained TF2 Saved Models from
|
70 |
+
[google drive](https://drive.google.com/drive/folders/1q8110-qp225asX3DQvZnfLfJPkCHmDpy?usp=sharing)
|
71 |
+
and put into `<pretrained_models>`.
|
72 |
+
|
73 |
+
The downloaded folder should have the following structure:
|
74 |
+
|
75 |
+
```
|
76 |
+
<pretrained_models>/
|
77 |
+
├── film_net/
|
78 |
+
│ ├── L1/
|
79 |
+
│ ├── Style/
|
80 |
+
│ ├─�� VGG/
|
81 |
+
├── vgg/
|
82 |
+
│ ├── imagenet-vgg-verydeep-19.mat
|
83 |
+
```
|
84 |
+
|
85 |
+
## Running the Codes
|
86 |
+
|
87 |
+
The following instructions run the interpolator on the photos provided in
|
88 |
+
'frame-interpolation/photos'.
|
89 |
+
|
90 |
+
### One mid-frame interpolation
|
91 |
+
|
92 |
+
To generate an intermediate photo from the input near-duplicate photos, simply run:
|
93 |
+
|
94 |
+
```
|
95 |
+
python3 -m eval.interpolator_test \
|
96 |
+
--frame1 photos/one.png \
|
97 |
+
--frame2 photos/two.png \
|
98 |
+
--model_path <pretrained_models>/film_net/Style/saved_model \
|
99 |
+
--output_frame photos/output_middle.png
|
100 |
+
```
|
101 |
+
|
102 |
+
This will produce the sub-frame at `t=0.5` and save as 'photos/output_middle.png'.
|
103 |
+
|
104 |
+
### Many in-between frames interpolation
|
105 |
+
|
106 |
+
It takes in a set of directories identified by a glob (--pattern). Each directory
|
107 |
+
is expected to contain at least two input frames, with each contiguous frame
|
108 |
+
pair treated as an input to generate in-between frames. Frames should be named such that when sorted (naturally) with `natsort`, their desired order is unchanged.
|
109 |
+
|
110 |
+
```
|
111 |
+
python3 -m eval.interpolator_cli \
|
112 |
+
--pattern "photos" \
|
113 |
+
--model_path <pretrained_models>/film_net/Style/saved_model \
|
114 |
+
--times_to_interpolate 6 \
|
115 |
+
--output_video
|
116 |
+
```
|
117 |
+
|
118 |
+
You will find the interpolated frames (including the input frames) in
|
119 |
+
'photos/interpolated_frames/', and the interpolated video at
|
120 |
+
'photos/interpolated.mp4'.
|
121 |
+
|
122 |
+
The number of frames is determined by `--times_to_interpolate`, which controls
|
123 |
+
the number of times the frame interpolator is invoked. When the number of frames
|
124 |
+
in a directory is `num_frames`, the number of output frames will be
|
125 |
+
`(2^times_to_interpolate+1)*(num_frames-1)`.
|
126 |
+
|
127 |
+
## Datasets
|
128 |
+
|
129 |
+
We use [Vimeo-90K](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip) as
|
130 |
+
our main training dataset. For quantitative evaluations, we rely on commonly
|
131 |
+
used benchmark datasets, specifically:
|
132 |
+
|
133 |
+
* [Vimeo-90K](http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip)
|
134 |
+
* [Middlebury-Other](https://vision.middlebury.edu/flow/data)
|
135 |
+
* [UCF101](https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip)
|
136 |
+
* [Xiph](https://github.com/sniklaus/softmax-splatting/blob/master/benchmark.py)
|
137 |
+
|
138 |
+
### Creating a TFRecord
|
139 |
+
|
140 |
+
The training and benchmark evaluation scripts expect the frame triplets in the
|
141 |
+
[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) storage format. <br />
|
142 |
+
|
143 |
+
We have included scripts that encode the relevant frame triplets into a
|
144 |
+
[tf.train.Example](https://www.tensorflow.org/api_docs/python/tf/train/Example)
|
145 |
+
data format, and export to a TFRecord file. <br />
|
146 |
+
|
147 |
+
You can use the commands `python3 -m
|
148 |
+
datasets.create_<dataset_name>_tfrecord --help` for more information.
|
149 |
+
|
150 |
+
For example, run the command below to create a TFRecord for the Middlebury-other
|
151 |
+
dataset. Download the [images](https://vision.middlebury.edu/flow/data) and point `--input_dir` to the unzipped folder path.
|
152 |
+
|
153 |
+
```
|
154 |
+
python3 -m datasets.create_middlebury_tfrecord \
|
155 |
+
--input_dir=<root folder of middlebury-other> \
|
156 |
+
--output_tfrecord_filepath=<output tfrecord filepath> \
|
157 |
+
--num_shards=3
|
158 |
+
```
|
159 |
+
|
160 |
+
The above command will output a TFRecord file with 3 shards as `<output tfrecord filepath>@3`.
|
161 |
+
|
162 |
+
## Training
|
163 |
+
|
164 |
+
Below are our training gin configuration files for the different loss function:
|
165 |
+
|
166 |
+
```
|
167 |
+
training/
|
168 |
+
├── config/
|
169 |
+
│ ├── film_net-L1.gin
|
170 |
+
│ ├── film_net-VGG.gin
|
171 |
+
│ ├── film_net-Style.gin
|
172 |
+
```
|
173 |
+
|
174 |
+
To launch a training, simply pass the configuration filepath to the desired
|
175 |
+
experiment. <br />
|
176 |
+
By default, it uses all visible GPUs for training. To debug or train
|
177 |
+
on a CPU, append `--mode cpu`.
|
178 |
+
|
179 |
+
```
|
180 |
+
python3 -m training.train \
|
181 |
+
--gin_config training/config/<config filename>.gin \
|
182 |
+
--base_folder <base folder for all training runs> \
|
183 |
+
--label <descriptive label for the run>
|
184 |
+
```
|
185 |
+
|
186 |
+
* When training finishes, the folder structure will look like this:
|
187 |
+
|
188 |
+
```
|
189 |
+
<base_folder>/
|
190 |
+
├── <label>/
|
191 |
+
│ ├── config.gin
|
192 |
+
│ ├── eval/
|
193 |
+
│ ├── train/
|
194 |
+
│ ├── saved_model/
|
195 |
+
```
|
196 |
+
|
197 |
+
### Build a SavedModel
|
198 |
+
|
199 |
+
Optionally, to build a
|
200 |
+
[SavedModel](https://www.tensorflow.org/guide/saved_model) format from a trained
|
201 |
+
checkpoints folder, you can use this command:
|
202 |
+
|
203 |
+
```
|
204 |
+
python3 -m training.build_saved_model_cli \
|
205 |
+
--base_folder <base folder of training sessions> \
|
206 |
+
--label <the name of the run>
|
207 |
+
```
|
208 |
+
|
209 |
+
* By default, a SavedModel is created when the training loop ends, and it will be saved at
|
210 |
+
`<base_folder>/<label>/saved_model`.
|
211 |
+
|
212 |
+
## Evaluation on Benchmarks
|
213 |
+
|
214 |
+
Below, we provided the evaluation gin configuration files for the benchmarks we
|
215 |
+
have considered:
|
216 |
+
|
217 |
+
```
|
218 |
+
eval/
|
219 |
+
├── config/
|
220 |
+
│ ├── middlebury.gin
|
221 |
+
│ ├── ucf101.gin
|
222 |
+
│ ├── vimeo_90K.gin
|
223 |
+
│ ├── xiph_2K.gin
|
224 |
+
│ ├── xiph_4K.gin
|
225 |
+
```
|
226 |
+
|
227 |
+
To run an evaluation, simply pass the configuration file of the desired evaluation dataset. <br />
|
228 |
+
If a GPU is visible, it runs on it.
|
229 |
+
|
230 |
+
```
|
231 |
+
python3 -m eval.eval_cli \
|
232 |
+
--gin_config eval/config/<eval_dataset>.gin \
|
233 |
+
--model_path <pretrained_models>/film_net/L1/saved_model
|
234 |
+
```
|
235 |
+
|
236 |
+
The above command will produce the PSNR and SSIM scores presented in the paper.
|
237 |
+
|
238 |
+
## Citation
|
239 |
+
|
240 |
+
If you find this implementation useful in your works, please acknowledge it
|
241 |
+
appropriately by citing:
|
242 |
+
|
243 |
+
```
|
244 |
+
@inproceedings{reda2022film,
|
245 |
+
title = {FILM: Frame Interpolation for Large Motion},
|
246 |
+
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
|
247 |
+
booktitle = {European Conference on Computer Vision (ECCV)},
|
248 |
+
year = {2022}
|
249 |
+
}
|
250 |
+
```
|
251 |
+
|
252 |
+
```
|
253 |
+
@misc{film-tf,
|
254 |
+
title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
|
255 |
+
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
|
256 |
+
year = {2022},
|
257 |
+
publisher = {GitHub},
|
258 |
+
journal = {GitHub repository},
|
259 |
+
howpublished = {\url{https://github.com/google-research/frame-interpolation}}
|
260 |
+
}
|
261 |
+
```
|
262 |
+
|
263 |
+
## Acknowledgments
|
264 |
+
|
265 |
+
We would like to thank Richard Tucker, Jason Lai and David Minnen. We would also
|
266 |
+
like to thank Jamie Aspinall for the imagery included in this repository.
|
267 |
+
|
268 |
+
## Coding style
|
269 |
+
|
270 |
+
* 2 spaces for indentation
|
271 |
+
* 80 character line length
|
272 |
+
* PEP8 formatting
|
273 |
+
|
274 |
+
## Disclaimer
|
275 |
+
|
276 |
+
This is not an officially supported Google product.
|
WINDOWS_INSTALLATION.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [FILM](https://github.com/google-research/frame-interpolation): Windows Installation Instructions
|
2 |
+
|
3 |
+
## Anaconda Python 3.9 (Optional)
|
4 |
+
|
5 |
+
#### Install Anaconda3 Python3.9
|
6 |
+
* Go to [https://www.anaconda.com/products/individual](https://www.anaconda.com/products/individual) and click the "Download" button.
|
7 |
+
* Download the Windows [64-Bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86_64.exe) or [32-bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86.exe) Graphical Installer, depending on your system needs.
|
8 |
+
* Run the downloaded (`.exe`) file to begin the installation.
|
9 |
+
* (Optional) Check the "Add Anaconda3 to my PATH environment variable". You may get a 'red text' warning of its implications, you may ignore it for this setup.
|
10 |
+
|
11 |
+
#### Create a new Anaconda virtual environment
|
12 |
+
* Open a new Terminal
|
13 |
+
* Type the following command:
|
14 |
+
```
|
15 |
+
conda create -n frame_interpolation pip python=3.9
|
16 |
+
```
|
17 |
+
* The above command will create a new virtual environment with the name `frame_interpolation`
|
18 |
+
|
19 |
+
#### Activate the Anaconda virtual environment
|
20 |
+
* Activate the newly created virtual environment by typing in your terminal (Command Prompt or PowerShell)
|
21 |
+
```
|
22 |
+
conda activate frame_interpolation
|
23 |
+
```
|
24 |
+
* Once activated, your terminal should look like:
|
25 |
+
```
|
26 |
+
(frame_interpolation) <present working directory> >
|
27 |
+
```
|
28 |
+
|
29 |
+
## NVIDIA GPU Support
|
30 |
+
#### Install CUDA Toolkit
|
31 |
+
* Go to [https://developer.nvidia.com/cuda-11.2.1-download-archive](https://developer.nvidia.com/cuda-11.2.1-download-archive) and select your `Windows`.
|
32 |
+
* Download and install `CUDA Tookit 11.2.1`.
|
33 |
+
* Additional CUDA installation information available [here](https://docs.nvidia.com/cuda/archive/11.2.2/cuda-installation-guide-microsoft-windows/index.html).
|
34 |
+
|
35 |
+
#### Install cuDNN
|
36 |
+
* Go to [https://developer.nvidia.com/rdp/cudnn-download](https://developer.nvidia.com/rdp/cudnn-download).
|
37 |
+
* Create a user profile (if needed) and login.
|
38 |
+
* Select `cuDNN v8.1.0 (January 26th, 2021), for CUDA 11.0,11.1 and 11.2`.
|
39 |
+
* Download [cuDNN Library for Widnows (x86)](https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.0.77/11.2_20210127/cudnn-11.2-windows-x64-v8.1.0.77.zip).
|
40 |
+
* Extract the contents of the zipped folder (it contains a folder named `cuda`) into `<INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\`. `<INSTALL_PATH>` points to the installation directory specified during CUDA Toolkit installation. By default, `<INSTAL_PATH> = C:\Program Files`.
|
41 |
+
|
42 |
+
#### Environment Setup
|
43 |
+
* Add the following paths to your 'Advanced System Settings' > 'Environment Variables ...' > Edit 'Path', and add:
|
44 |
+
* <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin
|
45 |
+
* <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\libnvvp
|
46 |
+
* <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\include
|
47 |
+
* <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\CUPTI\lib64
|
48 |
+
* <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\cuda\bin
|
49 |
+
|
50 |
+
#### Verify Installation
|
51 |
+
* Open a **new** terminal and type `conda activate frame_interpolation`.
|
52 |
+
* Install (temporarily) tensorflow and run a simple operation, by typing:
|
53 |
+
```
|
54 |
+
pip install --ignore-installed --upgrade tensorflow==2.6.0
|
55 |
+
python -c "import tensorflow as tf;print(tf.reduce_sum(tf.random.normal([1000, 1000])))"
|
56 |
+
```
|
57 |
+
* You should see success messages: 'Created device /job:localhost/replica:0/task:0/device:GPU:0'.
|
58 |
+
|
59 |
+
## FILM Installation
|
60 |
+
* Get Frame Interpolation source codes
|
61 |
+
```
|
62 |
+
git clone https://github.com/google-research/frame-interpolation
|
63 |
+
cd frame-interpolation
|
64 |
+
```
|
65 |
+
* Install dependencies
|
66 |
+
```
|
67 |
+
pip install -r requirements.txt
|
68 |
+
conda install -c conda-forge ffmpeg
|
69 |
+
```
|
70 |
+
* Download pre-traned models, detailed [here](https://github.com/google-research/frame-interpolation#pre-trained-models).
|
71 |
+
|
72 |
+
## Running the Codes
|
73 |
+
* One mid-frame interpolation. Note: `python3` may not be recognized in Windows, so simply drop `3` as below.
|
74 |
+
```
|
75 |
+
python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --model_path <pretrained_models>\film_net\Style\saved_model --output_frame photos\output_middle.png
|
76 |
+
```
|
77 |
+
|
78 |
+
* Large resolution mid-frame interpolation: Set `block_height` and `--block_width` to subdivide along the height and width to create patches, where the interpolator will be run iteratively, and the resulting interpolated mid-patches will be reconstructed into a final mid-frame. In the example below, will create and run on 4 patches (2*2).
|
79 |
+
```
|
80 |
+
python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --block_height 2 --block_wdith 2 --model_path <pretrained_models>\film_net\Style\saved_model --output_frame photos\output_middle.png
|
81 |
+
```
|
82 |
+
* Many in-between frames interpolation
|
83 |
+
```
|
84 |
+
python -m eval.interpolator_cli --pattern "photos" --model_path <pretrained_models>\film_net\Style\saved_model --times_to_interpolate 6 --output_video
|
85 |
+
```
|
86 |
+
|
87 |
+
## Acknowledgments
|
88 |
+
|
89 |
+
This windows installation guide is heavily based on [tensorflow-object-detection-api-tutorial](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/install.html) .
|
cog.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.2"
|
4 |
+
python_version: "3.8"
|
5 |
+
system_packages:
|
6 |
+
- "libgl1-mesa-glx"
|
7 |
+
- "libglib2.0-0"
|
8 |
+
python_packages:
|
9 |
+
- "ipython==7.30.1"
|
10 |
+
- "tensorflow-gpu==2.8.0"
|
11 |
+
- "tensorflow-datasets==4.4.0"
|
12 |
+
- "tensorflow-addons==0.15.0"
|
13 |
+
- "absl-py==0.12.0"
|
14 |
+
- "gin-config==0.5.0"
|
15 |
+
- "parameterized==0.8.1"
|
16 |
+
- "mediapy==1.0.3"
|
17 |
+
- "scikit-image==0.19.1"
|
18 |
+
- "apache-beam==2.34.0"
|
19 |
+
run:
|
20 |
+
- apt-get update && apt-get install -y software-properties-common
|
21 |
+
- apt-get install ffmpeg -y
|
22 |
+
|
23 |
+
predict: "predict.py:Predictor"
|
datasets/create_middlebury_tfrecord.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Beam pipeline that generates Middlebury `Other Datasets` triplet TFRecords.
|
16 |
+
|
17 |
+
Middlebury interpolation evaluation dataset consists of two subsets.
|
18 |
+
|
19 |
+
(1) Two frames only, without the intermediate golden frame. A total of 12 such
|
20 |
+
pairs, with folder names (Army, Backyard, Basketball, Dumptruck,
|
21 |
+
Evergreen, Grove, Mequon, Schefflera, Teddy, Urban, Wooden, Yosemite)
|
22 |
+
|
23 |
+
(2) Two frames together with the intermediate golden frame. A total of 12 such
|
24 |
+
triplets, with folder names (Beanbags, Dimetrodon, DogDance, Grove2,
|
25 |
+
Grove3, Hydrangea, MiniCooper, RubberWhale, Urban2, Urban3, Venus, Walking)
|
26 |
+
|
27 |
+
This script runs on (2), i.e. the dataset with the golden frames. For more
|
28 |
+
information, visit https://vision.middlebury.edu/flow/data.
|
29 |
+
|
30 |
+
Input to the script is the root-folder that contains the unzipped folders
|
31 |
+
of input pairs (other-data) and golen frames (other-gt-interp).
|
32 |
+
|
33 |
+
Output TFRecord is a tf.train.Example proto of each image triplet.
|
34 |
+
The feature_map takes the form:
|
35 |
+
feature_map {
|
36 |
+
'frame_0/encoded':
|
37 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
38 |
+
'frame_0/format':
|
39 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
40 |
+
'frame_0/height':
|
41 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
42 |
+
'frame_0/width':
|
43 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
44 |
+
'frame_1/encoded':
|
45 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
46 |
+
'frame_1/format':
|
47 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
48 |
+
'frame_1/height':
|
49 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
50 |
+
'frame_1/width':
|
51 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
52 |
+
'frame_2/encoded':
|
53 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
54 |
+
'frame_2/format':
|
55 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
56 |
+
'frame_2/height':
|
57 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
58 |
+
'frame_2/width':
|
59 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
60 |
+
'path':
|
61 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
62 |
+
}
|
63 |
+
|
64 |
+
Usage example:
|
65 |
+
python3 -m frame_interpolation.datasets.create_middlebury_tfrecord \
|
66 |
+
--input_dir=<root folder of middlebury-other> \
|
67 |
+
--output_tfrecord_filepath=<output tfrecord filepath>
|
68 |
+
"""
|
69 |
+
|
70 |
+
import os
|
71 |
+
|
72 |
+
from . import util
|
73 |
+
from absl import app
|
74 |
+
from absl import flags
|
75 |
+
from absl import logging
|
76 |
+
import apache_beam as beam
|
77 |
+
import tensorflow as tf
|
78 |
+
|
79 |
+
_INPUT_DIR = flags.DEFINE_string(
|
80 |
+
'input_dir',
|
81 |
+
default='/root/path/to/middlebury-other',
|
82 |
+
help='Path to the root directory of the `Other Datasets` of the Middlebury '
|
83 |
+
'interpolation evaluation data. '
|
84 |
+
'We expect the data to have been downloaded and unzipped. \n'
|
85 |
+
'Folder structures:\n'
|
86 |
+
'| raw_middlebury_other_dataset/\n'
|
87 |
+
'| other-data/\n'
|
88 |
+
'| | Beanbags\n'
|
89 |
+
'| | | frame10.png\n'
|
90 |
+
'| | | frame11.png\n'
|
91 |
+
'| | Dimetrodon\n'
|
92 |
+
'| | | frame10.png\n'
|
93 |
+
'| | | frame11.png\n'
|
94 |
+
'| | ...\n'
|
95 |
+
'| other-gt-interp/\n'
|
96 |
+
'| | Beanbags\n'
|
97 |
+
'| | | frame10i11.png\n'
|
98 |
+
'| | Dimetrodon\n'
|
99 |
+
'| | | frame10i11.png\n'
|
100 |
+
'| | ...\n')
|
101 |
+
|
102 |
+
_INPUT_PAIRS_FOLDERNAME = flags.DEFINE_string(
|
103 |
+
'input_pairs_foldername',
|
104 |
+
default='other-data',
|
105 |
+
help='Foldername containing the folders of the input frame pairs.')
|
106 |
+
|
107 |
+
_GOLDEN_FOLDERNAME = flags.DEFINE_string(
|
108 |
+
'golden_foldername',
|
109 |
+
default='other-gt-interp',
|
110 |
+
help='Foldername containing the folders of the golden frame.')
|
111 |
+
|
112 |
+
_OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
|
113 |
+
'output_tfrecord_filepath',
|
114 |
+
default=None,
|
115 |
+
required=True,
|
116 |
+
help='Filepath to the output TFRecord file.')
|
117 |
+
|
118 |
+
_NUM_SHARDS = flags.DEFINE_integer('num_shards',
|
119 |
+
default=3,
|
120 |
+
help='Number of shards used for the output.')
|
121 |
+
|
122 |
+
# Image key -> basename for frame interpolator: start / middle / end frames.
|
123 |
+
_INTERPOLATOR_IMAGES_MAP = {
|
124 |
+
'frame_0': 'frame10.png',
|
125 |
+
'frame_1': 'frame10i11.png',
|
126 |
+
'frame_2': 'frame11.png',
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
def main(unused_argv):
|
131 |
+
"""Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
|
132 |
+
# Collect the list of folder paths containing the input and golen frames.
|
133 |
+
pairs_list = tf.io.gfile.listdir(
|
134 |
+
os.path.join(_INPUT_DIR.value, _INPUT_PAIRS_FOLDERNAME.value))
|
135 |
+
|
136 |
+
folder_names = [
|
137 |
+
_INPUT_PAIRS_FOLDERNAME.value, _GOLDEN_FOLDERNAME.value,
|
138 |
+
_INPUT_PAIRS_FOLDERNAME.value
|
139 |
+
]
|
140 |
+
triplet_dicts = []
|
141 |
+
for pair in pairs_list:
|
142 |
+
triplet_dict = {
|
143 |
+
image_key: os.path.join(_INPUT_DIR.value, folder, pair, image_basename)
|
144 |
+
for folder, (image_key, image_basename
|
145 |
+
) in zip(folder_names, _INTERPOLATOR_IMAGES_MAP.items())
|
146 |
+
}
|
147 |
+
triplet_dicts.append(triplet_dict)
|
148 |
+
|
149 |
+
p = beam.Pipeline('DirectRunner')
|
150 |
+
(p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
|
151 |
+
| 'GenerateSingleExample' >> beam.ParDo(
|
152 |
+
util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
|
153 |
+
| 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
|
154 |
+
file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
|
155 |
+
num_shards=_NUM_SHARDS.value,
|
156 |
+
coder=beam.coders.BytesCoder()))
|
157 |
+
result = p.run()
|
158 |
+
result.wait_until_finish()
|
159 |
+
|
160 |
+
logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
|
161 |
+
_OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
|
162 |
+
|
163 |
+
if __name__ == '__main__':
|
164 |
+
app.run(main)
|
datasets/create_ucf101_tfrecord.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Beam pipeline that generates UCF101 `interp_test` triplet TFRecords.
|
16 |
+
|
17 |
+
UCF101 interpolation evaluation dataset consists of 379 triplets, with the
|
18 |
+
middle frame being the golden intermediate. The dataset is available here:
|
19 |
+
https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip.
|
20 |
+
|
21 |
+
Input to the script is the root folder that contains the unzipped
|
22 |
+
`UCF101_results` folder.
|
23 |
+
|
24 |
+
Output TFRecord is a tf.train.Example proto of each image triplet.
|
25 |
+
The feature_map takes the form:
|
26 |
+
feature_map {
|
27 |
+
'frame_0/encoded':
|
28 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
29 |
+
'frame_0/format':
|
30 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
31 |
+
'frame_0/height':
|
32 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
33 |
+
'frame_0/width':
|
34 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
35 |
+
'frame_1/encoded':
|
36 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
37 |
+
'frame_1/format':
|
38 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
39 |
+
'frame_1/height':
|
40 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
41 |
+
'frame_1/width':
|
42 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
43 |
+
'frame_2/encoded':
|
44 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
45 |
+
'frame_2/format':
|
46 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
47 |
+
'frame_2/height':
|
48 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
49 |
+
'frame_2/width':
|
50 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
51 |
+
'path':
|
52 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
53 |
+
}
|
54 |
+
|
55 |
+
Usage example:
|
56 |
+
python3 -m frame_interpolation.datasets.create_ucf101_tfrecord \
|
57 |
+
--input_dir=<root folder of UCF101_results> \
|
58 |
+
--output_tfrecord_filepath=<output tfrecord filepath>
|
59 |
+
"""
|
60 |
+
|
61 |
+
import os
|
62 |
+
|
63 |
+
from . import util
|
64 |
+
from absl import app
|
65 |
+
from absl import flags
|
66 |
+
from absl import logging
|
67 |
+
import apache_beam as beam
|
68 |
+
import tensorflow as tf
|
69 |
+
|
70 |
+
_INPUT_DIR = flags.DEFINE_string(
|
71 |
+
'input_dir',
|
72 |
+
default='/root/path/to/UCF101_results/ucf101_interp_ours',
|
73 |
+
help='Path to the root directory of the `UCF101_results` of the UCF101 '
|
74 |
+
'interpolation evaluation data. '
|
75 |
+
'We expect the data to have been downloaded and unzipped. \n'
|
76 |
+
'Folder structures:\n'
|
77 |
+
'| raw_UCF101_results/\n'
|
78 |
+
'| ucf101_interp_ours/\n'
|
79 |
+
'| | 1/\n'
|
80 |
+
'| | | frame_00.png\n'
|
81 |
+
'| | | frame_01_gt.png\n'
|
82 |
+
'| | | frame_01_ours.png\n'
|
83 |
+
'| | | frame_02.png\n'
|
84 |
+
'| | 2/\n'
|
85 |
+
'| | | frame_00.png\n'
|
86 |
+
'| | | frame_01_gt.png\n'
|
87 |
+
'| | | frame_01_ours.png\n'
|
88 |
+
'| | | frame_02.png\n'
|
89 |
+
'| | ...\n'
|
90 |
+
'| ucf101_sepconv/\n'
|
91 |
+
'| ...\n')
|
92 |
+
|
93 |
+
_OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
|
94 |
+
'output_tfrecord_filepath',
|
95 |
+
default=None,
|
96 |
+
required=True,
|
97 |
+
help='Filepath to the output TFRecord file.')
|
98 |
+
|
99 |
+
_NUM_SHARDS = flags.DEFINE_integer('num_shards',
|
100 |
+
default=2,
|
101 |
+
help='Number of shards used for the output.')
|
102 |
+
|
103 |
+
# Image key -> basename for frame interpolator: start / middle / end frames.
|
104 |
+
_INTERPOLATOR_IMAGES_MAP = {
|
105 |
+
'frame_0': 'frame_00.png',
|
106 |
+
'frame_1': 'frame_01_gt.png',
|
107 |
+
'frame_2': 'frame_02.png',
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
def main(unused_argv):
|
112 |
+
"""Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
|
113 |
+
# Collect the list of folder paths containing the input and golden frames.
|
114 |
+
triplets_list = tf.io.gfile.listdir(_INPUT_DIR.value)
|
115 |
+
|
116 |
+
triplet_dicts = []
|
117 |
+
for triplet in triplets_list:
|
118 |
+
triplet_dicts.append({
|
119 |
+
image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
|
120 |
+
for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
|
121 |
+
})
|
122 |
+
|
123 |
+
p = beam.Pipeline('DirectRunner')
|
124 |
+
(p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
|
125 |
+
| 'GenerateSingleExample' >> beam.ParDo(
|
126 |
+
util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
|
127 |
+
| 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
|
128 |
+
file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
|
129 |
+
num_shards=_NUM_SHARDS.value,
|
130 |
+
coder=beam.coders.BytesCoder()))
|
131 |
+
result = p.run()
|
132 |
+
result.wait_until_finish()
|
133 |
+
|
134 |
+
logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
|
135 |
+
_OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
app.run(main)
|
datasets/create_vimeo90K_tfrecord.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Beam pipeline that generates Vimeo-90K (train or test) triplet TFRecords.
|
16 |
+
|
17 |
+
Vimeo-90K dataset is built upon 5,846 videos downloaded from vimeo.com. The list
|
18 |
+
of the original video links are available here:
|
19 |
+
https://github.com/anchen1011/toflow/blob/master/data/original_vimeo_links.txt.
|
20 |
+
Each video is further cropped into a fixed spatial size of (448 x 256) to create
|
21 |
+
89,000 video clips.
|
22 |
+
|
23 |
+
The Vimeo-90K dataset is designed for four video processing tasks. This script
|
24 |
+
creates the TFRecords of frame triplets for frame interpolation task.
|
25 |
+
|
26 |
+
Temporal frame interpolation triplet dataset:
|
27 |
+
- 73,171 triplets of size (448x256) extracted from 15K subsets of Vimeo-90K.
|
28 |
+
- The triplets are pre-split into (train,test) = (51313,3782)
|
29 |
+
- Download links:
|
30 |
+
Test-set: http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip
|
31 |
+
Train+test-set: http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip
|
32 |
+
|
33 |
+
For more information, see the arXiv paper, project page or the GitHub link.
|
34 |
+
@article{xue17toflow,
|
35 |
+
author = {Xue, Tianfan and
|
36 |
+
Chen, Baian and
|
37 |
+
Wu, Jiajun and
|
38 |
+
Wei, Donglai and
|
39 |
+
Freeman, William T},
|
40 |
+
title = {Video Enhancement with Task-Oriented Flow},
|
41 |
+
journal = {arXiv},
|
42 |
+
year = {2017}
|
43 |
+
}
|
44 |
+
Project: http://toflow.csail.mit.edu/
|
45 |
+
GitHub: https://github.com/anchen1011/toflow
|
46 |
+
|
47 |
+
Inputs to the script are (1) the directory to the downloaded and unzipped folder
|
48 |
+
(2) the filepath of the text-file that lists the subfolders of the triplets.
|
49 |
+
|
50 |
+
Output TFRecord is a tf.train.Example proto of each image triplet.
|
51 |
+
The feature_map takes the form:
|
52 |
+
feature_map {
|
53 |
+
'frame_0/encoded':
|
54 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
55 |
+
'frame_0/format':
|
56 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
57 |
+
'frame_0/height':
|
58 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
59 |
+
'frame_0/width':
|
60 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
61 |
+
'frame_1/encoded':
|
62 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
63 |
+
'frame_1/format':
|
64 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
65 |
+
'frame_1/height':
|
66 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
67 |
+
'frame_1/width':
|
68 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
69 |
+
'frame_2/encoded':
|
70 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
71 |
+
'frame_2/format':
|
72 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
73 |
+
'frame_2/height':
|
74 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
75 |
+
'frame_2/width':
|
76 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0)
|
77 |
+
'path':
|
78 |
+
tf.io.FixedLenFeature((), tf.string, default_value='')
|
79 |
+
}
|
80 |
+
|
81 |
+
Usage example:
|
82 |
+
python3 -m frame_interpolation.datasets.create_vimeo90K_tfrecord \
|
83 |
+
--input_dir=<root folder of vimeo90K dataset> \
|
84 |
+
--input_triplet_list_filepath=<filepath of tri_{test|train}list.txt> \
|
85 |
+
--output_tfrecord_filepath=<output tfrecord filepath>
|
86 |
+
"""
|
87 |
+
import os
|
88 |
+
|
89 |
+
from . import util
|
90 |
+
from absl import app
|
91 |
+
from absl import flags
|
92 |
+
from absl import logging
|
93 |
+
import apache_beam as beam
|
94 |
+
import numpy as np
|
95 |
+
import tensorflow as tf
|
96 |
+
|
97 |
+
|
98 |
+
_INPUT_DIR = flags.DEFINE_string(
|
99 |
+
'input_dir',
|
100 |
+
default='/path/to/raw_vimeo_interp/sequences',
|
101 |
+
help='Path to the root directory of the vimeo frame interpolation dataset. '
|
102 |
+
'We expect the data to have been downloaded and unzipped.\n'
|
103 |
+
'Folder structures:\n'
|
104 |
+
'| raw_vimeo_dataset/\n'
|
105 |
+
'| sequences/\n'
|
106 |
+
'| | 00001\n'
|
107 |
+
'| | | 0389/\n'
|
108 |
+
'| | | | im1.png\n'
|
109 |
+
'| | | | im2.png\n'
|
110 |
+
'| | | | im3.png\n'
|
111 |
+
'| | | ...\n'
|
112 |
+
'| | 00002/\n'
|
113 |
+
'| | ...\n'
|
114 |
+
'| readme.txt\n'
|
115 |
+
'| tri_trainlist.txt\n'
|
116 |
+
'| tri_testlist.txt \n')
|
117 |
+
|
118 |
+
_INTPUT_TRIPLET_LIST_FILEPATH = flags.DEFINE_string(
|
119 |
+
'input_triplet_list_filepath',
|
120 |
+
default='/path/to/raw_vimeo_dataset/tri_{test|train}list.txt',
|
121 |
+
help='Text file containing a list of sub-directories of input triplets.')
|
122 |
+
|
123 |
+
_OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
|
124 |
+
'output_tfrecord_filepath',
|
125 |
+
default=None,
|
126 |
+
help='Filepath to the output TFRecord file.')
|
127 |
+
|
128 |
+
_NUM_SHARDS = flags.DEFINE_integer('num_shards',
|
129 |
+
default=200, # set to 3 for vimeo_test, and 200 for vimeo_train.
|
130 |
+
help='Number of shards used for the output.')
|
131 |
+
|
132 |
+
# Image key -> basename for frame interpolator: start / middle / end frames.
|
133 |
+
_INTERPOLATOR_IMAGES_MAP = {
|
134 |
+
'frame_0': 'im1.png',
|
135 |
+
'frame_1': 'im2.png',
|
136 |
+
'frame_2': 'im3.png',
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
def main(unused_argv):
|
141 |
+
"""Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
|
142 |
+
with tf.io.gfile.GFile(_INTPUT_TRIPLET_LIST_FILEPATH.value, 'r') as fid:
|
143 |
+
triplets_list = np.loadtxt(fid, dtype=str)
|
144 |
+
|
145 |
+
triplet_dicts = []
|
146 |
+
for triplet in triplets_list:
|
147 |
+
triplet_dict = {
|
148 |
+
image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
|
149 |
+
for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
|
150 |
+
}
|
151 |
+
triplet_dicts.append(triplet_dict)
|
152 |
+
p = beam.Pipeline('DirectRunner')
|
153 |
+
(p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
|
154 |
+
| 'GenerateSingleExample' >> beam.ParDo(
|
155 |
+
util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
|
156 |
+
| 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
|
157 |
+
file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
|
158 |
+
num_shards=_NUM_SHARDS.value,
|
159 |
+
coder=beam.coders.BytesCoder()))
|
160 |
+
result = p.run()
|
161 |
+
result.wait_until_finish()
|
162 |
+
|
163 |
+
logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
|
164 |
+
_OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
|
165 |
+
|
166 |
+
if __name__ == '__main__':
|
167 |
+
app.run(main)
|
datasets/create_xiph_tfrecord.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Beam pipeline that generates Xiph triplet TFRecords.
|
16 |
+
|
17 |
+
Xiph is a frame sequence dataset commonly used to assess video compression. See
|
18 |
+
here: https://media.xiph.org/video/derf/
|
19 |
+
|
20 |
+
The SoftSplat paper selected eight 4K clips with the most amount of motion and
|
21 |
+
extracted the first 100 frames from each clip. Each frame is then either resized
|
22 |
+
from 4K to 2K, or a 2K center crop from them is performed before interpolating
|
23 |
+
the even frames from the odd frames. These datasets are denoted as `Xiph-2K`
|
24 |
+
and `Xiph-4K` respectively. For more information see the project page:
|
25 |
+
https://github.com/sniklaus/softmax-splatting
|
26 |
+
|
27 |
+
Input is the root folder that contains the 800 frames of the eight clips. Set
|
28 |
+
center_crop_factor=2 and scale_factor=1 to generate `Xiph-4K`,and scale_factor=2
|
29 |
+
, center_crop_factor=1 to generate `Xiph-2K`. The scripts defaults to `Xiph-2K`.
|
30 |
+
|
31 |
+
Output TFRecord is a tf.train.Example proto of each image triplet.
|
32 |
+
The feature_map takes the form:
|
33 |
+
feature_map {
|
34 |
+
'frame_0/encoded':
|
35 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
36 |
+
'frame_0/format':
|
37 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
38 |
+
'frame_0/height':
|
39 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
40 |
+
'frame_0/width':
|
41 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
42 |
+
'frame_1/encoded':
|
43 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
44 |
+
'frame_1/format':
|
45 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
46 |
+
'frame_1/height':
|
47 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
48 |
+
'frame_1/width':
|
49 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
50 |
+
'frame_2/encoded':
|
51 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
52 |
+
'frame_2/format':
|
53 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
54 |
+
'frame_2/height':
|
55 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
56 |
+
'frame_2/width':
|
57 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
58 |
+
'path':
|
59 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
60 |
+
}
|
61 |
+
|
62 |
+
Usage example:
|
63 |
+
python3 -m frame_interpolation.datasets.create_xiph_tfrecord \
|
64 |
+
--input_dir=<root folder of xiph dataset> \
|
65 |
+
--scale_factor=<scale factor for image resizing, default=2> \
|
66 |
+
--center_crop_factor=<center cropping factor, default=1> \
|
67 |
+
--output_tfrecord_filepath=<output tfrecord filepath>
|
68 |
+
"""
|
69 |
+
import os
|
70 |
+
|
71 |
+
from . import util
|
72 |
+
from absl import app
|
73 |
+
from absl import flags
|
74 |
+
from absl import logging
|
75 |
+
import apache_beam as beam
|
76 |
+
import tensorflow as tf
|
77 |
+
|
78 |
+
_INPUT_DIR = flags.DEFINE_string(
|
79 |
+
'input_dir',
|
80 |
+
default='/root/path/to/selected/xiph/clips',
|
81 |
+
help='Path to the root directory of the `Xiph` interpolation evaluation '
|
82 |
+
'data. We expect the data to have been downloaded and unzipped.')
|
83 |
+
_CENTER_CROP_FACTOR = flags.DEFINE_integer(
|
84 |
+
'center_crop_factor',
|
85 |
+
default=1,
|
86 |
+
help='Factor to center crop image. If set to 2, an image of the same '
|
87 |
+
'resolution as the inputs but half the size is created.')
|
88 |
+
_SCALE_FACTOR = flags.DEFINE_integer(
|
89 |
+
'scale_factor',
|
90 |
+
default=2,
|
91 |
+
help='Factor to downsample frames.')
|
92 |
+
_NUM_CLIPS = flags.DEFINE_integer(
|
93 |
+
'num_clips', default=8, help='Number of clips.')
|
94 |
+
_NUM_FRAMES = flags.DEFINE_integer(
|
95 |
+
'num_frames', default=100, help='Number of frames per clip.')
|
96 |
+
_OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
|
97 |
+
'output_tfrecord_filepath',
|
98 |
+
default=None,
|
99 |
+
required=True,
|
100 |
+
help='Filepath to the output TFRecord file.')
|
101 |
+
_NUM_SHARDS = flags.DEFINE_integer('num_shards',
|
102 |
+
default=2,
|
103 |
+
help='Number of shards used for the output.')
|
104 |
+
|
105 |
+
# Image key -> offset for frame interpolator: start / middle / end frame offset.
|
106 |
+
_INTERPOLATOR_IMAGES_MAP = {
|
107 |
+
'frame_0': -1,
|
108 |
+
'frame_1': 0,
|
109 |
+
'frame_2': 1,
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
def main(unused_argv):
|
114 |
+
"""Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
|
115 |
+
# Collect the list of frame filenames.
|
116 |
+
frames_list = sorted(tf.io.gfile.listdir(_INPUT_DIR.value))
|
117 |
+
|
118 |
+
# Collect the triplets, even frames serving as golden to interpolate odds.
|
119 |
+
triplets_dict = []
|
120 |
+
for clip_index in range(_NUM_CLIPS.value):
|
121 |
+
for frame_index in range(1, _NUM_FRAMES.value - 1, 2):
|
122 |
+
index = clip_index * _NUM_FRAMES.value + frame_index
|
123 |
+
triplet_dict = {
|
124 |
+
image_key: os.path.join(_INPUT_DIR.value,
|
125 |
+
frames_list[index + image_offset])
|
126 |
+
for image_key, image_offset in _INTERPOLATOR_IMAGES_MAP.items()
|
127 |
+
}
|
128 |
+
triplets_dict.append(triplet_dict)
|
129 |
+
|
130 |
+
p = beam.Pipeline('DirectRunner')
|
131 |
+
(p | 'ReadInputTripletDicts' >> beam.Create(triplets_dict) # pylint: disable=expression-not-assigned
|
132 |
+
| 'GenerateSingleExample' >> beam.ParDo(
|
133 |
+
util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP, _SCALE_FACTOR.value,
|
134 |
+
_CENTER_CROP_FACTOR.value))
|
135 |
+
| 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
|
136 |
+
file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
|
137 |
+
num_shards=_NUM_SHARDS.value,
|
138 |
+
coder=beam.coders.BytesCoder()))
|
139 |
+
result = p.run()
|
140 |
+
result.wait_until_finish()
|
141 |
+
|
142 |
+
logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
|
143 |
+
_OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
|
144 |
+
|
145 |
+
if __name__ == '__main__':
|
146 |
+
app.run(main)
|
datasets/util.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Utility functions for creating a tf.train.Example proto of image triplets."""
|
16 |
+
|
17 |
+
import io
|
18 |
+
import os
|
19 |
+
from typing import Any, List, Mapping, Optional
|
20 |
+
|
21 |
+
from absl import logging
|
22 |
+
import apache_beam as beam
|
23 |
+
import numpy as np
|
24 |
+
import PIL.Image
|
25 |
+
import six
|
26 |
+
from skimage import transform
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
|
30 |
+
_GAMMA = 2.2
|
31 |
+
|
32 |
+
|
33 |
+
def _resample_image(image: np.ndarray, resample_image_width: int,
|
34 |
+
resample_image_height: int) -> np.ndarray:
|
35 |
+
"""Re-samples and returns an `image` to be `resample_image_size`."""
|
36 |
+
# Convert image from uint8 gamma [0..255] to float linear [0..1].
|
37 |
+
image = image.astype(np.float32) / _UINT8_MAX_F
|
38 |
+
image = np.power(np.clip(image, 0, 1), _GAMMA)
|
39 |
+
|
40 |
+
# Re-size the image
|
41 |
+
resample_image_size = (resample_image_height, resample_image_width)
|
42 |
+
image = transform.resize_local_mean(image, resample_image_size)
|
43 |
+
|
44 |
+
# Convert back from float linear [0..1] to uint8 gamma [0..255].
|
45 |
+
image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA)
|
46 |
+
image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0,
|
47 |
+
_UINT8_MAX_F).astype(np.uint8)
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def generate_image_triplet_example(
|
52 |
+
triplet_dict: Mapping[str, str],
|
53 |
+
scale_factor: int = 1,
|
54 |
+
center_crop_factor: int = 1) -> Optional[tf.train.Example]:
|
55 |
+
"""Generates and serializes a tf.train.Example proto from an image triplet.
|
56 |
+
|
57 |
+
Default setting creates a triplet Example with the input images unchanged.
|
58 |
+
Images are processed in the order of center-crop then downscale.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
triplet_dict: A dict of image key to filepath of the triplet images.
|
62 |
+
scale_factor: An integer scale factor to isotropically downsample images.
|
63 |
+
center_crop_factor: An integer cropping factor to center crop images with
|
64 |
+
the original resolution but isotropically downsized by the factor.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
tf.train.Example proto, or None upon error.
|
68 |
+
|
69 |
+
Raises:
|
70 |
+
ValueError if triplet_dict length is different from three or the scale input
|
71 |
+
arguments are non-positive.
|
72 |
+
"""
|
73 |
+
if len(triplet_dict) != 3:
|
74 |
+
raise ValueError(
|
75 |
+
f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.')
|
76 |
+
|
77 |
+
if scale_factor <= 0 or center_crop_factor <= 0:
|
78 |
+
raise ValueError(f'(scale_factor, center_crop_factor) must be positive, '
|
79 |
+
f'Not ({scale_factor}, {center_crop_factor}).')
|
80 |
+
|
81 |
+
feature = {}
|
82 |
+
|
83 |
+
# Keep track of the path where the images came from for debugging purposes.
|
84 |
+
mid_frame_path = os.path.dirname(triplet_dict['frame_1'])
|
85 |
+
feature['path'] = tf.train.Feature(
|
86 |
+
bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)]))
|
87 |
+
|
88 |
+
for image_key, image_path in triplet_dict.items():
|
89 |
+
if not tf.io.gfile.exists(image_path):
|
90 |
+
logging.error('File not found: %s', image_path)
|
91 |
+
return None
|
92 |
+
|
93 |
+
# Note: we need both the raw bytes and the image size.
|
94 |
+
# PIL.Image does not expose a method to grab the original bytes.
|
95 |
+
# (Also it is not aware of non-local file systems.)
|
96 |
+
# So we read with tf.io.gfile.GFile to get the bytes, and then wrap the
|
97 |
+
# bytes in BytesIO to let PIL.Image open the image.
|
98 |
+
try:
|
99 |
+
byte_array = tf.io.gfile.GFile(image_path, 'rb').read()
|
100 |
+
except tf.errors.InvalidArgumentError:
|
101 |
+
logging.exception('Cannot read image file: %s', image_path)
|
102 |
+
return None
|
103 |
+
try:
|
104 |
+
pil_image = PIL.Image.open(io.BytesIO(byte_array))
|
105 |
+
except PIL.UnidentifiedImageError:
|
106 |
+
logging.exception('Cannot decode image file: %s', image_path)
|
107 |
+
return None
|
108 |
+
width, height = pil_image.size
|
109 |
+
pil_image_format = pil_image.format
|
110 |
+
|
111 |
+
# Optionally center-crop images and downsize images
|
112 |
+
# by `center_crop_factor`.
|
113 |
+
if center_crop_factor > 1:
|
114 |
+
image = np.array(pil_image)
|
115 |
+
quarter_height = image.shape[0] // (2 * center_crop_factor)
|
116 |
+
quarter_width = image.shape[1] // (2 * center_crop_factor)
|
117 |
+
image = image[quarter_height:-quarter_height,
|
118 |
+
quarter_width:-quarter_width, :]
|
119 |
+
pil_image = PIL.Image.fromarray(image)
|
120 |
+
|
121 |
+
# Update image properties.
|
122 |
+
height, width, _ = image.shape
|
123 |
+
buffer = io.BytesIO()
|
124 |
+
try:
|
125 |
+
pil_image.save(buffer, format='PNG')
|
126 |
+
except OSError:
|
127 |
+
logging.exception('Cannot encode image file: %s', image_path)
|
128 |
+
return None
|
129 |
+
byte_array = buffer.getvalue()
|
130 |
+
|
131 |
+
# Optionally downsample images by `scale_factor`.
|
132 |
+
if scale_factor > 1:
|
133 |
+
image = np.array(pil_image)
|
134 |
+
image = _resample_image(image, image.shape[1] // scale_factor,
|
135 |
+
image.shape[0] // scale_factor)
|
136 |
+
pil_image = PIL.Image.fromarray(image)
|
137 |
+
|
138 |
+
# Update image properties.
|
139 |
+
height, width, _ = image.shape
|
140 |
+
buffer = io.BytesIO()
|
141 |
+
try:
|
142 |
+
pil_image.save(buffer, format='PNG')
|
143 |
+
except OSError:
|
144 |
+
logging.exception('Cannot encode image file: %s', image_path)
|
145 |
+
return None
|
146 |
+
byte_array = buffer.getvalue()
|
147 |
+
|
148 |
+
# Create tf Features.
|
149 |
+
image_feature = tf.train.Feature(
|
150 |
+
bytes_list=tf.train.BytesList(value=[byte_array]))
|
151 |
+
height_feature = tf.train.Feature(
|
152 |
+
int64_list=tf.train.Int64List(value=[height]))
|
153 |
+
width_feature = tf.train.Feature(
|
154 |
+
int64_list=tf.train.Int64List(value=[width]))
|
155 |
+
encoding = tf.train.Feature(
|
156 |
+
bytes_list=tf.train.BytesList(
|
157 |
+
value=[six.ensure_binary(pil_image_format.lower())]))
|
158 |
+
|
159 |
+
# Update feature map.
|
160 |
+
feature[f'{image_key}/encoded'] = image_feature
|
161 |
+
feature[f'{image_key}/format'] = encoding
|
162 |
+
feature[f'{image_key}/height'] = height_feature
|
163 |
+
feature[f'{image_key}/width'] = width_feature
|
164 |
+
|
165 |
+
# Create tf Example.
|
166 |
+
features = tf.train.Features(feature=feature)
|
167 |
+
example = tf.train.Example(features=features)
|
168 |
+
return example
|
169 |
+
|
170 |
+
|
171 |
+
class ExampleGenerator(beam.DoFn):
|
172 |
+
"""Generate a tf.train.Example per input image triplet filepaths."""
|
173 |
+
|
174 |
+
def __init__(self,
|
175 |
+
images_map: Mapping[str, Any],
|
176 |
+
scale_factor: int = 1,
|
177 |
+
center_crop_factor: int = 1):
|
178 |
+
"""Initializes the map of 3 images to add to each tf.train.Example.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
images_map: Map from image key to image filepath.
|
182 |
+
scale_factor: A scale factor to downsample frames.
|
183 |
+
center_crop_factor: A factor to centercrop and downsize frames.
|
184 |
+
"""
|
185 |
+
super().__init__()
|
186 |
+
self._images_map = images_map
|
187 |
+
self._scale_factor = scale_factor
|
188 |
+
self._center_crop_factor = center_crop_factor
|
189 |
+
|
190 |
+
def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]:
|
191 |
+
"""Generates a serialized tf.train.Example for a triplet of images.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
triplet_dict: A dict of image key to filepath of the triplet images.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
A serialized tf.train.Example proto. No shuffling is applied.
|
198 |
+
"""
|
199 |
+
example = generate_image_triplet_example(triplet_dict, self._scale_factor,
|
200 |
+
self._center_crop_factor)
|
201 |
+
if example:
|
202 |
+
return [example.SerializeToString()]
|
203 |
+
else:
|
204 |
+
return []
|
eval/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
eval/config/middlebury.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
experiment.name = 'middlebury'
|
16 |
+
evaluation.max_examples = -1
|
17 |
+
evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
|
18 |
+
evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3'
|
eval/config/ucf101.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
experiment.name = 'ucf101'
|
16 |
+
evaluation.max_examples = -1
|
17 |
+
evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
|
18 |
+
evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2'
|
eval/config/vimeo_90K.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
experiment.name = 'vimeo_90K'
|
16 |
+
evaluation.max_examples = -1
|
17 |
+
evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
|
18 |
+
evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3'
|
eval/config/xiph_2K.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
experiment.name = 'xiph_2K'
|
16 |
+
evaluation.max_examples = -1
|
17 |
+
evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
|
18 |
+
evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2'
|
eval/config/xiph_4K.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
experiment.name = 'xiph_4K'
|
16 |
+
evaluation.max_examples = -1
|
17 |
+
evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
|
18 |
+
evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2'
|
eval/eval_cli.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Evaluate the frame interpolation model from a tfrecord and store results.
|
16 |
+
|
17 |
+
This script runs the inference on examples in a tfrecord and generates images
|
18 |
+
and numeric results according to the gin config. For details, see the
|
19 |
+
run_evaluation() function below.
|
20 |
+
|
21 |
+
Usage example:
|
22 |
+
python3 -m frame_interpolation.eval.eval_cli -- \
|
23 |
+
--gin_config <path to eval_dataset.gin> \
|
24 |
+
--base_folder <the root directory to all training sessions> \
|
25 |
+
--label < the foldername of the training session>
|
26 |
+
|
27 |
+
or
|
28 |
+
|
29 |
+
python3 -m frame_interpolation.eval.eval_cli -- \
|
30 |
+
--gin_config <path to eval_dataset.gin> \
|
31 |
+
--model_path <The filepath of the TF2 saved model>
|
32 |
+
|
33 |
+
The output is saved at the parent directory of the `model_path`:
|
34 |
+
<parent directory of model_path>/batch_eval.
|
35 |
+
|
36 |
+
The evaluation is run on a GPU by default. Add the `--mode` argument for others.
|
37 |
+
"""
|
38 |
+
import collections
|
39 |
+
import os
|
40 |
+
from typing import Any, Dict
|
41 |
+
|
42 |
+
from . import util
|
43 |
+
from absl import app
|
44 |
+
from absl import flags
|
45 |
+
from absl import logging
|
46 |
+
import gin.tf
|
47 |
+
from ..losses import losses
|
48 |
+
import numpy as np
|
49 |
+
import tensorflow as tf
|
50 |
+
from ..training import data_lib
|
51 |
+
|
52 |
+
|
53 |
+
_GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
|
54 |
+
_LABEL = flags.DEFINE_string(
|
55 |
+
'label', None, 'Descriptive label for the training session to eval.')
|
56 |
+
_BASE_FOLDER = flags.DEFINE_string('base_folder', None,
|
57 |
+
'Root folder of training sessions.')
|
58 |
+
_MODEL_PATH = flags.DEFINE_string(
|
59 |
+
name='model_path',
|
60 |
+
default=None,
|
61 |
+
help='The path of the TF2 saved model to use. If _MODEL_PATH argument is '
|
62 |
+
'directly specified, _LABEL and _BASE_FOLDER arguments will be ignored.')
|
63 |
+
_OUTPUT_FRAMES = flags.DEFINE_boolean(
|
64 |
+
name='output_frames',
|
65 |
+
default=False,
|
66 |
+
help='If true, saves the the inputs, groud-truth and interpolated frames.')
|
67 |
+
_MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
|
68 |
+
'Device to run evaluations.')
|
69 |
+
|
70 |
+
|
71 |
+
@gin.configurable('experiment')
|
72 |
+
def _get_experiment_config(name) -> Dict[str, Any]:
|
73 |
+
"""Fetches the gin config."""
|
74 |
+
return {
|
75 |
+
'name': name,
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
def _set_visible_devices():
|
80 |
+
"""Set the visible devices according to running mode."""
|
81 |
+
mode_devices = tf.config.list_physical_devices(_MODE.value.upper())
|
82 |
+
tf.config.set_visible_devices([], 'GPU')
|
83 |
+
tf.config.set_visible_devices([], 'TPU')
|
84 |
+
tf.config.set_visible_devices(mode_devices, _MODE.value.upper())
|
85 |
+
return
|
86 |
+
|
87 |
+
|
88 |
+
@gin.configurable('evaluation')
|
89 |
+
def run_evaluation(model_path, tfrecord, output_dir, max_examples, metrics):
|
90 |
+
"""Runs the eval loop for examples in the tfrecord.
|
91 |
+
|
92 |
+
The evaluation is run for the first 'max_examples' number of examples, and
|
93 |
+
resulting images are stored into the given output_dir. Any tensor that
|
94 |
+
appears like an image is stored with its name -- this may include intermediate
|
95 |
+
results, depending on what the model outputs.
|
96 |
+
|
97 |
+
Additionally, numeric results are stored into results.csv file within the same
|
98 |
+
directory. This includes per-example metrics and the mean across the whole
|
99 |
+
dataset.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model_path: Directory TF2 saved model.
|
103 |
+
tfrecord: Directory to the tfrecord eval data.
|
104 |
+
output_dir: Directory to store the results into.
|
105 |
+
max_examples: Maximum examples to evaluate.
|
106 |
+
metrics: The names of loss functions to use.
|
107 |
+
"""
|
108 |
+
model = tf.saved_model.load(model_path)
|
109 |
+
|
110 |
+
# Store a 'readme.txt' that contains information on where the data came from.
|
111 |
+
with tf.io.gfile.GFile(os.path.join(output_dir, 'readme.txt'), mode='w') as f:
|
112 |
+
print('Results for:', file=f)
|
113 |
+
print(f' model: {model_path}', file=f)
|
114 |
+
print(f' tfrecord: {tfrecord}', file=f)
|
115 |
+
|
116 |
+
with tf.io.gfile.GFile(
|
117 |
+
os.path.join(output_dir, 'results.csv'), mode='w') as csv_file:
|
118 |
+
test_losses = losses.test_losses(metrics, [
|
119 |
+
1.0,
|
120 |
+
] * len(metrics))
|
121 |
+
title_row = ['key'] + list(test_losses)
|
122 |
+
print(', '.join(title_row), file=csv_file)
|
123 |
+
|
124 |
+
datasets = data_lib.create_eval_datasets(
|
125 |
+
batch_size=1,
|
126 |
+
files=[tfrecord],
|
127 |
+
names=[os.path.basename(output_dir)],
|
128 |
+
max_examples=max_examples)
|
129 |
+
dataset = datasets[os.path.basename(output_dir)]
|
130 |
+
|
131 |
+
all_losses = collections.defaultdict(list)
|
132 |
+
for example in dataset:
|
133 |
+
inputs = {
|
134 |
+
'x0': example['x0'],
|
135 |
+
'x1': example['x1'],
|
136 |
+
'time': example['time'][..., tf.newaxis],
|
137 |
+
}
|
138 |
+
prediction = model(inputs, training=False)
|
139 |
+
|
140 |
+
# Get the key from encoded mid-frame path.
|
141 |
+
path = example['path'][0].numpy().decode('utf-8')
|
142 |
+
key = path.rsplit('.', 1)[0].rsplit(os.sep)[-1]
|
143 |
+
|
144 |
+
# Combines both inputs and outputs into a single dictionary:
|
145 |
+
combined = {**prediction, **example} if _OUTPUT_FRAMES.value else {}
|
146 |
+
for name in combined:
|
147 |
+
image = combined[name]
|
148 |
+
if isinstance(image, tf.Tensor):
|
149 |
+
# This saves any tensor that has a shape that can be interpreted
|
150 |
+
# as an image, e.g. (1, H, W, C), where the batch dimension is always
|
151 |
+
# 1, H and W are the image height and width, and C is either 1 or 3
|
152 |
+
# (grayscale or color image).
|
153 |
+
if len(image.shape) == 4 and (image.shape[-1] == 1 or
|
154 |
+
image.shape[-1] == 3):
|
155 |
+
util.write_image(
|
156 |
+
os.path.join(output_dir, f'{key}_{name}.png'), image[0].numpy())
|
157 |
+
|
158 |
+
# Evaluate losses if the dataset has ground truth 'y', otherwise just do
|
159 |
+
# a visual eval.
|
160 |
+
if 'y' in example:
|
161 |
+
loss_values = []
|
162 |
+
# Clip interpolator output to the range [0,1]. Clipping is done only
|
163 |
+
# on the eval loop to get better metrics, but not on the training loop
|
164 |
+
# so gradients are not killed.
|
165 |
+
prediction['image'] = tf.clip_by_value(prediction['image'], 0., 1.)
|
166 |
+
for loss_name, (loss_value_fn, loss_weight_fn) in test_losses.items():
|
167 |
+
loss_value = loss_value_fn(example, prediction) * loss_weight_fn(0)
|
168 |
+
loss_values.append(loss_value.numpy())
|
169 |
+
all_losses[loss_name].append(loss_value.numpy())
|
170 |
+
print(f'{key}, {str(loss_values)[1:-1]}', file=csv_file)
|
171 |
+
|
172 |
+
if all_losses:
|
173 |
+
totals = [np.mean(all_losses[loss_name]) for loss_name in test_losses]
|
174 |
+
print(f'mean, {str(totals)[1:-1]}', file=csv_file)
|
175 |
+
totals_dict = {
|
176 |
+
loss_name: np.mean(all_losses[loss_name]) for loss_name in test_losses
|
177 |
+
}
|
178 |
+
logging.info('mean, %s', totals_dict)
|
179 |
+
|
180 |
+
|
181 |
+
def main(argv):
|
182 |
+
if len(argv) > 1:
|
183 |
+
raise app.UsageError('Too many command-line arguments.')
|
184 |
+
|
185 |
+
if _MODEL_PATH.value is not None:
|
186 |
+
model_path = _MODEL_PATH.value
|
187 |
+
else:
|
188 |
+
model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'saved_model')
|
189 |
+
|
190 |
+
gin.parse_config_files_and_bindings(
|
191 |
+
config_files=[_GIN_CONFIG.value],
|
192 |
+
bindings=None,
|
193 |
+
skip_unknown=True)
|
194 |
+
|
195 |
+
config = _get_experiment_config() # pylint: disable=no-value-for-parameter
|
196 |
+
eval_name = config['name']
|
197 |
+
output_dir = os.path.join(
|
198 |
+
os.path.dirname(model_path), 'batch_eval', eval_name)
|
199 |
+
logging.info('Creating output_dir @ %s ...', output_dir)
|
200 |
+
|
201 |
+
# Copy config file to <base_folder>/<label>/batch_eval/<eval_name>/config.gin.
|
202 |
+
tf.io.gfile.makedirs(output_dir)
|
203 |
+
tf.io.gfile.copy(
|
204 |
+
_GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
|
205 |
+
|
206 |
+
_set_visible_devices()
|
207 |
+
logging.info('Evaluating %s on %s ...', eval_name, [
|
208 |
+
el.name.split('/physical_device:')[-1]
|
209 |
+
for el in tf.config.get_visible_devices()
|
210 |
+
])
|
211 |
+
run_evaluation(model_path=model_path, output_dir=output_dir) # pylint: disable=no-value-for-parameter
|
212 |
+
|
213 |
+
logging.info('Done. Evaluations saved @ %s.', output_dir)
|
214 |
+
|
215 |
+
if __name__ == '__main__':
|
216 |
+
app.run(main)
|
eval/interpolator.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""A wrapper class for running a frame interpolation TF2 saved model.
|
16 |
+
|
17 |
+
Usage:
|
18 |
+
model_path='/tmp/saved_model/'
|
19 |
+
it = Interpolator(model_path)
|
20 |
+
result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
|
21 |
+
|
22 |
+
Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
|
23 |
+
(B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
|
24 |
+
"""
|
25 |
+
from typing import List, Optional
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
|
30 |
+
def _pad_to_align(x, align):
|
31 |
+
"""Pad image batch x so width and height divide by align.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
x: Image batch to align.
|
35 |
+
align: Number to align to.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
1) An image padded so width % align == 0 and height % align == 0.
|
39 |
+
2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
|
40 |
+
to undo the padding.
|
41 |
+
"""
|
42 |
+
# Input checking.
|
43 |
+
assert np.ndim(x) == 4
|
44 |
+
assert align > 0, 'align must be a positive number.'
|
45 |
+
|
46 |
+
height, width = x.shape[-3:-1]
|
47 |
+
height_to_pad = (align - height % align) if height % align != 0 else 0
|
48 |
+
width_to_pad = (align - width % align) if width % align != 0 else 0
|
49 |
+
|
50 |
+
bbox_to_pad = {
|
51 |
+
'offset_height': height_to_pad // 2,
|
52 |
+
'offset_width': width_to_pad // 2,
|
53 |
+
'target_height': height + height_to_pad,
|
54 |
+
'target_width': width + width_to_pad
|
55 |
+
}
|
56 |
+
padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
|
57 |
+
bbox_to_crop = {
|
58 |
+
'offset_height': height_to_pad // 2,
|
59 |
+
'offset_width': width_to_pad // 2,
|
60 |
+
'target_height': height,
|
61 |
+
'target_width': width
|
62 |
+
}
|
63 |
+
return padded_x, bbox_to_crop
|
64 |
+
|
65 |
+
|
66 |
+
def image_to_patches(image: np.ndarray, block_shape: List[int]) -> np.ndarray:
|
67 |
+
"""Folds an image into patches and stacks along the batch dimension.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
image: The input image of shape [B, H, W, C].
|
71 |
+
block_shape: The number of patches along the height and width to extract.
|
72 |
+
Each patch is shaped (H/block_shape[0], W/block_shape[1])
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
The extracted patches shaped [num_blocks, patch_height, patch_width,...],
|
76 |
+
with num_blocks = block_shape[0] * block_shape[1].
|
77 |
+
"""
|
78 |
+
block_height, block_width = block_shape
|
79 |
+
num_blocks = block_height * block_width
|
80 |
+
|
81 |
+
height, width, channel = image.shape[-3:]
|
82 |
+
patch_height, patch_width = height//block_height, width//block_width
|
83 |
+
|
84 |
+
assert height == (
|
85 |
+
patch_height * block_height
|
86 |
+
), 'block_height=%d should evenly divide height=%d.'%(block_height, height)
|
87 |
+
assert width == (
|
88 |
+
patch_width * block_width
|
89 |
+
), 'block_width=%d should evenly divide width=%d.'%(block_width, width)
|
90 |
+
|
91 |
+
patch_size = patch_height * patch_width
|
92 |
+
paddings = 2*[[0, 0]]
|
93 |
+
|
94 |
+
patches = tf.space_to_batch(image, [patch_height, patch_width], paddings)
|
95 |
+
patches = tf.split(patches, patch_size, 0)
|
96 |
+
patches = tf.stack(patches, axis=3)
|
97 |
+
patches = tf.reshape(patches,
|
98 |
+
[num_blocks, patch_height, patch_width, channel])
|
99 |
+
return patches.numpy()
|
100 |
+
|
101 |
+
|
102 |
+
def patches_to_image(patches: np.ndarray, block_shape: List[int]) -> np.ndarray:
|
103 |
+
"""Unfolds patches (stacked along batch) into an image.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
patches: The input patches, shaped [num_patches, patch_H, patch_W, C].
|
107 |
+
block_shape: The number of patches along the height and width to unfold.
|
108 |
+
Each patch assumed to be shaped (H/block_shape[0], W/block_shape[1]).
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
The unfolded image shaped [B, H, W, C].
|
112 |
+
"""
|
113 |
+
block_height, block_width = block_shape
|
114 |
+
paddings = 2 * [[0, 0]]
|
115 |
+
|
116 |
+
patch_height, patch_width, channel = patches.shape[-3:]
|
117 |
+
patch_size = patch_height * patch_width
|
118 |
+
|
119 |
+
patches = tf.reshape(patches,
|
120 |
+
[1, block_height, block_width, patch_size, channel])
|
121 |
+
patches = tf.split(patches, patch_size, axis=3)
|
122 |
+
patches = tf.stack(patches, axis=0)
|
123 |
+
patches = tf.reshape(patches,
|
124 |
+
[patch_size, block_height, block_width, channel])
|
125 |
+
image = tf.batch_to_space(patches, [patch_height, patch_width], paddings)
|
126 |
+
return image.numpy()
|
127 |
+
|
128 |
+
|
129 |
+
class Interpolator:
|
130 |
+
"""A class for generating interpolated frames between two input frames.
|
131 |
+
|
132 |
+
Uses TF2 saved model format.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, model_path: str,
|
136 |
+
align: Optional[int] = None,
|
137 |
+
block_shape: Optional[List[int]] = None) -> None:
|
138 |
+
"""Loads a saved model.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
model_path: Path to the saved model. If none are provided, uses the
|
142 |
+
default model.
|
143 |
+
align: 'If >1, pad the input size so it divides with this before
|
144 |
+
inference.'
|
145 |
+
block_shape: Number of patches along the (height, width) to sid-divide
|
146 |
+
input images.
|
147 |
+
"""
|
148 |
+
self._model = tf.compat.v2.saved_model.load(model_path)
|
149 |
+
self._align = align or None
|
150 |
+
self._block_shape = block_shape or None
|
151 |
+
|
152 |
+
def interpolate(self, x0: np.ndarray, x1: np.ndarray,
|
153 |
+
dt: np.ndarray) -> np.ndarray:
|
154 |
+
"""Generates an interpolated frame between given two batches of frames.
|
155 |
+
|
156 |
+
All input tensors should be np.float32 datatype.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
x0: First image batch. Dimensions: (batch_size, height, width, channels)
|
160 |
+
x1: Second image batch. Dimensions: (batch_size, height, width, channels)
|
161 |
+
dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
The result with dimensions (batch_size, height, width, channels).
|
165 |
+
"""
|
166 |
+
if self._align is not None:
|
167 |
+
x0, bbox_to_crop = _pad_to_align(x0, self._align)
|
168 |
+
x1, _ = _pad_to_align(x1, self._align)
|
169 |
+
|
170 |
+
inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
|
171 |
+
result = self._model(inputs, training=False)
|
172 |
+
image = result['image']
|
173 |
+
|
174 |
+
if self._align is not None:
|
175 |
+
image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
|
176 |
+
return image.numpy()
|
177 |
+
|
178 |
+
def __call__(self, x0: np.ndarray, x1: np.ndarray,
|
179 |
+
dt: np.ndarray) -> np.ndarray:
|
180 |
+
"""Generates an interpolated frame between given two batches of frames.
|
181 |
+
|
182 |
+
All input tensors should be np.float32 datatype.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
x0: First image batch. Dimensions: (batch_size, height, width, channels)
|
186 |
+
x1: Second image batch. Dimensions: (batch_size, height, width, channels)
|
187 |
+
dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
The result with dimensions (batch_size, height, width, channels).
|
191 |
+
"""
|
192 |
+
if self._block_shape is not None and np.prod(self._block_shape) > 1:
|
193 |
+
# Subdivide high-res images into managable non-overlapping patches.
|
194 |
+
x0_patches = image_to_patches(x0, self._block_shape)
|
195 |
+
x1_patches = image_to_patches(x1, self._block_shape)
|
196 |
+
|
197 |
+
# Run the interpolator on each patch pair.
|
198 |
+
output_patches = []
|
199 |
+
for image_0, image_1 in zip(x0_patches, x1_patches):
|
200 |
+
mid_patch = self.interpolate(image_0[np.newaxis, ...],
|
201 |
+
image_1[np.newaxis, ...], dt)
|
202 |
+
output_patches.append(mid_patch)
|
203 |
+
|
204 |
+
# Reconstruct interpolated image by stitching interpolated patches.
|
205 |
+
output_patches = np.concatenate(output_patches, axis=0)
|
206 |
+
return patches_to_image(output_patches, self._block_shape)
|
207 |
+
else:
|
208 |
+
# Invoke the interpolator once.
|
209 |
+
return self.interpolate(x0, x1, dt)
|
eval/interpolator_cli.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Runs the FILM frame interpolator on a pair of frames on beam.
|
16 |
+
|
17 |
+
This script is used evaluate the output quality of the FILM Tensorflow frame
|
18 |
+
interpolator. Optionally, it outputs a video of the interpolated frames.
|
19 |
+
|
20 |
+
A beam pipeline for invoking the frame interpolator on a set of directories
|
21 |
+
identified by a glob (--pattern). Each directory is expected to contain two
|
22 |
+
input frames that are the inputs to the frame interpolator. If a directory has
|
23 |
+
more than two frames, then each contiguous frame pair is treated as input to
|
24 |
+
generate in-between frames.
|
25 |
+
|
26 |
+
The output video is stored to interpolator.mp4 in each directory. The number of
|
27 |
+
frames is determined by --times_to_interpolate, which controls the number of
|
28 |
+
times the frame interpolator is invoked. When the number of input frames is 2,
|
29 |
+
the number of output frames is 2^times_to_interpolate+1.
|
30 |
+
|
31 |
+
This expects a directory structure such as:
|
32 |
+
<root directory of the eval>/01/frame1.png
|
33 |
+
frame2.png
|
34 |
+
<root directory of the eval>/02/frame1.png
|
35 |
+
frame2.png
|
36 |
+
<root directory of the eval>/03/frame1.png
|
37 |
+
frame2.png
|
38 |
+
...
|
39 |
+
|
40 |
+
And will produce:
|
41 |
+
<root directory of the eval>/01/interpolated_frames/frame0.png
|
42 |
+
frame1.png
|
43 |
+
frame2.png
|
44 |
+
<root directory of the eval>/02/interpolated_frames/frame0.png
|
45 |
+
frame1.png
|
46 |
+
frame2.png
|
47 |
+
<root directory of the eval>/03/interpolated_frames/frame0.png
|
48 |
+
frame1.png
|
49 |
+
frame2.png
|
50 |
+
...
|
51 |
+
|
52 |
+
And optionally will produce:
|
53 |
+
<root directory of the eval>/01/interpolated.mp4
|
54 |
+
<root directory of the eval>/02/interpolated.mp4
|
55 |
+
<root directory of the eval>/03/interpolated.mp4
|
56 |
+
...
|
57 |
+
|
58 |
+
Usage example:
|
59 |
+
python3 -m frame_interpolation.eval.interpolator_cli \
|
60 |
+
--model_path <path to TF2 saved model> \
|
61 |
+
--pattern "<root directory of the eval>/*" \
|
62 |
+
--times_to_interpolate <Number of times to interpolate>
|
63 |
+
"""
|
64 |
+
|
65 |
+
import functools
|
66 |
+
import os
|
67 |
+
from typing import List, Sequence
|
68 |
+
|
69 |
+
from . import interpolator as interpolator_lib
|
70 |
+
from . import util
|
71 |
+
from absl import app
|
72 |
+
from absl import flags
|
73 |
+
from absl import logging
|
74 |
+
import apache_beam as beam
|
75 |
+
import mediapy as media
|
76 |
+
import natsort
|
77 |
+
import numpy as np
|
78 |
+
import tensorflow as tf
|
79 |
+
from tqdm.auto import tqdm
|
80 |
+
|
81 |
+
# Controls TF_CCP log level.
|
82 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
83 |
+
|
84 |
+
|
85 |
+
_PATTERN = flags.DEFINE_string(
|
86 |
+
name='pattern',
|
87 |
+
default=None,
|
88 |
+
help='The pattern to determine the directories with the input frames.',
|
89 |
+
required=True)
|
90 |
+
_MODEL_PATH = flags.DEFINE_string(
|
91 |
+
name='model_path',
|
92 |
+
default=None,
|
93 |
+
help='The path of the TF2 saved model to use.')
|
94 |
+
_TIMES_TO_INTERPOLATE = flags.DEFINE_integer(
|
95 |
+
name='times_to_interpolate',
|
96 |
+
default=5,
|
97 |
+
help='The number of times to run recursive midpoint interpolation. '
|
98 |
+
'The number of output frames will be 2^times_to_interpolate+1.')
|
99 |
+
_FPS = flags.DEFINE_integer(
|
100 |
+
name='fps',
|
101 |
+
default=30,
|
102 |
+
help='Frames per second to play interpolated videos in slow motion.')
|
103 |
+
_ALIGN = flags.DEFINE_integer(
|
104 |
+
name='align',
|
105 |
+
default=64,
|
106 |
+
help='If >1, pad the input size so it is evenly divisible by this value.')
|
107 |
+
_BLOCK_HEIGHT = flags.DEFINE_integer(
|
108 |
+
name='block_height',
|
109 |
+
default=1,
|
110 |
+
help='An int >= 1, number of patches along height, '
|
111 |
+
'patch_height = height//block_height, should be evenly divisible.')
|
112 |
+
_BLOCK_WIDTH = flags.DEFINE_integer(
|
113 |
+
name='block_width',
|
114 |
+
default=1,
|
115 |
+
help='An int >= 1, number of patches along width, '
|
116 |
+
'patch_width = width//block_width, should be evenly divisible.')
|
117 |
+
_OUTPUT_VIDEO = flags.DEFINE_boolean(
|
118 |
+
name='output_video',
|
119 |
+
default=False,
|
120 |
+
help='If true, creates a video of the frames in the interpolated_frames/ '
|
121 |
+
'subdirectory')
|
122 |
+
|
123 |
+
# Add other extensions, if not either.
|
124 |
+
_INPUT_EXT = ['png', 'jpg', 'jpeg']
|
125 |
+
|
126 |
+
|
127 |
+
def _output_frames(frames: List[np.ndarray], frames_dir: str):
|
128 |
+
"""Writes PNG-images to a directory.
|
129 |
+
|
130 |
+
If frames_dir doesn't exist, it is created. If frames_dir contains existing
|
131 |
+
PNG-files, they are removed before saving the new ones.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
frames: List of images to save.
|
135 |
+
frames_dir: The output directory to save the images.
|
136 |
+
|
137 |
+
"""
|
138 |
+
if tf.io.gfile.isdir(frames_dir):
|
139 |
+
old_frames = tf.io.gfile.glob(f'{frames_dir}/frame_*.png')
|
140 |
+
if old_frames:
|
141 |
+
logging.info('Removing existing frames from %s.', frames_dir)
|
142 |
+
for old_frame in old_frames:
|
143 |
+
tf.io.gfile.remove(old_frame)
|
144 |
+
else:
|
145 |
+
tf.io.gfile.makedirs(frames_dir)
|
146 |
+
for idx, frame in tqdm(
|
147 |
+
enumerate(frames), total=len(frames), ncols=100, colour='green'):
|
148 |
+
util.write_image(f'{frames_dir}/frame_{idx:03d}.png', frame)
|
149 |
+
logging.info('Output frames saved in %s.', frames_dir)
|
150 |
+
|
151 |
+
|
152 |
+
class ProcessDirectory(beam.DoFn):
|
153 |
+
"""DoFn for running the interpolator on a single directory at the time."""
|
154 |
+
|
155 |
+
def setup(self):
|
156 |
+
self.interpolator = interpolator_lib.Interpolator(
|
157 |
+
_MODEL_PATH.value, _ALIGN.value,
|
158 |
+
[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
|
159 |
+
|
160 |
+
if _OUTPUT_VIDEO.value:
|
161 |
+
ffmpeg_path = util.get_ffmpeg_path()
|
162 |
+
media.set_ffmpeg(ffmpeg_path)
|
163 |
+
|
164 |
+
def process(self, directory: str):
|
165 |
+
input_frames_list = [
|
166 |
+
natsort.natsorted(tf.io.gfile.glob(f'{directory}/*.{ext}'))
|
167 |
+
for ext in _INPUT_EXT
|
168 |
+
]
|
169 |
+
input_frames = functools.reduce(lambda x, y: x + y, input_frames_list)
|
170 |
+
logging.info('Generating in-between frames for %s.', directory)
|
171 |
+
frames = list(
|
172 |
+
util.interpolate_recursively_from_files(
|
173 |
+
input_frames, _TIMES_TO_INTERPOLATE.value, self.interpolator))
|
174 |
+
_output_frames(frames, f'{directory}/interpolated_frames')
|
175 |
+
if _OUTPUT_VIDEO.value:
|
176 |
+
media.write_video(f'{directory}/interpolated.mp4', frames, fps=_FPS.value)
|
177 |
+
logging.info('Output video saved at %s/interpolated.mp4.', directory)
|
178 |
+
|
179 |
+
|
180 |
+
def _run_pipeline() -> None:
|
181 |
+
directories = tf.io.gfile.glob(_PATTERN.value)
|
182 |
+
pipeline = beam.Pipeline('DirectRunner')
|
183 |
+
(pipeline | 'Create directory names' >> beam.Create(directories) # pylint: disable=expression-not-assigned
|
184 |
+
| 'Process directories' >> beam.ParDo(ProcessDirectory()))
|
185 |
+
|
186 |
+
result = pipeline.run()
|
187 |
+
result.wait_until_finish()
|
188 |
+
|
189 |
+
|
190 |
+
def main(argv: Sequence[str]) -> None:
|
191 |
+
if len(argv) > 1:
|
192 |
+
raise app.UsageError('Too many command-line arguments.')
|
193 |
+
_run_pipeline()
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
app.run(main)
|
eval/interpolator_test.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""A test script for mid frame interpolation from two input frames.
|
16 |
+
|
17 |
+
Usage example:
|
18 |
+
python3 -m frame_interpolation.eval.interpolator_test \
|
19 |
+
--frame1 <filepath of the first frame> \
|
20 |
+
--frame2 <filepath of the second frame> \
|
21 |
+
--model_path <The filepath of the TF2 saved model to use>
|
22 |
+
|
23 |
+
The output is saved to <the directory of the input frames>/output_frame.png. If
|
24 |
+
`--output_frame` filepath is provided, it will be used instead.
|
25 |
+
"""
|
26 |
+
import os
|
27 |
+
from typing import Sequence
|
28 |
+
|
29 |
+
from . import interpolator as interpolator_lib
|
30 |
+
from . import util
|
31 |
+
from absl import app
|
32 |
+
from absl import flags
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
# Controls TF_CCP log level.
|
36 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
37 |
+
|
38 |
+
|
39 |
+
_FRAME1 = flags.DEFINE_string(
|
40 |
+
name='frame1',
|
41 |
+
default=None,
|
42 |
+
help='The filepath of the first input frame.',
|
43 |
+
required=True)
|
44 |
+
_FRAME2 = flags.DEFINE_string(
|
45 |
+
name='frame2',
|
46 |
+
default=None,
|
47 |
+
help='The filepath of the second input frame.',
|
48 |
+
required=True)
|
49 |
+
_MODEL_PATH = flags.DEFINE_string(
|
50 |
+
name='model_path',
|
51 |
+
default=None,
|
52 |
+
help='The path of the TF2 saved model to use.')
|
53 |
+
_OUTPUT_FRAME = flags.DEFINE_string(
|
54 |
+
name='output_frame',
|
55 |
+
default=None,
|
56 |
+
help='The output filepath of the interpolated mid-frame.')
|
57 |
+
_ALIGN = flags.DEFINE_integer(
|
58 |
+
name='align',
|
59 |
+
default=64,
|
60 |
+
help='If >1, pad the input size so it is evenly divisible by this value.')
|
61 |
+
_BLOCK_HEIGHT = flags.DEFINE_integer(
|
62 |
+
name='block_height',
|
63 |
+
default=1,
|
64 |
+
help='An int >= 1, number of patches along height, '
|
65 |
+
'patch_height = height//block_height, should be evenly divisible.')
|
66 |
+
_BLOCK_WIDTH = flags.DEFINE_integer(
|
67 |
+
name='block_width',
|
68 |
+
default=1,
|
69 |
+
help='An int >= 1, number of patches along width, '
|
70 |
+
'patch_width = width//block_width, should be evenly divisible.')
|
71 |
+
|
72 |
+
|
73 |
+
def _run_interpolator() -> None:
|
74 |
+
"""Writes interpolated mid frame from a given two input frame filepaths."""
|
75 |
+
|
76 |
+
interpolator = interpolator_lib.Interpolator(
|
77 |
+
model_path=_MODEL_PATH.value,
|
78 |
+
align=_ALIGN.value,
|
79 |
+
block_shape=[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
|
80 |
+
|
81 |
+
# First batched image.
|
82 |
+
image_1 = util.read_image(_FRAME1.value)
|
83 |
+
image_batch_1 = np.expand_dims(image_1, axis=0)
|
84 |
+
|
85 |
+
# Second batched image.
|
86 |
+
image_2 = util.read_image(_FRAME2.value)
|
87 |
+
image_batch_2 = np.expand_dims(image_2, axis=0)
|
88 |
+
|
89 |
+
# Batched time.
|
90 |
+
batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
|
91 |
+
|
92 |
+
# Invoke the model for one mid-frame interpolation.
|
93 |
+
mid_frame = interpolator(image_batch_1, image_batch_2, batch_dt)[0]
|
94 |
+
|
95 |
+
# Write interpolated mid-frame.
|
96 |
+
mid_frame_filepath = _OUTPUT_FRAME.value
|
97 |
+
if not mid_frame_filepath:
|
98 |
+
mid_frame_filepath = f'{os.path.dirname(_FRAME1.value)}/output_frame.png'
|
99 |
+
util.write_image(mid_frame_filepath, mid_frame)
|
100 |
+
|
101 |
+
|
102 |
+
def main(argv: Sequence[str]) -> None:
|
103 |
+
if len(argv) > 1:
|
104 |
+
raise app.UsageError('Too many command-line arguments.')
|
105 |
+
_run_interpolator()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
app.run(main)
|
eval/util.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Utility functions for frame interpolation on a set of video frames."""
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
from typing import Generator, Iterable, List, Optional
|
19 |
+
|
20 |
+
from . import interpolator as interpolator_lib
|
21 |
+
import numpy as np
|
22 |
+
import tensorflow as tf
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
|
26 |
+
_CONFIG_FFMPEG_NAME_OR_PATH = 'ffmpeg'
|
27 |
+
|
28 |
+
|
29 |
+
def read_image(filename: str) -> np.ndarray:
|
30 |
+
"""Reads an sRgb 8-bit image.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
filename: The input filename to read.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
|
37 |
+
"""
|
38 |
+
image_data = tf.io.read_file(filename)
|
39 |
+
image = tf.io.decode_image(image_data, channels=3)
|
40 |
+
image_numpy = tf.cast(image, dtype=tf.float32).numpy()
|
41 |
+
return image_numpy / _UINT8_MAX_F
|
42 |
+
|
43 |
+
|
44 |
+
def write_image(filename: str, image: np.ndarray) -> None:
|
45 |
+
"""Writes a float32 3-channel RGB ndarray image, with colors in range [0..1].
|
46 |
+
|
47 |
+
Args:
|
48 |
+
filename: The output filename to save.
|
49 |
+
image: A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
|
50 |
+
"""
|
51 |
+
image_in_uint8_range = np.clip(image * _UINT8_MAX_F, 0.0, _UINT8_MAX_F)
|
52 |
+
image_in_uint8 = (image_in_uint8_range + 0.5).astype(np.uint8)
|
53 |
+
|
54 |
+
extension = os.path.splitext(filename)[1]
|
55 |
+
if extension == '.jpg':
|
56 |
+
image_data = tf.io.encode_jpeg(image_in_uint8)
|
57 |
+
else:
|
58 |
+
image_data = tf.io.encode_png(image_in_uint8)
|
59 |
+
tf.io.write_file(filename, image_data)
|
60 |
+
|
61 |
+
|
62 |
+
def _recursive_generator(
|
63 |
+
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
|
64 |
+
interpolator: interpolator_lib.Interpolator,
|
65 |
+
bar: Optional[tqdm] = None
|
66 |
+
) -> Generator[np.ndarray, None, None]:
|
67 |
+
"""Splits halfway to repeatedly generate more frames.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
frame1: Input image 1.
|
71 |
+
frame2: Input image 2.
|
72 |
+
num_recursions: How many times to interpolate the consecutive image pairs.
|
73 |
+
interpolator: The frame interpolator instance.
|
74 |
+
|
75 |
+
Yields:
|
76 |
+
The interpolated frames, including the first frame (frame1), but excluding
|
77 |
+
the final frame2.
|
78 |
+
"""
|
79 |
+
if num_recursions == 0:
|
80 |
+
yield frame1
|
81 |
+
else:
|
82 |
+
# Adds the batch dimension to all inputs before calling the interpolator,
|
83 |
+
# and remove it afterwards.
|
84 |
+
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
|
85 |
+
mid_frame = interpolator(frame1[np.newaxis, ...], frame2[np.newaxis, ...],
|
86 |
+
time)[0]
|
87 |
+
bar.update(1) if bar is not None else bar
|
88 |
+
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1,
|
89 |
+
interpolator, bar)
|
90 |
+
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1,
|
91 |
+
interpolator, bar)
|
92 |
+
|
93 |
+
|
94 |
+
def interpolate_recursively_from_files(
|
95 |
+
frames: List[str], times_to_interpolate: int,
|
96 |
+
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
|
97 |
+
"""Generates interpolated frames by repeatedly interpolating the midpoint.
|
98 |
+
|
99 |
+
Loads the files on demand and uses the yield paradigm to return the frames
|
100 |
+
to allow streamed processing of longer videos.
|
101 |
+
|
102 |
+
Recursive interpolation is useful if the interpolator is trained to predict
|
103 |
+
frames at midpoint only and is thus expected to perform poorly elsewhere.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
frames: List of input frames. Expected shape (H, W, 3). The colors should be
|
107 |
+
in the range[0, 1] and in gamma space.
|
108 |
+
times_to_interpolate: Number of times to do recursive midpoint
|
109 |
+
interpolation.
|
110 |
+
interpolator: The frame interpolation model to use.
|
111 |
+
|
112 |
+
Yields:
|
113 |
+
The interpolated frames (including the inputs).
|
114 |
+
"""
|
115 |
+
n = len(frames)
|
116 |
+
num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
|
117 |
+
bar = tqdm(total=num_frames, ncols=100, colour='green')
|
118 |
+
for i in range(1, n):
|
119 |
+
yield from _recursive_generator(
|
120 |
+
read_image(frames[i - 1]), read_image(frames[i]), times_to_interpolate,
|
121 |
+
interpolator, bar)
|
122 |
+
# Separately yield the final frame.
|
123 |
+
yield read_image(frames[-1])
|
124 |
+
|
125 |
+
def interpolate_recursively_from_memory(
|
126 |
+
frames: List[np.ndarray], times_to_interpolate: int,
|
127 |
+
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
|
128 |
+
"""Generates interpolated frames by repeatedly interpolating the midpoint.
|
129 |
+
|
130 |
+
This is functionally equivalent to interpolate_recursively_from_files(), but
|
131 |
+
expects the inputs frames in memory, instead of loading them on demand.
|
132 |
+
|
133 |
+
Recursive interpolation is useful if the interpolator is trained to predict
|
134 |
+
frames at midpoint only and is thus expected to perform poorly elsewhere.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
frames: List of input frames. Expected shape (H, W, 3). The colors should be
|
138 |
+
in the range[0, 1] and in gamma space.
|
139 |
+
times_to_interpolate: Number of times to do recursive midpoint
|
140 |
+
interpolation.
|
141 |
+
interpolator: The frame interpolation model to use.
|
142 |
+
|
143 |
+
Yields:
|
144 |
+
The interpolated frames (including the inputs).
|
145 |
+
"""
|
146 |
+
n = len(frames)
|
147 |
+
num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
|
148 |
+
bar = tqdm(total=num_frames, ncols=100, colour='green')
|
149 |
+
for i in range(1, n):
|
150 |
+
yield from _recursive_generator(frames[i - 1], frames[i],
|
151 |
+
times_to_interpolate, interpolator, bar)
|
152 |
+
# Separately yield the final frame.
|
153 |
+
yield frames[-1]
|
154 |
+
|
155 |
+
|
156 |
+
def get_ffmpeg_path() -> str:
|
157 |
+
path = shutil.which(_CONFIG_FFMPEG_NAME_OR_PATH)
|
158 |
+
if not path:
|
159 |
+
raise RuntimeError(
|
160 |
+
f"Program '{_CONFIG_FFMPEG_NAME_OR_PATH}' is not found;"
|
161 |
+
" perhaps install ffmpeg using 'apt-get install ffmpeg'.")
|
162 |
+
return path
|
losses/losses.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Loss functions used to train the FILM interpolation model.
|
16 |
+
|
17 |
+
The losses for training and test loops are configurable via gin. Training can
|
18 |
+
use more than one loss function. Test loop can also evaluate one ore more loss
|
19 |
+
functions, each of which can be summarized separately.
|
20 |
+
"""
|
21 |
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
22 |
+
|
23 |
+
from . import vgg19_loss as vgg19
|
24 |
+
import gin.tf
|
25 |
+
import numpy as np
|
26 |
+
import tensorflow as tf
|
27 |
+
|
28 |
+
|
29 |
+
@gin.configurable('vgg', denylist=['example', 'prediction'])
|
30 |
+
def vgg_loss(example: Mapping[str, tf.Tensor],
|
31 |
+
prediction: Mapping[str, tf.Tensor],
|
32 |
+
vgg_model_file: str,
|
33 |
+
weights: Optional[List[float]] = None) -> tf.Tensor:
|
34 |
+
"""Perceptual loss for images in [0,1] color range.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
example: A dictionary with the ground truth image as 'y'.
|
38 |
+
prediction: The prediction dictionary with the image as 'image'.
|
39 |
+
vgg_model_file: The path containing the vgg19 weights in MATLAB format.
|
40 |
+
weights: An optional array of weights for different VGG layers. If None, the
|
41 |
+
default weights are used (see vgg19.vgg_loss documentation).
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
The perceptual loss.
|
45 |
+
"""
|
46 |
+
return vgg19.vgg_loss(prediction['image'], example['y'], vgg_model_file,
|
47 |
+
weights)
|
48 |
+
|
49 |
+
|
50 |
+
@gin.configurable('style', denylist=['example', 'prediction'])
|
51 |
+
def style_loss(example: Mapping[str, tf.Tensor],
|
52 |
+
prediction: Mapping[str, tf.Tensor],
|
53 |
+
vgg_model_file: str,
|
54 |
+
weights: Optional[List[float]] = None) -> tf.Tensor:
|
55 |
+
"""Computes style loss from images in [0..1] color range.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
example: A dictionary with the ground truth image as 'y'.
|
59 |
+
prediction: The prediction dictionary with the image as 'image'.
|
60 |
+
vgg_model_file: The path containing the vgg19 weights in MATLAB format.
|
61 |
+
weights: An optional array of weights for different VGG layers. If None, the
|
62 |
+
default weights are used (see vgg19.vgg_loss documentation).
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
A tf.Tensor of a scalar representing the style loss computed over multiple
|
66 |
+
vgg layer features.
|
67 |
+
"""
|
68 |
+
return vgg19.style_loss(prediction['image'], example['y'], vgg_model_file,
|
69 |
+
weights)
|
70 |
+
|
71 |
+
|
72 |
+
def l1_loss(example: Mapping[str, tf.Tensor],
|
73 |
+
prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
|
74 |
+
return tf.reduce_mean(tf.abs(prediction['image'] - example['y']))
|
75 |
+
|
76 |
+
|
77 |
+
def l1_warped_loss(example: Mapping[str, tf.Tensor],
|
78 |
+
prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
|
79 |
+
"""Computes an l1 loss using only warped images.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
example: A dictionary with the ground truth image as 'y'.
|
83 |
+
prediction: The prediction dictionary with the image(s) as 'x0_warped'
|
84 |
+
and/or 'x1_warped'.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
A tf.Tensor of a scalar representing the linear combination of l1 losses
|
88 |
+
between prediction images and y.
|
89 |
+
"""
|
90 |
+
loss = tf.constant(0.0, dtype=tf.float32)
|
91 |
+
if 'x0_warped' in prediction:
|
92 |
+
loss += tf.reduce_mean(tf.abs(prediction['x0_warped'] - example['y']))
|
93 |
+
if 'x1_warped' in prediction:
|
94 |
+
loss += tf.reduce_mean(tf.abs(prediction['x1_warped'] - example['y']))
|
95 |
+
return loss
|
96 |
+
|
97 |
+
|
98 |
+
def l2_loss(example: Mapping[str, tf.Tensor],
|
99 |
+
prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
|
100 |
+
return tf.reduce_mean(tf.square(prediction['image'] - example['y']))
|
101 |
+
|
102 |
+
|
103 |
+
def ssim_loss(example: Mapping[str, tf.Tensor],
|
104 |
+
prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
|
105 |
+
image = prediction['image']
|
106 |
+
y = example['y']
|
107 |
+
return tf.reduce_mean(tf.image.ssim(image, y, max_val=1.0))
|
108 |
+
|
109 |
+
|
110 |
+
def psnr_loss(example: Mapping[str, tf.Tensor],
|
111 |
+
prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
|
112 |
+
return tf.reduce_mean(
|
113 |
+
tf.image.psnr(prediction['image'], example['y'], max_val=1.0))
|
114 |
+
|
115 |
+
|
116 |
+
def get_loss(loss_name: str) -> Callable[[Any, Any], tf.Tensor]:
|
117 |
+
"""Returns the loss function corresponding to the given name."""
|
118 |
+
if loss_name == 'l1':
|
119 |
+
return l1_loss
|
120 |
+
elif loss_name == 'l2':
|
121 |
+
return l2_loss
|
122 |
+
elif loss_name == 'ssim':
|
123 |
+
return ssim_loss
|
124 |
+
elif loss_name == 'vgg':
|
125 |
+
return vgg_loss
|
126 |
+
elif loss_name == 'style':
|
127 |
+
return style_loss
|
128 |
+
elif loss_name == 'psnr':
|
129 |
+
return psnr_loss
|
130 |
+
elif loss_name == 'l1_warped':
|
131 |
+
return l1_warped_loss
|
132 |
+
else:
|
133 |
+
raise ValueError('Invalid loss function %s' % loss_name)
|
134 |
+
|
135 |
+
|
136 |
+
# pylint: disable=unnecessary-lambda
|
137 |
+
def get_loss_op(loss_name):
|
138 |
+
"""Returns a function for creating a loss calculation op."""
|
139 |
+
loss = get_loss(loss_name)
|
140 |
+
return lambda example, prediction: loss(example, prediction)
|
141 |
+
|
142 |
+
|
143 |
+
def get_weight_op(weight_schedule):
|
144 |
+
"""Returns a function for creating an iteration dependent loss weight op."""
|
145 |
+
return lambda iterations: weight_schedule(iterations)
|
146 |
+
|
147 |
+
|
148 |
+
def create_losses(
|
149 |
+
loss_names: List[str], loss_weight_schedules: List[
|
150 |
+
tf.keras.optimizers.schedules.LearningRateSchedule]
|
151 |
+
) -> Dict[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
|
152 |
+
tf.Tensor]]]:
|
153 |
+
"""Returns a dictionary of functions for creating loss and loss_weight ops.
|
154 |
+
|
155 |
+
As an example, create_losses(['l1', 'l2'], [PiecewiseConstantDecay(),
|
156 |
+
PiecewiseConstantDecay()]) returns a dictionary with two keys, and each value
|
157 |
+
being a tuple of ops for loss calculation and loss_weight sampling.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
loss_names: Names of the losses.
|
161 |
+
loss_weight_schedules: Instances of loss weight schedules.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
A dictionary that contains the loss and weight schedule ops keyed by the
|
165 |
+
names.
|
166 |
+
"""
|
167 |
+
losses = dict()
|
168 |
+
for name, weight_schedule in zip(loss_names, loss_weight_schedules):
|
169 |
+
unique_values = np.unique(weight_schedule.values)
|
170 |
+
if len(unique_values) == 1 and unique_values[0] == 1.0:
|
171 |
+
# Special case 'no weight' for prettier TensorBoard summaries.
|
172 |
+
weighted_name = name
|
173 |
+
else:
|
174 |
+
# Weights are variable/scheduled, a constant "k" is used to
|
175 |
+
# indicate weights are iteration dependent.
|
176 |
+
weighted_name = 'k*' + name
|
177 |
+
losses[weighted_name] = (get_loss_op(name), get_weight_op(weight_schedule))
|
178 |
+
return losses
|
179 |
+
|
180 |
+
|
181 |
+
@gin.configurable
|
182 |
+
def training_losses(
|
183 |
+
loss_names: List[str],
|
184 |
+
loss_weights: Optional[List[float]] = None,
|
185 |
+
loss_weight_schedules: Optional[List[
|
186 |
+
tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
|
187 |
+
loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
|
188 |
+
) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
|
189 |
+
tf.Tensor]]]:
|
190 |
+
"""Creates the training loss functions and loss weight schedules."""
|
191 |
+
weight_schedules = []
|
192 |
+
if not loss_weights:
|
193 |
+
for weight_schedule, weight_parameters in zip(loss_weight_schedules,
|
194 |
+
loss_weight_parameters):
|
195 |
+
weight_schedules.append(weight_schedule(**weight_parameters))
|
196 |
+
else:
|
197 |
+
for loss_weight in loss_weights:
|
198 |
+
weight_parameters = {
|
199 |
+
'boundaries': [0],
|
200 |
+
'values': 2 * [
|
201 |
+
loss_weight,
|
202 |
+
]
|
203 |
+
}
|
204 |
+
weight_schedules.append(
|
205 |
+
tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
206 |
+
**weight_parameters))
|
207 |
+
|
208 |
+
return create_losses(loss_names, weight_schedules)
|
209 |
+
|
210 |
+
|
211 |
+
@gin.configurable
|
212 |
+
def test_losses(
|
213 |
+
loss_names: List[str],
|
214 |
+
loss_weights: Optional[List[float]] = None,
|
215 |
+
loss_weight_schedules: Optional[List[
|
216 |
+
tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
|
217 |
+
loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
|
218 |
+
) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
|
219 |
+
tf.Tensor]]]:
|
220 |
+
"""Creates the test loss functions and loss weight schedules."""
|
221 |
+
weight_schedules = []
|
222 |
+
if not loss_weights:
|
223 |
+
for weight_schedule, weight_parameters in zip(loss_weight_schedules,
|
224 |
+
loss_weight_parameters):
|
225 |
+
weight_schedules.append(weight_schedule(**weight_parameters))
|
226 |
+
else:
|
227 |
+
for loss_weight in loss_weights:
|
228 |
+
weight_parameters = {
|
229 |
+
'boundaries': [0],
|
230 |
+
'values': 2 * [
|
231 |
+
loss_weight,
|
232 |
+
]
|
233 |
+
}
|
234 |
+
weight_schedules.append(
|
235 |
+
tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
236 |
+
**weight_parameters))
|
237 |
+
|
238 |
+
return create_losses(loss_names, weight_schedules)
|
239 |
+
|
240 |
+
|
241 |
+
def aggregate_batch_losses(
|
242 |
+
batch_losses: List[Mapping[str, float]]) -> Mapping[str, float]:
|
243 |
+
"""Averages per batch losses into single dictionary for the whole epoch.
|
244 |
+
|
245 |
+
As an example, if the batch_losses contained per batch losses:
|
246 |
+
batch_losses = { {'l1': 0.2, 'ssim': 0.9}, {'l1': 0.3, 'ssim': 0.8}}
|
247 |
+
The returned dictionary would look like: { 'l1': 0.25, 'ssim': 0.95 }
|
248 |
+
|
249 |
+
Args:
|
250 |
+
batch_losses: A list of dictionary objects, with one entry for each loss.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
Single dictionary with the losses aggregated.
|
254 |
+
"""
|
255 |
+
transp_losses = {}
|
256 |
+
# Loop through all losses
|
257 |
+
for batch_loss in batch_losses:
|
258 |
+
# Loop through per batch losses of a single type:
|
259 |
+
for loss_name, loss in batch_loss.items():
|
260 |
+
if loss_name not in transp_losses:
|
261 |
+
transp_losses[loss_name] = []
|
262 |
+
transp_losses[loss_name].append(loss)
|
263 |
+
aggregate_losses = {}
|
264 |
+
for loss_name in transp_losses:
|
265 |
+
aggregate_losses[loss_name] = np.mean(transp_losses[loss_name])
|
266 |
+
return aggregate_losses
|
losses/vgg19_loss.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Feature loss based on 19 layer VGG network.
|
16 |
+
|
17 |
+
|
18 |
+
The network layers in the feature loss is weighted as described in
|
19 |
+
'Stereo Magnification: Learning View Synthesis using Multiplane Images',
|
20 |
+
Tinghui Zhou, Richard Tucker, Flynn, Graham Fyffe, Noah Snavely, SIGGRAPH 2018.
|
21 |
+
"""
|
22 |
+
|
23 |
+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import scipy.io as sio
|
27 |
+
import tensorflow.compat.v1 as tf
|
28 |
+
|
29 |
+
|
30 |
+
def _build_net(layer_type: str,
|
31 |
+
input_tensor: tf.Tensor,
|
32 |
+
weight_bias: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
|
33 |
+
name: Optional[str] = None) -> Callable[[Any], Any]:
|
34 |
+
"""Build a layer of the VGG network.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
layer_type: A string, type of this layer.
|
38 |
+
input_tensor: A tensor.
|
39 |
+
weight_bias: A tuple of weight and bias.
|
40 |
+
name: A string, name of this layer.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
A callable function of the tensorflow layer.
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
ValueError: If layer_type is not conv or pool.
|
47 |
+
"""
|
48 |
+
|
49 |
+
if layer_type == 'conv':
|
50 |
+
return tf.nn.relu(
|
51 |
+
tf.nn.conv2d(
|
52 |
+
input_tensor,
|
53 |
+
weight_bias[0],
|
54 |
+
strides=[1, 1, 1, 1],
|
55 |
+
padding='SAME',
|
56 |
+
name=name) + weight_bias[1])
|
57 |
+
elif layer_type == 'pool':
|
58 |
+
return tf.nn.avg_pool(
|
59 |
+
input_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
|
60 |
+
else:
|
61 |
+
raise ValueError('Unsupported layer %s' % layer_type)
|
62 |
+
|
63 |
+
|
64 |
+
def _get_weight_and_bias(vgg_layers: np.ndarray,
|
65 |
+
index: int) -> Tuple[tf.Tensor, tf.Tensor]:
|
66 |
+
"""Get the weight and bias of a specific layer from the VGG pretrained model.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
vgg_layers: An array, the VGG pretrained model.
|
70 |
+
index: An integer, index of the layer.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
weights: A tensor.
|
74 |
+
bias: A tensor.
|
75 |
+
"""
|
76 |
+
|
77 |
+
weights = vgg_layers[index][0][0][2][0][0]
|
78 |
+
weights = tf.constant(weights)
|
79 |
+
bias = vgg_layers[index][0][0][2][0][1]
|
80 |
+
bias = tf.constant(np.reshape(bias, (bias.size)))
|
81 |
+
|
82 |
+
return weights, bias
|
83 |
+
|
84 |
+
|
85 |
+
def _build_vgg19(image: tf.Tensor, model_filepath: str) -> Dict[str, tf.Tensor]:
|
86 |
+
"""Builds the VGG network given the model weights.
|
87 |
+
|
88 |
+
The weights are loaded only for the first time this code is invoked.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
image: A tensor, input image.
|
92 |
+
model_filepath: A string, path to the VGG pretrained model.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
net: A dict mapping a layer name to a tensor.
|
96 |
+
"""
|
97 |
+
|
98 |
+
with tf.variable_scope('vgg', reuse=True):
|
99 |
+
net = {}
|
100 |
+
if not hasattr(_build_vgg19, 'vgg_rawnet'):
|
101 |
+
with tf.io.gfile.GFile(model_filepath, 'rb') as f:
|
102 |
+
_build_vgg19.vgg_rawnet = sio.loadmat(f)
|
103 |
+
vgg_layers = _build_vgg19.vgg_rawnet['layers'][0]
|
104 |
+
imagenet_mean = tf.constant([123.6800, 116.7790, 103.9390],
|
105 |
+
shape=[1, 1, 1, 3])
|
106 |
+
net['input'] = image - imagenet_mean
|
107 |
+
net['conv1_1'] = _build_net(
|
108 |
+
'conv',
|
109 |
+
net['input'],
|
110 |
+
_get_weight_and_bias(vgg_layers, 0),
|
111 |
+
name='vgg_conv1_1')
|
112 |
+
net['conv1_2'] = _build_net(
|
113 |
+
'conv',
|
114 |
+
net['conv1_1'],
|
115 |
+
_get_weight_and_bias(vgg_layers, 2),
|
116 |
+
name='vgg_conv1_2')
|
117 |
+
net['pool1'] = _build_net('pool', net['conv1_2'])
|
118 |
+
net['conv2_1'] = _build_net(
|
119 |
+
'conv',
|
120 |
+
net['pool1'],
|
121 |
+
_get_weight_and_bias(vgg_layers, 5),
|
122 |
+
name='vgg_conv2_1')
|
123 |
+
net['conv2_2'] = _build_net(
|
124 |
+
'conv',
|
125 |
+
net['conv2_1'],
|
126 |
+
_get_weight_and_bias(vgg_layers, 7),
|
127 |
+
name='vgg_conv2_2')
|
128 |
+
net['pool2'] = _build_net('pool', net['conv2_2'])
|
129 |
+
net['conv3_1'] = _build_net(
|
130 |
+
'conv',
|
131 |
+
net['pool2'],
|
132 |
+
_get_weight_and_bias(vgg_layers, 10),
|
133 |
+
name='vgg_conv3_1')
|
134 |
+
net['conv3_2'] = _build_net(
|
135 |
+
'conv',
|
136 |
+
net['conv3_1'],
|
137 |
+
_get_weight_and_bias(vgg_layers, 12),
|
138 |
+
name='vgg_conv3_2')
|
139 |
+
net['conv3_3'] = _build_net(
|
140 |
+
'conv',
|
141 |
+
net['conv3_2'],
|
142 |
+
_get_weight_and_bias(vgg_layers, 14),
|
143 |
+
name='vgg_conv3_3')
|
144 |
+
net['conv3_4'] = _build_net(
|
145 |
+
'conv',
|
146 |
+
net['conv3_3'],
|
147 |
+
_get_weight_and_bias(vgg_layers, 16),
|
148 |
+
name='vgg_conv3_4')
|
149 |
+
net['pool3'] = _build_net('pool', net['conv3_4'])
|
150 |
+
net['conv4_1'] = _build_net(
|
151 |
+
'conv',
|
152 |
+
net['pool3'],
|
153 |
+
_get_weight_and_bias(vgg_layers, 19),
|
154 |
+
name='vgg_conv4_1')
|
155 |
+
net['conv4_2'] = _build_net(
|
156 |
+
'conv',
|
157 |
+
net['conv4_1'],
|
158 |
+
_get_weight_and_bias(vgg_layers, 21),
|
159 |
+
name='vgg_conv4_2')
|
160 |
+
net['conv4_3'] = _build_net(
|
161 |
+
'conv',
|
162 |
+
net['conv4_2'],
|
163 |
+
_get_weight_and_bias(vgg_layers, 23),
|
164 |
+
name='vgg_conv4_3')
|
165 |
+
net['conv4_4'] = _build_net(
|
166 |
+
'conv',
|
167 |
+
net['conv4_3'],
|
168 |
+
_get_weight_and_bias(vgg_layers, 25),
|
169 |
+
name='vgg_conv4_4')
|
170 |
+
net['pool4'] = _build_net('pool', net['conv4_4'])
|
171 |
+
net['conv5_1'] = _build_net(
|
172 |
+
'conv',
|
173 |
+
net['pool4'],
|
174 |
+
_get_weight_and_bias(vgg_layers, 28),
|
175 |
+
name='vgg_conv5_1')
|
176 |
+
net['conv5_2'] = _build_net(
|
177 |
+
'conv',
|
178 |
+
net['conv5_1'],
|
179 |
+
_get_weight_and_bias(vgg_layers, 30),
|
180 |
+
name='vgg_conv5_2')
|
181 |
+
|
182 |
+
return net
|
183 |
+
|
184 |
+
|
185 |
+
def _compute_error(fake: tf.Tensor,
|
186 |
+
real: tf.Tensor,
|
187 |
+
mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
188 |
+
"""Computes the L1 loss and reweights by the mask."""
|
189 |
+
if mask is None:
|
190 |
+
return tf.reduce_mean(tf.abs(fake - real))
|
191 |
+
else:
|
192 |
+
# Resizes mask to the same size as the input.
|
193 |
+
size = (tf.shape(fake)[1], tf.shape(fake)[2])
|
194 |
+
resized_mask = tf.image.resize(
|
195 |
+
mask, size, method=tf.image.ResizeMethod.BILINEAR)
|
196 |
+
return tf.reduce_mean(tf.abs(fake - real) * resized_mask)
|
197 |
+
|
198 |
+
|
199 |
+
# Normalized VGG loss (from
|
200 |
+
# https://github.com/CQFIO/PhotographicImageSynthesis)
|
201 |
+
def vgg_loss(image: tf.Tensor,
|
202 |
+
reference: tf.Tensor,
|
203 |
+
vgg_model_file: str,
|
204 |
+
weights: Optional[Sequence[float]] = None,
|
205 |
+
mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
206 |
+
"""Computes the VGG loss for an image pair.
|
207 |
+
|
208 |
+
The VGG loss is the average feature vector difference between the two images.
|
209 |
+
|
210 |
+
The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
|
211 |
+
the recommendation seems to be to have them in gamma space.
|
212 |
+
|
213 |
+
The pretrained weights are publicly available in
|
214 |
+
http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
|
215 |
+
|
216 |
+
Args:
|
217 |
+
image: A tensor, typically the prediction from a network.
|
218 |
+
reference: A tensor, the image to compare against, i.e. the golden image.
|
219 |
+
vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
|
220 |
+
format.
|
221 |
+
weights: A list of float, optional weights for the layers. The defaults are
|
222 |
+
from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
|
223 |
+
cascaded refinement networks," ICCV 2017.
|
224 |
+
mask: An optional image-shape and single-channel tensor, the mask values are
|
225 |
+
per-pixel weights to be applied on the losses. The mask will be resized to
|
226 |
+
the same spatial resolution with the feature maps before been applied to
|
227 |
+
the losses. When the mask value is zero, pixels near the boundary of the
|
228 |
+
mask can still influence the loss if they fall into the receptive field of
|
229 |
+
the VGG convolutional layers.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
vgg_loss: The linear combination of losses from five VGG layers.
|
233 |
+
"""
|
234 |
+
|
235 |
+
if not weights:
|
236 |
+
weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
|
237 |
+
|
238 |
+
vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
|
239 |
+
vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
|
240 |
+
p1 = _compute_error(vgg_ref['conv1_2'], vgg_img['conv1_2'], mask) * weights[0]
|
241 |
+
p2 = _compute_error(vgg_ref['conv2_2'], vgg_img['conv2_2'], mask) * weights[1]
|
242 |
+
p3 = _compute_error(vgg_ref['conv3_2'], vgg_img['conv3_2'], mask) * weights[2]
|
243 |
+
p4 = _compute_error(vgg_ref['conv4_2'], vgg_img['conv4_2'], mask) * weights[3]
|
244 |
+
p5 = _compute_error(vgg_ref['conv5_2'], vgg_img['conv5_2'], mask) * weights[4]
|
245 |
+
|
246 |
+
final_loss = p1 + p2 + p3 + p4 + p5
|
247 |
+
|
248 |
+
# Scale to range [0..1].
|
249 |
+
final_loss /= 255.0
|
250 |
+
|
251 |
+
return final_loss
|
252 |
+
|
253 |
+
|
254 |
+
def _compute_gram_matrix(input_features: tf.Tensor,
|
255 |
+
mask: tf.Tensor) -> tf.Tensor:
|
256 |
+
"""Computes Gram matrix of `input_features`.
|
257 |
+
|
258 |
+
Gram matrix described in https://en.wikipedia.org/wiki/Gramian_matrix.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
input_features: A tf.Tensor of shape (B, H, W, C) representing a feature map
|
262 |
+
obtained by a convolutional layer of a VGG network.
|
263 |
+
mask: A tf.Tensor of shape (B, H, W, 1) representing the per-pixel weights
|
264 |
+
to be applied on the `input_features`. The mask will be resized to the
|
265 |
+
same spatial resolution as the `input_featues`. When the mask value is
|
266 |
+
zero, pixels near the boundary of the mask can still influence the loss if
|
267 |
+
they fall into the receptive field of the VGG convolutional layers.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
A tf.Tensor of shape (B, C, C) representing the gram matrix of the masked
|
271 |
+
`input_features`.
|
272 |
+
"""
|
273 |
+
_, h, w, c = tuple([
|
274 |
+
i if (isinstance(i, int) or i is None) else i.value
|
275 |
+
for i in input_features.shape
|
276 |
+
])
|
277 |
+
if mask is None:
|
278 |
+
reshaped_features = tf.reshape(input_features, (-1, h * w, c))
|
279 |
+
else:
|
280 |
+
# Resize mask to match the shape of `input_features`
|
281 |
+
resized_mask = tf.image.resize(
|
282 |
+
mask, (h, w), method=tf.image.ResizeMethod.BILINEAR)
|
283 |
+
reshaped_features = tf.reshape(input_features * resized_mask,
|
284 |
+
(-1, h * w, c))
|
285 |
+
return tf.matmul(
|
286 |
+
reshaped_features, reshaped_features, transpose_a=True) / float(h * w)
|
287 |
+
|
288 |
+
|
289 |
+
def style_loss(image: tf.Tensor,
|
290 |
+
reference: tf.Tensor,
|
291 |
+
vgg_model_file: str,
|
292 |
+
weights: Optional[Sequence[float]] = None,
|
293 |
+
mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
294 |
+
"""Computes style loss as used in `A Neural Algorithm of Artistic Style`.
|
295 |
+
|
296 |
+
Based on the work in https://github.com/cysmith/neural-style-tf. Weights are
|
297 |
+
first initilaized to the inverse of the number of elements in each VGG layer
|
298 |
+
considerd. After 1.5M iterations, they are rescaled to normalize the
|
299 |
+
contribution of the Style loss to be equal to other losses (L1/VGG). This is
|
300 |
+
based on the works of image inpainting (https://arxiv.org/abs/1804.07723)
|
301 |
+
and frame prediction (https://arxiv.org/abs/1811.00684).
|
302 |
+
|
303 |
+
The style loss is the average gram matrix difference between `image` and
|
304 |
+
`reference`. The gram matrix is the inner product of a feature map of shape
|
305 |
+
(B, H*W, C) with itself. Results in a symmetric gram matrix shaped (B, C, C).
|
306 |
+
|
307 |
+
The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
|
308 |
+
the recommendation seems to be to have them in gamma space.
|
309 |
+
|
310 |
+
The pretrained weights are publicly available in
|
311 |
+
http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
|
312 |
+
|
313 |
+
Args:
|
314 |
+
image: A tensor, typically the prediction from a network.
|
315 |
+
reference: A tensor, the image to compare against, i.e. the golden image.
|
316 |
+
vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
|
317 |
+
format.
|
318 |
+
weights: A list of float, optional weights for the layers. The defaults are
|
319 |
+
from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
|
320 |
+
cascaded refinement networks," ICCV 2017.
|
321 |
+
mask: An optional image-shape and single-channel tensor, the mask values are
|
322 |
+
per-pixel weights to be applied on the losses. The mask will be resized to
|
323 |
+
the same spatial resolution with the feature maps before been applied to
|
324 |
+
the losses. When the mask value is zero, pixels near the boundary of the
|
325 |
+
mask can still influence the loss if they fall into the receptive field of
|
326 |
+
the VGG convolutional layers.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
Style loss, a linear combination of gram matrix L2 differences of from five
|
330 |
+
VGG layer features.
|
331 |
+
"""
|
332 |
+
|
333 |
+
if not weights:
|
334 |
+
weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
|
335 |
+
|
336 |
+
vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
|
337 |
+
vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
|
338 |
+
|
339 |
+
p1 = tf.reduce_mean(
|
340 |
+
tf.squared_difference(
|
341 |
+
_compute_gram_matrix(vgg_ref['conv1_2'] / 255.0, mask),
|
342 |
+
_compute_gram_matrix(vgg_img['conv1_2'] / 255.0, mask))) * weights[0]
|
343 |
+
p2 = tf.reduce_mean(
|
344 |
+
tf.squared_difference(
|
345 |
+
_compute_gram_matrix(vgg_ref['conv2_2'] / 255.0, mask),
|
346 |
+
_compute_gram_matrix(vgg_img['conv2_2'] / 255.0, mask))) * weights[1]
|
347 |
+
p3 = tf.reduce_mean(
|
348 |
+
tf.squared_difference(
|
349 |
+
_compute_gram_matrix(vgg_ref['conv3_2'] / 255.0, mask),
|
350 |
+
_compute_gram_matrix(vgg_img['conv3_2'] / 255.0, mask))) * weights[2]
|
351 |
+
p4 = tf.reduce_mean(
|
352 |
+
tf.squared_difference(
|
353 |
+
_compute_gram_matrix(vgg_ref['conv4_2'] / 255.0, mask),
|
354 |
+
_compute_gram_matrix(vgg_img['conv4_2'] / 255.0, mask))) * weights[3]
|
355 |
+
p5 = tf.reduce_mean(
|
356 |
+
tf.squared_difference(
|
357 |
+
_compute_gram_matrix(vgg_ref['conv5_2'] / 255.0, mask),
|
358 |
+
_compute_gram_matrix(vgg_img['conv5_2'] / 255.0, mask))) * weights[4]
|
359 |
+
|
360 |
+
final_loss = p1 + p2 + p3 + p4 + p5
|
361 |
+
|
362 |
+
return final_loss
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/film_net/feature_extractor.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""TF2 layer for extracting image features for the film_net interpolator.
|
16 |
+
|
17 |
+
The feature extractor implemented here converts an image pyramid into a pyramid
|
18 |
+
of deep features. The feature pyramid serves a similar purpose as U-Net
|
19 |
+
architecture's encoder, but we use a special cascaded architecture described in
|
20 |
+
Multi-view Image Fusion [1].
|
21 |
+
|
22 |
+
For comprehensiveness, below is a short description of the idea. While the
|
23 |
+
description is a bit involved, the cascaded feature pyramid can be used just
|
24 |
+
like any image feature pyramid.
|
25 |
+
|
26 |
+
Why cascaded architeture?
|
27 |
+
=========================
|
28 |
+
To understand the concept it is worth reviewing a traditional feature pyramid
|
29 |
+
first: *A traditional feature pyramid* as in U-net or in many optical flow
|
30 |
+
networks is built by alternating between convolutions and pooling, starting
|
31 |
+
from the input image.
|
32 |
+
|
33 |
+
It is well known that early features of such architecture correspond to low
|
34 |
+
level concepts such as edges in the image whereas later layers extract
|
35 |
+
semantically higher level concepts such as object classes etc. In other words,
|
36 |
+
the meaning of the filters in each resolution level is different. For problems
|
37 |
+
such as semantic segmentation and many others this is a desirable property.
|
38 |
+
|
39 |
+
However, the asymmetric features preclude sharing weights across resolution
|
40 |
+
levels in the feature extractor itself and in any subsequent neural networks
|
41 |
+
that follow. This can be a downside, since optical flow prediction, for
|
42 |
+
instance is symmetric across resolution levels. The cascaded feature
|
43 |
+
architecture addresses this shortcoming.
|
44 |
+
|
45 |
+
How is it built?
|
46 |
+
================
|
47 |
+
The *cascaded* feature pyramid contains feature vectors that have constant
|
48 |
+
length and meaning on each resolution level, except few of the finest ones. The
|
49 |
+
advantage of this is that the subsequent optical flow layer can learn
|
50 |
+
synergically from many resolutions. This means that coarse level prediction can
|
51 |
+
benefit from finer resolution training examples, which can be useful with
|
52 |
+
moderately sized datasets to avoid overfitting.
|
53 |
+
|
54 |
+
The cascaded feature pyramid is built by extracting shallower subtree pyramids,
|
55 |
+
each one of them similar to the traditional architecture. Each subtree
|
56 |
+
pyramid S_i is extracted starting from each resolution level:
|
57 |
+
|
58 |
+
image resolution 0 -> S_0
|
59 |
+
image resolution 1 -> S_1
|
60 |
+
image resolution 2 -> S_2
|
61 |
+
...
|
62 |
+
|
63 |
+
If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
|
64 |
+
is constructed by concatenating features as follows (assuming subtree depth=3):
|
65 |
+
|
66 |
+
lvl
|
67 |
+
feat_0 = concat( S_0_0 )
|
68 |
+
feat_1 = concat( S_1_0 S_0_1 )
|
69 |
+
feat_2 = concat( S_2_0 S_1_1 S_0_2 )
|
70 |
+
feat_3 = concat( S_3_0 S_2_1 S_1_2 )
|
71 |
+
feat_4 = concat( S_4_0 S_3_1 S_2_2 )
|
72 |
+
feat_5 = concat( S_5_0 S_4_1 S_3_2 )
|
73 |
+
....
|
74 |
+
|
75 |
+
In above, all levels except feat_0 and feat_1 have the same number of features
|
76 |
+
with similar semantic meaning. This enables training a single optical flow
|
77 |
+
predictor module shared by levels 2,3,4,5... . For more details and evaluation
|
78 |
+
see [1].
|
79 |
+
|
80 |
+
[1] Multi-view Image Fusion, Trinidad et al. 2019
|
81 |
+
"""
|
82 |
+
|
83 |
+
from typing import List
|
84 |
+
|
85 |
+
from . import options
|
86 |
+
import tensorflow as tf
|
87 |
+
|
88 |
+
|
89 |
+
def _relu(x: tf.Tensor) -> tf.Tensor:
|
90 |
+
return tf.nn.leaky_relu(x, alpha=0.2)
|
91 |
+
|
92 |
+
|
93 |
+
def _conv(filters: int, name: str):
|
94 |
+
return tf.keras.layers.Conv2D(
|
95 |
+
name=name,
|
96 |
+
filters=filters,
|
97 |
+
kernel_size=3,
|
98 |
+
padding='same',
|
99 |
+
activation=_relu)
|
100 |
+
|
101 |
+
|
102 |
+
class SubTreeExtractor(tf.keras.layers.Layer):
|
103 |
+
"""Extracts a hierarchical set of features from an image.
|
104 |
+
|
105 |
+
This is a conventional, hierarchical image feature extractor, that extracts
|
106 |
+
[k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
|
107 |
+
Each level is followed by average pooling.
|
108 |
+
|
109 |
+
Attributes:
|
110 |
+
name: Name for the layer
|
111 |
+
config: Options for the fusion_net frame interpolator
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, name: str, config: options.Options):
|
115 |
+
super().__init__(name=name)
|
116 |
+
k = config.filters
|
117 |
+
n = config.sub_levels
|
118 |
+
self.convs = []
|
119 |
+
for i in range(n):
|
120 |
+
self.convs.append(
|
121 |
+
_conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i)))
|
122 |
+
self.convs.append(
|
123 |
+
_conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i + 1)))
|
124 |
+
|
125 |
+
def call(self, image: tf.Tensor, n: int) -> List[tf.Tensor]:
|
126 |
+
"""Extracts a pyramid of features from the image.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
image: tf.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
|
130 |
+
n: number of pyramid levels to extract. This can be less or equal to
|
131 |
+
options.sub_levels given in the __init__.
|
132 |
+
Returns:
|
133 |
+
The pyramid of features, starting from the finest level. Each element
|
134 |
+
contains the output after the last convolution on the corresponding
|
135 |
+
pyramid level.
|
136 |
+
"""
|
137 |
+
head = image
|
138 |
+
pool = tf.keras.layers.AveragePooling2D(
|
139 |
+
pool_size=2, strides=2, padding='valid')
|
140 |
+
pyramid = []
|
141 |
+
for i in range(n):
|
142 |
+
head = self.convs[2*i](head)
|
143 |
+
head = self.convs[2*i+1](head)
|
144 |
+
pyramid.append(head)
|
145 |
+
if i < n-1:
|
146 |
+
head = pool(head)
|
147 |
+
return pyramid
|
148 |
+
|
149 |
+
|
150 |
+
class FeatureExtractor(tf.keras.layers.Layer):
|
151 |
+
"""Extracts features from an image pyramid using a cascaded architecture.
|
152 |
+
|
153 |
+
Attributes:
|
154 |
+
name: Name of the layer
|
155 |
+
config: Options for the fusion_net frame interpolator
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, name: str, config: options.Options):
|
159 |
+
super().__init__(name=name)
|
160 |
+
self.extract_sublevels = SubTreeExtractor('sub_extractor', config)
|
161 |
+
self.options = config
|
162 |
+
|
163 |
+
def call(self, image_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
|
164 |
+
"""Extracts a cascaded feature pyramid.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
image_pyramid: Image pyramid as a list, starting from the finest level.
|
168 |
+
Returns:
|
169 |
+
A pyramid of cascaded features.
|
170 |
+
"""
|
171 |
+
sub_pyramids = []
|
172 |
+
for i in range(len(image_pyramid)):
|
173 |
+
# At each level of the image pyramid, creates a sub_pyramid of features
|
174 |
+
# with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
|
175 |
+
# We use the same instance since we want to share the weights.
|
176 |
+
#
|
177 |
+
# However, we cap the depth of the sub_pyramid so we don't create features
|
178 |
+
# that are beyond the coarsest level of the cascaded feature pyramid we
|
179 |
+
# want to generate.
|
180 |
+
capped_sub_levels = min(len(image_pyramid) - i, self.options.sub_levels)
|
181 |
+
sub_pyramids.append(
|
182 |
+
self.extract_sublevels(image_pyramid[i], capped_sub_levels))
|
183 |
+
# Below we generate the cascades of features on each level of the feature
|
184 |
+
# pyramid. Assuming sub_levels=3, The layout of the features will be
|
185 |
+
# as shown in the example on file documentation above.
|
186 |
+
feature_pyramid = []
|
187 |
+
for i in range(len(image_pyramid)):
|
188 |
+
features = sub_pyramids[i][0]
|
189 |
+
for j in range(1, self.options.sub_levels):
|
190 |
+
if j <= i:
|
191 |
+
features = tf.concat([features, sub_pyramids[i - j][j]], axis=-1)
|
192 |
+
feature_pyramid.append(features)
|
193 |
+
return feature_pyramid
|
models/film_net/fusion.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""The final fusion stage for the film_net frame interpolator.
|
16 |
+
|
17 |
+
The inputs to this module are the warped input images, image features and
|
18 |
+
flow fields, all aligned to the target frame (often midway point between the
|
19 |
+
two original inputs). The output is the final image. FILM has no explicit
|
20 |
+
occlusion handling -- instead using the abovementioned information this module
|
21 |
+
automatically decides how to best blend the inputs together to produce content
|
22 |
+
in areas where the pixels can only be borrowed from one of the inputs.
|
23 |
+
|
24 |
+
Similarly, this module also decides on how much to blend in each input in case
|
25 |
+
of fractional timestep that is not at the halfway point. For example, if the two
|
26 |
+
inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
|
27 |
+
it often makes most sense to favor the first input. However, this is not
|
28 |
+
always the case -- in particular in occluded pixels.
|
29 |
+
|
30 |
+
The architecture of the Fusion module follows U-net [1] architecture's decoder
|
31 |
+
side, e.g. each pyramid level consists of concatenation with upsampled coarser
|
32 |
+
level output, and two 3x3 convolutions.
|
33 |
+
|
34 |
+
The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
|
35 |
+
upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
|
36 |
+
uses max-pooling which has a tendency to create checkerboard artifacts.
|
37 |
+
|
38 |
+
[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
|
39 |
+
Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
|
40 |
+
[2] https://distill.pub/2016/deconv-checkerboard/
|
41 |
+
"""
|
42 |
+
|
43 |
+
from typing import List
|
44 |
+
|
45 |
+
from . import options
|
46 |
+
import tensorflow as tf
|
47 |
+
|
48 |
+
|
49 |
+
def _relu(x: tf.Tensor) -> tf.Tensor:
|
50 |
+
return tf.nn.leaky_relu(x, alpha=0.2)
|
51 |
+
|
52 |
+
|
53 |
+
_NUMBER_OF_COLOR_CHANNELS = 3
|
54 |
+
|
55 |
+
|
56 |
+
class Fusion(tf.keras.layers.Layer):
|
57 |
+
"""The decoder."""
|
58 |
+
|
59 |
+
def __init__(self, name: str, config: options.Options):
|
60 |
+
super().__init__(name=name)
|
61 |
+
|
62 |
+
# Each item 'convs[i]' will contain the list of convolutions to be applied
|
63 |
+
# for pyramid level 'i'.
|
64 |
+
self.convs: List[List[tf.keras.layers.Layer]] = []
|
65 |
+
|
66 |
+
# Store the levels, so we can verify right number of levels in call().
|
67 |
+
self.levels = config.fusion_pyramid_levels
|
68 |
+
|
69 |
+
# Create the convolutions. Roughly following the feature extractor, we
|
70 |
+
# double the number of filters when the resolution halves, but only up to
|
71 |
+
# the specialized_levels, after which we use the same number of filters on
|
72 |
+
# all levels.
|
73 |
+
#
|
74 |
+
# We create the convs in fine-to-coarse order, so that the array index
|
75 |
+
# for the convs will correspond to our normal indexing (0=finest level).
|
76 |
+
for i in range(config.fusion_pyramid_levels - 1):
|
77 |
+
m = config.specialized_levels
|
78 |
+
k = config.filters
|
79 |
+
num_filters = (k << i) if i < m else (k << m)
|
80 |
+
|
81 |
+
convs: List[tf.keras.layers.Layer] = []
|
82 |
+
convs.append(
|
83 |
+
tf.keras.layers.Conv2D(
|
84 |
+
filters=num_filters, kernel_size=[2, 2], padding='same'))
|
85 |
+
convs.append(
|
86 |
+
tf.keras.layers.Conv2D(
|
87 |
+
filters=num_filters,
|
88 |
+
kernel_size=[3, 3],
|
89 |
+
padding='same',
|
90 |
+
activation=_relu))
|
91 |
+
convs.append(
|
92 |
+
tf.keras.layers.Conv2D(
|
93 |
+
filters=num_filters,
|
94 |
+
kernel_size=[3, 3],
|
95 |
+
padding='same',
|
96 |
+
activation=_relu))
|
97 |
+
self.convs.append(convs)
|
98 |
+
|
99 |
+
# The final convolution that outputs RGB:
|
100 |
+
self.output_conv = tf.keras.layers.Conv2D(
|
101 |
+
filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1)
|
102 |
+
|
103 |
+
def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor:
|
104 |
+
"""Runs the fusion module.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
pyramid: The input feature pyramid as list of tensors. Each tensor being
|
108 |
+
in (B x H x W x C) format, with finest level tensor first.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
A batch of RGB images.
|
112 |
+
Raises:
|
113 |
+
ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
|
114 |
+
the constructor.
|
115 |
+
"""
|
116 |
+
if len(pyramid) != self.levels:
|
117 |
+
raise ValueError(
|
118 |
+
'Fusion called with different number of pyramid levels '
|
119 |
+
f'{len(pyramid)} than it was configured for, {self.levels}.')
|
120 |
+
|
121 |
+
# As a slight difference to a conventional decoder (e.g. U-net), we don't
|
122 |
+
# apply any extra convolutions to the coarsest level, but just pass it
|
123 |
+
# to finer levels for concatenation. This choice has not been thoroughly
|
124 |
+
# evaluated, but is motivated by the educated guess that the fusion part
|
125 |
+
# probably does not need large spatial context, because at this point the
|
126 |
+
# features are spatially aligned by the preceding warp.
|
127 |
+
net = pyramid[-1]
|
128 |
+
|
129 |
+
# Loop starting from the 2nd coarsest level:
|
130 |
+
for i in reversed(range(0, self.levels - 1)):
|
131 |
+
# Resize the tensor from coarser level to match for concatenation.
|
132 |
+
level_size = tf.shape(pyramid[i])[1:3]
|
133 |
+
net = tf.image.resize(net, level_size,
|
134 |
+
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
135 |
+
net = self.convs[i][0](net)
|
136 |
+
net = tf.concat([pyramid[i], net], axis=-1)
|
137 |
+
net = self.convs[i][1](net)
|
138 |
+
net = self.convs[i][2](net)
|
139 |
+
net = self.output_conv(net)
|
140 |
+
return net
|
models/film_net/interpolator.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""The film_net frame interpolator main model code.
|
16 |
+
|
17 |
+
Basics
|
18 |
+
======
|
19 |
+
The film_net is an end-to-end learned neural frame interpolator implemented as
|
20 |
+
a TF2 model. It has the following inputs and outputs:
|
21 |
+
|
22 |
+
Inputs:
|
23 |
+
x0: image A.
|
24 |
+
x1: image B.
|
25 |
+
time: desired sub-frame time.
|
26 |
+
|
27 |
+
Outputs:
|
28 |
+
image: the predicted in-between image at the chosen time in range [0, 1].
|
29 |
+
|
30 |
+
Additional outputs include forward and backward warped image pyramids, flow
|
31 |
+
pyramids, etc., that can be visualized for debugging and analysis.
|
32 |
+
|
33 |
+
Note that many training sets only contain triplets with ground truth at
|
34 |
+
time=0.5. If a model has been trained with such training set, it will only work
|
35 |
+
well for synthesizing frames at time=0.5. Such models can only generate more
|
36 |
+
in-between frames using recursion.
|
37 |
+
|
38 |
+
Architecture
|
39 |
+
============
|
40 |
+
The inference consists of three main stages: 1) feature extraction 2) warping
|
41 |
+
3) fusion. On high-level, the architecture has similarities to Context-aware
|
42 |
+
Synthesis for Video Frame Interpolation [1], but the exact architecture is
|
43 |
+
closer to Multi-view Image Fusion [2] with some modifications for the frame
|
44 |
+
interpolation use-case.
|
45 |
+
|
46 |
+
Feature extraction stage employs the cascaded multi-scale architecture described
|
47 |
+
in [2]. The advantage of this architecture is that coarse level flow prediction
|
48 |
+
can be learned from finer resolution image samples. This is especially useful
|
49 |
+
to avoid overfitting with moderately sized datasets.
|
50 |
+
|
51 |
+
The warping stage uses a residual flow prediction idea that is similar to
|
52 |
+
PWC-Net [3], Multi-view Image Fusion [2] and many others.
|
53 |
+
|
54 |
+
The fusion stage is similar to U-Net's decoder where the skip connections are
|
55 |
+
connected to warped image and feature pyramids. This is described in [2].
|
56 |
+
|
57 |
+
Implementation Conventions
|
58 |
+
====================
|
59 |
+
Pyramids
|
60 |
+
--------
|
61 |
+
Throughtout the model, all image and feature pyramids are stored as python lists
|
62 |
+
with finest level first followed by downscaled versions obtained by successively
|
63 |
+
halving the resolution. The depths of all pyramids are determined by
|
64 |
+
options.pyramid_levels. The only exception to this is internal to the feature
|
65 |
+
extractor, where smaller feature pyramids are temporarily constructed with depth
|
66 |
+
options.sub_levels.
|
67 |
+
|
68 |
+
Color ranges & gamma
|
69 |
+
--------------------
|
70 |
+
The model code makes no assumptions on whether the images are in gamma or
|
71 |
+
linearized space or what is the range of RGB color values. So a model can be
|
72 |
+
trained with different choices. This does not mean that all the choices lead to
|
73 |
+
similar results. In practice the model has been proven to work well with RGB
|
74 |
+
scale = [0,1] with gamma-space images (i.e. not linearized).
|
75 |
+
|
76 |
+
[1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
|
77 |
+
[2] Multi-view Image Fusion, Trinidad et al, 2019
|
78 |
+
[3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
|
79 |
+
"""
|
80 |
+
|
81 |
+
from . import feature_extractor
|
82 |
+
from . import fusion
|
83 |
+
from . import options
|
84 |
+
from . import pyramid_flow_estimator
|
85 |
+
from . import util
|
86 |
+
import tensorflow as tf
|
87 |
+
|
88 |
+
|
89 |
+
def create_model(x0: tf.Tensor, x1: tf.Tensor, time: tf.Tensor,
|
90 |
+
config: options.Options) -> tf.keras.Model:
|
91 |
+
"""Creates a frame interpolator model.
|
92 |
+
|
93 |
+
The frame interpolator is used to warp the two images to the in-between frame
|
94 |
+
at given time. Note that training data is often restricted such that
|
95 |
+
supervision only exists at 'time'=0.5. If trained with such data, the model
|
96 |
+
will overfit to predicting images that are halfway between the two inputs and
|
97 |
+
will not be as accurate elsewhere.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
x0: first input image as BxHxWxC tensor.
|
101 |
+
x1: second input image as BxHxWxC tensor.
|
102 |
+
time: ignored by film_net. We always infer a frame at t = 0.5.
|
103 |
+
config: FilmNetOptions object.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A tf.Model that takes 'x0', 'x1', and 'time' as input and returns a
|
107 |
+
dictionary with the interpolated result in 'image'. For additional
|
108 |
+
diagnostics or supervision, the following intermediate results are
|
109 |
+
also stored in the dictionary:
|
110 |
+
'x0_warped': an intermediate result obtained by warping from x0
|
111 |
+
'x1_warped': an intermediate result obtained by warping from x1
|
112 |
+
'forward_residual_flow_pyramid': pyramid with forward residual flows
|
113 |
+
'backward_residual_flow_pyramid': pyramid with backward residual flows
|
114 |
+
'forward_flow_pyramid': pyramid with forward flows
|
115 |
+
'backward_flow_pyramid': pyramid with backward flows
|
116 |
+
|
117 |
+
Raises:
|
118 |
+
ValueError, if config.pyramid_levels < config.fusion_pyramid_levels.
|
119 |
+
"""
|
120 |
+
if config.pyramid_levels < config.fusion_pyramid_levels:
|
121 |
+
raise ValueError('config.pyramid_levels must be greater than or equal to '
|
122 |
+
'config.fusion_pyramid_levels.')
|
123 |
+
|
124 |
+
x0_decoded = x0
|
125 |
+
x1_decoded = x1
|
126 |
+
|
127 |
+
# shuffle images
|
128 |
+
image_pyramids = [
|
129 |
+
util.build_image_pyramid(x0_decoded, config),
|
130 |
+
util.build_image_pyramid(x1_decoded, config)
|
131 |
+
]
|
132 |
+
|
133 |
+
# Siamese feature pyramids:
|
134 |
+
extract = feature_extractor.FeatureExtractor('feat_net', config)
|
135 |
+
feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
|
136 |
+
|
137 |
+
predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
|
138 |
+
'predict_flow', config)
|
139 |
+
|
140 |
+
# Predict forward flow.
|
141 |
+
forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
|
142 |
+
feature_pyramids[1])
|
143 |
+
# Predict backward flow.
|
144 |
+
backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
|
145 |
+
feature_pyramids[0])
|
146 |
+
|
147 |
+
# Concatenate features and images:
|
148 |
+
|
149 |
+
# Note that we keep up to 'fusion_pyramid_levels' levels as only those
|
150 |
+
# are used by the fusion module.
|
151 |
+
fusion_pyramid_levels = config.fusion_pyramid_levels
|
152 |
+
|
153 |
+
forward_flow_pyramid = util.flow_pyramid_synthesis(
|
154 |
+
forward_residual_flow_pyramid)[:fusion_pyramid_levels]
|
155 |
+
backward_flow_pyramid = util.flow_pyramid_synthesis(
|
156 |
+
backward_residual_flow_pyramid)[:fusion_pyramid_levels]
|
157 |
+
|
158 |
+
# We multiply the flows with t and 1-t to warp to the desired fractional time.
|
159 |
+
#
|
160 |
+
# Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
|
161 |
+
# lator for multi-frame interpolation. Below, we create a constant tensor of
|
162 |
+
# shape [B]. We use the `time` tensor to infer the batch size.
|
163 |
+
mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(time)
|
164 |
+
backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
|
165 |
+
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
|
166 |
+
|
167 |
+
pyramids_to_warp = [
|
168 |
+
util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
|
169 |
+
feature_pyramids[0][:fusion_pyramid_levels]),
|
170 |
+
util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
|
171 |
+
feature_pyramids[1][:fusion_pyramid_levels])
|
172 |
+
]
|
173 |
+
|
174 |
+
# Warp features and images using the flow. Note that we use backward warping
|
175 |
+
# and backward flow is used to read from image 0 and forward flow from
|
176 |
+
# image 1.
|
177 |
+
forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
|
178 |
+
backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
|
179 |
+
|
180 |
+
aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
|
181 |
+
backward_warped_pyramid)
|
182 |
+
aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
|
183 |
+
aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
|
184 |
+
|
185 |
+
fuse = fusion.Fusion('fusion', config)
|
186 |
+
prediction = fuse(aligned_pyramid)
|
187 |
+
|
188 |
+
output_color = prediction[..., :3]
|
189 |
+
outputs = {'image': output_color}
|
190 |
+
|
191 |
+
if config.use_aux_outputs:
|
192 |
+
outputs.update({
|
193 |
+
'x0_warped': forward_warped_pyramid[0][..., 0:3],
|
194 |
+
'x1_warped': backward_warped_pyramid[0][..., 0:3],
|
195 |
+
'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
|
196 |
+
'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
|
197 |
+
'forward_flow_pyramid': forward_flow_pyramid,
|
198 |
+
'backward_flow_pyramid': backward_flow_pyramid,
|
199 |
+
})
|
200 |
+
|
201 |
+
model = tf.keras.Model(
|
202 |
+
inputs={
|
203 |
+
'x0': x0,
|
204 |
+
'x1': x1,
|
205 |
+
'time': time
|
206 |
+
}, outputs=outputs)
|
207 |
+
return model
|
models/film_net/options.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Options for the film_net video frame interpolator."""
|
16 |
+
|
17 |
+
import gin.tf
|
18 |
+
|
19 |
+
|
20 |
+
@gin.configurable('film_net')
|
21 |
+
class Options(object):
|
22 |
+
"""Options for the film_net video frame interpolator.
|
23 |
+
|
24 |
+
To further understand these options, see the paper here:
|
25 |
+
https://augmentedperception.github.io/pixelfusion/.
|
26 |
+
|
27 |
+
The default values are suitable for up to 64 pixel motions. For larger motions
|
28 |
+
the number of flow convolutions and/or pyramid levels can be increased, but
|
29 |
+
usually with the cost of accuracy on solving the smaller motions.
|
30 |
+
|
31 |
+
The maximum motion in pixels that the system can resolve is equivalent to
|
32 |
+
2^(pyramid_levels-1) * flow_convs[-1]. I.e. the downsampling factor times
|
33 |
+
the receptive field radius on the coarsest pyramid level. This, of course,
|
34 |
+
assumes that the training data contains such motions.
|
35 |
+
|
36 |
+
Note that to avoid a run-time error, the input image width and height have to
|
37 |
+
be divisible by 2^(pyramid_levels-1).
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
pyramid_levels: How many pyramid levels to use for the feature pyramid and
|
41 |
+
the flow prediction.
|
42 |
+
fusion_pyramid_levels: How many pyramid levels to use for the fusion module
|
43 |
+
this must be less or equal to 'pyramid_levels'.
|
44 |
+
specialized_levels: How many fine levels of the pyramid shouldn't share the
|
45 |
+
weights. If specialized_levels = 3, it means that two finest levels are
|
46 |
+
independently learned, whereas the third will be learned together with the
|
47 |
+
rest of the pyramid. Valid range [1, pyramid_levels].
|
48 |
+
flow_convs: Convolutions per residual flow predictor. This array should have
|
49 |
+
specialized_levels+1 items on it, the last item representing the number of
|
50 |
+
convs used by any pyramid level that uses shared weights.
|
51 |
+
flow_filters: Base number of filters in residual flow predictors. This array
|
52 |
+
should have specialized_levels+1 items on it, the last item representing
|
53 |
+
the number of filters used by any pyramid level that uses shared weights.
|
54 |
+
sub_levels: The depth of the cascaded feature tree each pyramid level
|
55 |
+
concatenates together to compute the flow. This must be within range [1,
|
56 |
+
specialized_level+1]. It is recommended to set this to specialized_levels
|
57 |
+
+ 1
|
58 |
+
filters: Base number of features to extract. On each pyramid level the
|
59 |
+
number doubles. This is used by both feature extraction and fusion stages.
|
60 |
+
use_aux_outputs: Set to True to include auxiliary outputs along with the
|
61 |
+
predicted image.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
pyramid_levels=5,
|
66 |
+
fusion_pyramid_levels=5,
|
67 |
+
specialized_levels=3,
|
68 |
+
flow_convs=None,
|
69 |
+
flow_filters=None,
|
70 |
+
sub_levels=4,
|
71 |
+
filters=16,
|
72 |
+
use_aux_outputs=True):
|
73 |
+
self.pyramid_levels = pyramid_levels
|
74 |
+
self.fusion_pyramid_levels = fusion_pyramid_levels
|
75 |
+
self.specialized_levels = specialized_levels
|
76 |
+
self.flow_convs = flow_convs or [4, 4, 4, 4]
|
77 |
+
self.flow_filters = flow_filters or [64, 128, 256, 256]
|
78 |
+
self.sub_levels = sub_levels
|
79 |
+
self.filters = filters
|
80 |
+
self.use_aux_outputs = use_aux_outputs
|
81 |
+
|
models/film_net/pyramid_flow_estimator.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""TF2 layer for estimating optical flow by a residual flow pyramid.
|
16 |
+
|
17 |
+
This approach of estimating optical flow between two images can be traced back
|
18 |
+
to [1], but is also used by later neural optical flow computation methods such
|
19 |
+
as SpyNet [2] and PWC-Net [3].
|
20 |
+
|
21 |
+
The basic idea is that the optical flow is first estimated in a coarse
|
22 |
+
resolution, then the flow is upsampled to warp the higher resolution image and
|
23 |
+
then a residual correction is computed and added to the estimated flow. This
|
24 |
+
process is repeated in a pyramid on coarse to fine order to successively
|
25 |
+
increase the resolution of both optical flow and the warped image.
|
26 |
+
|
27 |
+
In here, the optical flow predictor is used as an internal component for the
|
28 |
+
film_net frame interpolator, to warp the two input images into the inbetween,
|
29 |
+
target frame.
|
30 |
+
|
31 |
+
[1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
|
32 |
+
[2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
|
33 |
+
Network. 2016
|
34 |
+
[3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
|
35 |
+
Pyramid, Warping, and Cost Volume, 2017
|
36 |
+
"""
|
37 |
+
|
38 |
+
from typing import List
|
39 |
+
|
40 |
+
from . import options
|
41 |
+
from . import util
|
42 |
+
import tensorflow as tf
|
43 |
+
|
44 |
+
|
45 |
+
def _relu(x: tf.Tensor) -> tf.Tensor:
|
46 |
+
return tf.nn.leaky_relu(x, alpha=0.2)
|
47 |
+
|
48 |
+
|
49 |
+
class FlowEstimator(tf.keras.layers.Layer):
|
50 |
+
"""Small-receptive field predictor for computing the flow between two images.
|
51 |
+
|
52 |
+
This is used to compute the residual flow fields in PyramidFlowEstimator.
|
53 |
+
|
54 |
+
Note that while the number of 3x3 convolutions & filters to apply is
|
55 |
+
configurable, two extra 1x1 convolutions are appended to extract the flow in
|
56 |
+
the end.
|
57 |
+
|
58 |
+
Attributes:
|
59 |
+
name: The name of the layer
|
60 |
+
num_convs: Number of 3x3 convolutions to apply
|
61 |
+
num_filters: Number of filters in each 3x3 convolution
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, name: str, num_convs: int, num_filters: int):
|
65 |
+
super(FlowEstimator, self).__init__(name=name)
|
66 |
+
def conv(filters, size, name, activation=_relu):
|
67 |
+
return tf.keras.layers.Conv2D(
|
68 |
+
name=name,
|
69 |
+
filters=filters,
|
70 |
+
kernel_size=size,
|
71 |
+
padding='same',
|
72 |
+
activation=activation)
|
73 |
+
|
74 |
+
self._convs = []
|
75 |
+
for i in range(num_convs):
|
76 |
+
self._convs.append(conv(filters=num_filters, size=3, name=f'conv_{i}'))
|
77 |
+
self._convs.append(conv(filters=num_filters/2, size=1, name=f'conv_{i+1}'))
|
78 |
+
# For the final convolution, we want no activation at all to predict the
|
79 |
+
# optical flow vector values. We have done extensive testing on explicitly
|
80 |
+
# bounding these values using sigmoid, but it turned out that having no
|
81 |
+
# activation gives better results.
|
82 |
+
self._convs.append(
|
83 |
+
conv(filters=2, size=1, name=f'conv_{i+2}', activation=None))
|
84 |
+
|
85 |
+
def call(self, features_a: tf.Tensor, features_b: tf.Tensor) -> tf.Tensor:
|
86 |
+
"""Estimates optical flow between two images.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
features_a: per pixel feature vectors for image A (B x H x W x C)
|
90 |
+
features_b: per pixel feature vectors for image B (B x H x W x C)
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
A tensor with optical flow from A to B
|
94 |
+
"""
|
95 |
+
net = tf.concat([features_a, features_b], axis=-1)
|
96 |
+
for conv in self._convs:
|
97 |
+
net = conv(net)
|
98 |
+
return net
|
99 |
+
|
100 |
+
|
101 |
+
class PyramidFlowEstimator(tf.keras.layers.Layer):
|
102 |
+
"""Predicts optical flow by coarse-to-fine refinement.
|
103 |
+
|
104 |
+
Attributes:
|
105 |
+
name: The name of the layer
|
106 |
+
config: Options for the film_net frame interpolator
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, name: str, config: options.Options):
|
110 |
+
super(PyramidFlowEstimator, self).__init__(name=name)
|
111 |
+
self._predictors = []
|
112 |
+
for i in range(config.specialized_levels):
|
113 |
+
self._predictors.append(
|
114 |
+
FlowEstimator(
|
115 |
+
name=f'flow_predictor_{i}',
|
116 |
+
num_convs=config.flow_convs[i],
|
117 |
+
num_filters=config.flow_filters[i]))
|
118 |
+
shared_predictor = FlowEstimator(
|
119 |
+
name='flow_predictor_shared',
|
120 |
+
num_convs=config.flow_convs[-1],
|
121 |
+
num_filters=config.flow_filters[-1])
|
122 |
+
for i in range(config.specialized_levels, config.pyramid_levels):
|
123 |
+
self._predictors.append(shared_predictor)
|
124 |
+
|
125 |
+
def call(self, feature_pyramid_a: List[tf.Tensor],
|
126 |
+
feature_pyramid_b: List[tf.Tensor]) -> List[tf.Tensor]:
|
127 |
+
"""Estimates residual flow pyramids between two image pyramids.
|
128 |
+
|
129 |
+
Each image pyramid is represented as a list of tensors in fine-to-coarse
|
130 |
+
order. Each individual image is represented as a tensor where each pixel is
|
131 |
+
a vector of image features.
|
132 |
+
|
133 |
+
util.flow_pyramid_synthesis can be used to convert the residual flow
|
134 |
+
pyramid returned by this method into a flow pyramid, where each level
|
135 |
+
encodes the flow instead of a residual correction.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
feature_pyramid_a: image pyramid as a list in fine-to-coarse order
|
139 |
+
feature_pyramid_b: image pyramid as a list in fine-to-coarse order
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
List of flow tensors, in fine-to-coarse order, each level encoding the
|
143 |
+
difference against the bilinearly upsampled version from the coarser
|
144 |
+
level. The coarsest flow tensor, e.g. the last element in the array is the
|
145 |
+
'DC-term', e.g. not a residual (alternatively you can think of it being a
|
146 |
+
residual against zero).
|
147 |
+
"""
|
148 |
+
levels = len(feature_pyramid_a)
|
149 |
+
v = self._predictors[-1](feature_pyramid_a[-1], feature_pyramid_b[-1])
|
150 |
+
residuals = [v]
|
151 |
+
for i in reversed(range(0, levels-1)):
|
152 |
+
# Upsamples the flow to match the current pyramid level. Also, scales the
|
153 |
+
# magnitude by two to reflect the new size.
|
154 |
+
level_size = tf.shape(feature_pyramid_a[i])[1:3]
|
155 |
+
v = tf.image.resize(images=2*v, size=level_size)
|
156 |
+
# Warp feature_pyramid_b[i] image based on the current flow estimate.
|
157 |
+
warped = util.warp(feature_pyramid_b[i], v)
|
158 |
+
# Estimate the residual flow between pyramid_a[i] and warped image:
|
159 |
+
v_residual = self._predictors[i](feature_pyramid_a[i], warped)
|
160 |
+
residuals.append(v_residual)
|
161 |
+
v = v_residual + v
|
162 |
+
# Use reversed() to return in the 'standard' finest-first-order:
|
163 |
+
return list(reversed(residuals))
|
models/film_net/util.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Various utilities used in the film_net frame interpolator model."""
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
from .options import Options
|
19 |
+
import tensorflow as tf
|
20 |
+
import tensorflow_addons.image as tfa_image
|
21 |
+
|
22 |
+
|
23 |
+
def build_image_pyramid(image: tf.Tensor,
|
24 |
+
options: Options) -> List[tf.Tensor]:
|
25 |
+
"""Builds an image pyramid from a given image.
|
26 |
+
|
27 |
+
The original image is included in the pyramid and the rest are generated by
|
28 |
+
successively halving the resolution.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
image: the input image.
|
32 |
+
options: film_net options object
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
A list of images starting from the finest with options.pyramid_levels items
|
36 |
+
"""
|
37 |
+
levels = options.pyramid_levels
|
38 |
+
pyramid = []
|
39 |
+
pool = tf.keras.layers.AveragePooling2D(
|
40 |
+
pool_size=2, strides=2, padding='valid')
|
41 |
+
for i in range(0, levels):
|
42 |
+
pyramid.append(image)
|
43 |
+
if i < levels-1:
|
44 |
+
image = pool(image)
|
45 |
+
return pyramid
|
46 |
+
|
47 |
+
|
48 |
+
def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor:
|
49 |
+
"""Backward warps the image using the given flow.
|
50 |
+
|
51 |
+
Specifically, the output pixel in batch b, at position x, y will be computed
|
52 |
+
as follows:
|
53 |
+
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
|
54 |
+
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
|
55 |
+
|
56 |
+
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
|
57 |
+
y in position 1.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
image: An image with shape BxHxWxC.
|
61 |
+
flow: A flow with shape BxHxWx2, with the two channels denoting the relative
|
62 |
+
offset in order: (dx, dy).
|
63 |
+
Returns:
|
64 |
+
A warped image.
|
65 |
+
"""
|
66 |
+
# tfa_image.dense_image_warp expects unconventional negated optical flow, so
|
67 |
+
# negate the flow here. Also revert x and y for compatibility with older saved
|
68 |
+
# models trained with custom warp op that stored (x, y) instead of (y, x) flow
|
69 |
+
# vectors.
|
70 |
+
flow = -flow[..., ::-1]
|
71 |
+
|
72 |
+
# Note: we have to wrap tfa_image.dense_image_warp into a Keras Lambda,
|
73 |
+
# because it is not compatible with Keras symbolic tensors and we want to use
|
74 |
+
# this code as part of a Keras model. Wrapping it into a lambda has the
|
75 |
+
# consequence that tfa_image.dense_image_warp is only called once the tensors
|
76 |
+
# are concrete, e.g. actually contain data. The inner lambda is a workaround
|
77 |
+
# for passing two parameters, e.g you would really want to write:
|
78 |
+
# tf.keras.layers.Lambda(tfa_image.dense_image_warp)(image, flow), but this is
|
79 |
+
# not supported by the Keras Lambda.
|
80 |
+
warped = tf.keras.layers.Lambda(
|
81 |
+
lambda x: tfa_image.dense_image_warp(*x))((image, flow))
|
82 |
+
return tf.reshape(warped, shape=tf.shape(image))
|
83 |
+
|
84 |
+
|
85 |
+
def multiply_pyramid(pyramid: List[tf.Tensor],
|
86 |
+
scalar: tf.Tensor) -> List[tf.Tensor]:
|
87 |
+
"""Multiplies all image batches in the pyramid by a batch of scalars.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
pyramid: Pyramid of image batches.
|
91 |
+
scalar: Batch of scalars.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
An image pyramid with all images multiplied by the scalar.
|
95 |
+
"""
|
96 |
+
# To multiply each image with its corresponding scalar, we first transpose
|
97 |
+
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be
|
98 |
+
# multiplied with a batch of scalars, then we transpose back to the standard
|
99 |
+
# BxHxWxC form.
|
100 |
+
return [
|
101 |
+
tf.transpose(tf.transpose(image, [3, 1, 2, 0]) * scalar, [3, 1, 2, 0])
|
102 |
+
for image in pyramid
|
103 |
+
]
|
104 |
+
|
105 |
+
|
106 |
+
def flow_pyramid_synthesis(
|
107 |
+
residual_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
|
108 |
+
"""Converts a residual flow pyramid into a flow pyramid."""
|
109 |
+
flow = residual_pyramid[-1]
|
110 |
+
flow_pyramid = [flow]
|
111 |
+
for residual_flow in reversed(residual_pyramid[:-1]):
|
112 |
+
level_size = tf.shape(residual_flow)[1:3]
|
113 |
+
flow = tf.image.resize(images=2*flow, size=level_size)
|
114 |
+
flow = residual_flow + flow
|
115 |
+
flow_pyramid.append(flow)
|
116 |
+
# Use reversed() to return in the 'standard' finest-first-order:
|
117 |
+
return list(reversed(flow_pyramid))
|
118 |
+
|
119 |
+
|
120 |
+
def pyramid_warp(feature_pyramid: List[tf.Tensor],
|
121 |
+
flow_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
|
122 |
+
"""Warps the feature pyramid using the flow pyramid.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
feature_pyramid: feature pyramid starting from the finest level.
|
126 |
+
flow_pyramid: flow fields, starting from the finest level.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
Reverse warped feature pyramid.
|
130 |
+
"""
|
131 |
+
warped_feature_pyramid = []
|
132 |
+
for features, flow in zip(feature_pyramid, flow_pyramid):
|
133 |
+
warped_feature_pyramid.append(warp(features, flow))
|
134 |
+
return warped_feature_pyramid
|
135 |
+
|
136 |
+
|
137 |
+
def concatenate_pyramids(pyramid1: List[tf.Tensor],
|
138 |
+
pyramid2: List[tf.Tensor]) -> List[tf.Tensor]:
|
139 |
+
"""Concatenates each pyramid level together in the channel dimension."""
|
140 |
+
result = []
|
141 |
+
for features1, features2 in zip(pyramid1, pyramid2):
|
142 |
+
result.append(tf.concat([features1, features2], axis=-1))
|
143 |
+
return result
|
moment.gif
ADDED
Git LFS Details
|
photos/one.png
ADDED
Git LFS Details
|
photos/two.png
ADDED
Git LFS Details
|
predict.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import numpy as np
|
4 |
+
import tempfile
|
5 |
+
import tensorflow as tf
|
6 |
+
import mediapy
|
7 |
+
from PIL import Image
|
8 |
+
import cog
|
9 |
+
|
10 |
+
from eval import interpolator, util
|
11 |
+
|
12 |
+
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
|
13 |
+
|
14 |
+
|
15 |
+
class Predictor(cog.Predictor):
|
16 |
+
def setup(self):
|
17 |
+
import tensorflow as tf
|
18 |
+
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
|
19 |
+
self.interpolator = interpolator.Interpolator("pretrained_models/film_net/Style/saved_model", None)
|
20 |
+
|
21 |
+
# Batched time.
|
22 |
+
self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
|
23 |
+
|
24 |
+
@cog.input(
|
25 |
+
"frame1",
|
26 |
+
type=Path,
|
27 |
+
help="The first input frame",
|
28 |
+
)
|
29 |
+
@cog.input(
|
30 |
+
"frame2",
|
31 |
+
type=Path,
|
32 |
+
help="The second input frame",
|
33 |
+
)
|
34 |
+
@cog.input(
|
35 |
+
"times_to_interpolate",
|
36 |
+
type=int,
|
37 |
+
default=1,
|
38 |
+
min=1,
|
39 |
+
max=8,
|
40 |
+
help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the "
|
41 |
+
"sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with "
|
42 |
+
"(2^times_to_interpolate + 1) frames, fps of 30.",
|
43 |
+
)
|
44 |
+
def predict(self, frame1, frame2, times_to_interpolate):
|
45 |
+
INPUT_EXT = ['.png', '.jpg', '.jpeg']
|
46 |
+
assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \
|
47 |
+
"Please provide png, jpg or jpeg images."
|
48 |
+
|
49 |
+
# make sure 2 images are the same size
|
50 |
+
img1 = Image.open(str(frame1))
|
51 |
+
img2 = Image.open(str(frame2))
|
52 |
+
if not img1.size == img2.size:
|
53 |
+
img1 = img1.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
|
54 |
+
img2 = img2.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
|
55 |
+
frame1 = 'new_frame1.png'
|
56 |
+
frame2 = 'new_frame2.png'
|
57 |
+
img1.save(frame1)
|
58 |
+
img2.save(frame2)
|
59 |
+
|
60 |
+
if times_to_interpolate == 1:
|
61 |
+
# First batched image.
|
62 |
+
image_1 = util.read_image(str(frame1))
|
63 |
+
image_batch_1 = np.expand_dims(image_1, axis=0)
|
64 |
+
|
65 |
+
# Second batched image.
|
66 |
+
image_2 = util.read_image(str(frame2))
|
67 |
+
image_batch_2 = np.expand_dims(image_2, axis=0)
|
68 |
+
|
69 |
+
# Invoke the model once.
|
70 |
+
|
71 |
+
mid_frame = self.interpolator.interpolate(image_batch_1, image_batch_2, self.batch_dt)[0]
|
72 |
+
out_path = Path(tempfile.mkdtemp()) / "out.png"
|
73 |
+
util.write_image(str(out_path), mid_frame)
|
74 |
+
return out_path
|
75 |
+
|
76 |
+
|
77 |
+
input_frames = [str(frame1), str(frame2)]
|
78 |
+
|
79 |
+
frames = list(
|
80 |
+
util.interpolate_recursively_from_files(
|
81 |
+
input_frames, times_to_interpolate, self.interpolator))
|
82 |
+
print('Interpolated frames generated, saving now as output video.')
|
83 |
+
|
84 |
+
ffmpeg_path = util.get_ffmpeg_path()
|
85 |
+
mediapy.set_ffmpeg(ffmpeg_path)
|
86 |
+
out_path = Path(tempfile.mkdtemp()) / "out.mp4"
|
87 |
+
mediapy.write_video(str(out_path), frames, fps=30)
|
88 |
+
return out_path
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker base image: `gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest`
|
2 |
+
tensorflow==2.6.2 # The latest should include tensorflow-gpu
|
3 |
+
tensorflow-datasets==4.4.0
|
4 |
+
tensorflow-addons==0.15.0
|
5 |
+
absl-py==0.12.0
|
6 |
+
gin-config==0.5.0
|
7 |
+
parameterized==0.8.1
|
8 |
+
mediapy==1.0.3
|
9 |
+
scikit-image==0.19.1
|
10 |
+
apache-beam==2.34.0
|
11 |
+
google-cloud-bigquery-storage==1.1.0 # Suppresses a harmless error from beam
|
12 |
+
natsort==8.1.0
|
13 |
+
gdown==4.5.4
|
14 |
+
tqdm==4.64.1
|
training/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
training/augmentation_lib.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Dataset augmentation for frame interpolation."""
|
16 |
+
from typing import Callable, Dict, List
|
17 |
+
|
18 |
+
import gin.tf
|
19 |
+
import numpy as np
|
20 |
+
import tensorflow as tf
|
21 |
+
import tensorflow.math as tfm
|
22 |
+
import tensorflow_addons.image as tfa_image
|
23 |
+
|
24 |
+
_PI = 3.141592653589793
|
25 |
+
|
26 |
+
|
27 |
+
def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
|
28 |
+
r"""Rotate the (u,v) vector of each pixel with angle in radians.
|
29 |
+
|
30 |
+
Flow matrix system of coordinates.
|
31 |
+
. . . . u (x)
|
32 |
+
.
|
33 |
+
.
|
34 |
+
. v (-y)
|
35 |
+
|
36 |
+
Rotation system of coordinates.
|
37 |
+
. y
|
38 |
+
.
|
39 |
+
.
|
40 |
+
. . . . x
|
41 |
+
Args:
|
42 |
+
flow: Flow map which has been image-rotated.
|
43 |
+
angle_rad: The rotation angle in radians.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
A flow with the same map but each (u,v) vector rotated by angle_rad.
|
47 |
+
"""
|
48 |
+
u, v = tf.split(flow, 2, axis=-1)
|
49 |
+
# rotu = u * cos(angle) - (-v) * sin(angle)
|
50 |
+
rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v
|
51 |
+
# rotv = -(u * sin(theta) + (-v) * cos(theta))
|
52 |
+
rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v
|
53 |
+
return tf.concat((rot_u, rot_v), axis=-1)
|
54 |
+
|
55 |
+
|
56 |
+
def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor:
|
57 |
+
"""Rotates a flow by a multiple of 90 degrees.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
|
61 |
+
k: The multiplier factor.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A flow image of the same shape as the input rotated by multiples of 90
|
65 |
+
degrees.
|
66 |
+
"""
|
67 |
+
angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.)
|
68 |
+
flow = tf.image.rot90(flow, k)
|
69 |
+
return _rotate_flow_vectors(flow, angle_rad)
|
70 |
+
|
71 |
+
|
72 |
+
def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
|
73 |
+
"""Rotates a flow by a the provided angle in radians.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
|
77 |
+
angle_rad: The angle to ratate the flow in radians.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
A flow image of the same shape as the input rotated by the provided angle in
|
81 |
+
radians.
|
82 |
+
"""
|
83 |
+
flow = tfa_image.rotate(
|
84 |
+
flow,
|
85 |
+
angles=angle_rad,
|
86 |
+
interpolation='bilinear',
|
87 |
+
fill_mode='reflect')
|
88 |
+
return _rotate_flow_vectors(flow, angle_rad)
|
89 |
+
|
90 |
+
|
91 |
+
def flow_flip(flow: tf.Tensor) -> tf.Tensor:
|
92 |
+
"""Flips a flow left to right.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
flow: The flow image shaped (H, W, 2) to flip left to right.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
A flow image of the same shape as the input flipped left to right.
|
99 |
+
"""
|
100 |
+
flow = tf.image.flip_left_right(tf.identity(flow))
|
101 |
+
flow_u, flow_v = tf.split(flow, 2, axis=-1)
|
102 |
+
return tf.stack([-1 * flow_u, flow_v], axis=-1)
|
103 |
+
|
104 |
+
|
105 |
+
def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
|
106 |
+
"""Rotates a stack of images by a random multiples of 90 degrees.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
|
110 |
+
channel's axis.
|
111 |
+
Returns:
|
112 |
+
A tf.Tensor of the same rank as the `images` after random rotation by
|
113 |
+
multiples of 90 degrees applied counter-clock wise.
|
114 |
+
"""
|
115 |
+
random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
|
116 |
+
for key in images:
|
117 |
+
images[key] = tf.image.rot90(images[key], k=random_k)
|
118 |
+
return images
|
119 |
+
|
120 |
+
|
121 |
+
def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
|
122 |
+
"""Flips a stack of images randomly.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
|
126 |
+
channel's axis.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
A tf.Tensor of the images after random left to right flip.
|
130 |
+
"""
|
131 |
+
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
|
132 |
+
prob = tf.cast(prob, tf.bool)
|
133 |
+
|
134 |
+
def _identity(image):
|
135 |
+
return image
|
136 |
+
|
137 |
+
def _flip_left_right(image):
|
138 |
+
return tf.image.flip_left_right(image)
|
139 |
+
|
140 |
+
# pylint: disable=cell-var-from-loop
|
141 |
+
for key in images:
|
142 |
+
images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]),
|
143 |
+
lambda: _identity(images[key]))
|
144 |
+
return images
|
145 |
+
|
146 |
+
|
147 |
+
def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
|
148 |
+
"""Reverses a stack of images randomly.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with
|
152 |
+
each tensor being a stack of iamges along the last channel axis.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
A dictionary of tf.Tensors, each shaped the same as the input images dict.
|
156 |
+
"""
|
157 |
+
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
|
158 |
+
prob = tf.cast(prob, tf.bool)
|
159 |
+
|
160 |
+
def _identity(images):
|
161 |
+
return images
|
162 |
+
|
163 |
+
def _reverse(images):
|
164 |
+
images['x0'], images['x1'] = images['x1'], images['x0']
|
165 |
+
return images
|
166 |
+
|
167 |
+
return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images))
|
168 |
+
|
169 |
+
|
170 |
+
def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
|
171 |
+
"""Rotates image randomly with [-45 to 45 degrees].
|
172 |
+
|
173 |
+
Args:
|
174 |
+
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
|
175 |
+
channel's axis.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
A tf.Tensor of the images after random rotation with a bound of -72 to 72
|
179 |
+
degrees.
|
180 |
+
"""
|
181 |
+
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
|
182 |
+
prob = tf.cast(prob, tf.float32)
|
183 |
+
random_angle = tf.random.uniform((),
|
184 |
+
minval=-0.25 * np.pi,
|
185 |
+
maxval=0.25 * np.pi,
|
186 |
+
dtype=tf.float32)
|
187 |
+
|
188 |
+
for key in images:
|
189 |
+
images[key] = tfa_image.rotate(
|
190 |
+
images[key],
|
191 |
+
angles=random_angle * prob,
|
192 |
+
interpolation='bilinear',
|
193 |
+
fill_mode='constant')
|
194 |
+
return images
|
195 |
+
|
196 |
+
|
197 |
+
@gin.configurable('data_augmentation')
|
198 |
+
def data_augmentations(
|
199 |
+
names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]:
|
200 |
+
"""Creates the data augmentation functions.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
names: The list of augmentation function names.
|
204 |
+
Returns:
|
205 |
+
A dictionary of Callables to the augmentation functions, keyed by their
|
206 |
+
names.
|
207 |
+
"""
|
208 |
+
augmentations = dict()
|
209 |
+
for name in names:
|
210 |
+
if name == 'random_image_rot90':
|
211 |
+
augmentations[name] = random_image_rot90
|
212 |
+
elif name == 'random_rotate':
|
213 |
+
augmentations[name] = random_rotate
|
214 |
+
elif name == 'random_flip':
|
215 |
+
augmentations[name] = random_flip
|
216 |
+
elif name == 'random_reverse':
|
217 |
+
augmentations[name] = random_reverse
|
218 |
+
else:
|
219 |
+
raise AttributeError('Invalid augmentation function %s' % name)
|
220 |
+
return augmentations
|
training/build_saved_model_cli.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Converts TF2 training checkpoint to a saved model.
|
16 |
+
|
17 |
+
The model must match the checkpoint, so the gin config must be given.
|
18 |
+
|
19 |
+
Usage example:
|
20 |
+
python3 -m frame_interpolation.training.build_saved_model_cli \
|
21 |
+
--gin_config <filepath of the gin config the training session was based> \
|
22 |
+
--base_folder <base folder of training sessions> \
|
23 |
+
--label <the name of the run>
|
24 |
+
|
25 |
+
This will produce a saved model into: <base_folder>/<label>/saved_model
|
26 |
+
"""
|
27 |
+
import os
|
28 |
+
from typing import Sequence
|
29 |
+
|
30 |
+
from . import model_lib
|
31 |
+
from absl import app
|
32 |
+
from absl import flags
|
33 |
+
from absl import logging
|
34 |
+
import gin.tf
|
35 |
+
import tensorflow as tf
|
36 |
+
tf.get_logger().setLevel('ERROR')
|
37 |
+
|
38 |
+
_GIN_CONFIG = flags.DEFINE_string(
|
39 |
+
name='gin_config',
|
40 |
+
default='config.gin',
|
41 |
+
help='Gin config file, saved in the training session <root folder>.')
|
42 |
+
_LABEL = flags.DEFINE_string(
|
43 |
+
name='label',
|
44 |
+
default=None,
|
45 |
+
required=True,
|
46 |
+
help='Descriptive label for the training session.')
|
47 |
+
_BASE_FOLDER = flags.DEFINE_string(
|
48 |
+
name='base_folder',
|
49 |
+
default=None,
|
50 |
+
help='Path to all training sessions.')
|
51 |
+
_MODE = flags.DEFINE_enum(
|
52 |
+
name='mode',
|
53 |
+
default=None,
|
54 |
+
enum_values=['cpu', 'gpu', 'tpu'],
|
55 |
+
help='Distributed strategy approach.')
|
56 |
+
|
57 |
+
|
58 |
+
def _build_saved_model(checkpoint_path: str, config_files: Sequence[str],
|
59 |
+
output_model_path: str):
|
60 |
+
"""Builds a saved model based on the checkpoint directory."""
|
61 |
+
gin.parse_config_files_and_bindings(
|
62 |
+
config_files=config_files,
|
63 |
+
bindings=None,
|
64 |
+
skip_unknown=True)
|
65 |
+
model = model_lib.create_model()
|
66 |
+
checkpoint = tf.train.Checkpoint(model=model)
|
67 |
+
checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
|
68 |
+
try:
|
69 |
+
logging.info('Restoring from %s', checkpoint_file)
|
70 |
+
status = checkpoint.restore(checkpoint_file)
|
71 |
+
status.assert_existing_objects_matched()
|
72 |
+
status.expect_partial()
|
73 |
+
model.save(output_model_path)
|
74 |
+
except (tf.errors.NotFoundError, AssertionError) as err:
|
75 |
+
logging.info('Failed to restore checkpoint from %s. Error:\n%s',
|
76 |
+
checkpoint_file, err)
|
77 |
+
|
78 |
+
|
79 |
+
def main(argv):
|
80 |
+
if len(argv) > 1:
|
81 |
+
raise app.UsageError('Too many command-line arguments.')
|
82 |
+
|
83 |
+
checkpoint_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
|
84 |
+
if not tf.io.gfile.exists(_GIN_CONFIG.value):
|
85 |
+
config_file = os.path.join(_BASE_FOLDER.value, _LABEL.value,
|
86 |
+
_GIN_CONFIG.value)
|
87 |
+
else:
|
88 |
+
config_file = _GIN_CONFIG.value
|
89 |
+
output_model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value,
|
90 |
+
'saved_model')
|
91 |
+
_build_saved_model(
|
92 |
+
checkpoint_path=checkpoint_path,
|
93 |
+
config_files=[config_file],
|
94 |
+
output_model_path=output_model_path)
|
95 |
+
logging.info('The saved model stored into %s/.', output_model_path)
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
app.run(main)
|
training/config/film_net-L1.gin
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
model.name = 'film_net'
|
16 |
+
|
17 |
+
film_net.pyramid_levels = 7
|
18 |
+
film_net.fusion_pyramid_levels = 5
|
19 |
+
film_net.specialized_levels = 3
|
20 |
+
film_net.sub_levels = 4
|
21 |
+
film_net.flow_convs = [3, 3, 3, 3]
|
22 |
+
film_net.flow_filters = [32, 64, 128, 256]
|
23 |
+
film_net.filters = 64
|
24 |
+
|
25 |
+
training.learning_rate = 0.0001
|
26 |
+
training.learning_rate_decay_steps = 750000
|
27 |
+
training.learning_rate_decay_rate = 0.464158
|
28 |
+
training.learning_rate_staircase = True
|
29 |
+
training.num_steps = 3000000
|
30 |
+
|
31 |
+
# in the sweep
|
32 |
+
training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
|
33 |
+
training_dataset.batch_size = 8
|
34 |
+
training_dataset.crop_size = 256
|
35 |
+
|
36 |
+
eval_datasets.batch_size = 1
|
37 |
+
eval_datasets.max_examples = -1
|
38 |
+
# eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
|
39 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
|
40 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
|
41 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
|
42 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
|
43 |
+
# eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
|
44 |
+
eval_datasets.files = []
|
45 |
+
eval_datasets.names = []
|
46 |
+
|
47 |
+
# Training augmentation (in addition to random crop)
|
48 |
+
data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
|
49 |
+
|
50 |
+
# Loss functions
|
51 |
+
training_losses.loss_names = ['l1']
|
52 |
+
training_losses.loss_weights = [1.0]
|
53 |
+
|
54 |
+
test_losses.loss_names = ['l1', 'psnr', 'ssim']
|
55 |
+
test_losses.loss_weights = [1.0, 1.0, 1.0]
|
training/config/film_net-Style.gin
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
model.name = 'film_net'
|
16 |
+
|
17 |
+
film_net.pyramid_levels = 7
|
18 |
+
film_net.fusion_pyramid_levels = 5
|
19 |
+
film_net.specialized_levels = 3
|
20 |
+
film_net.sub_levels = 4
|
21 |
+
film_net.flow_convs = [3, 3, 3, 3]
|
22 |
+
film_net.flow_filters = [32, 64, 128, 256]
|
23 |
+
film_net.filters = 64
|
24 |
+
|
25 |
+
training.learning_rate = 0.0001
|
26 |
+
training.learning_rate_decay_steps = 750000
|
27 |
+
training.learning_rate_decay_rate = 0.464158
|
28 |
+
training.learning_rate_staircase = True
|
29 |
+
training.num_steps = 3000000
|
30 |
+
|
31 |
+
# in the sweep
|
32 |
+
training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
|
33 |
+
training_dataset.batch_size = 8
|
34 |
+
training_dataset.crop_size = 256
|
35 |
+
|
36 |
+
eval_datasets.batch_size = 1
|
37 |
+
eval_datasets.max_examples = -1
|
38 |
+
# eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
|
39 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
|
40 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
|
41 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
|
42 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
|
43 |
+
# eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
|
44 |
+
eval_datasets.files = []
|
45 |
+
eval_datasets.names = []
|
46 |
+
|
47 |
+
# Training augmentation (in addition to random crop)
|
48 |
+
data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
|
49 |
+
|
50 |
+
# Loss functions
|
51 |
+
training_losses.loss_names = ['l1', 'vgg', 'style']
|
52 |
+
training_losses.loss_weight_schedules = [
|
53 |
+
@tf.keras.optimizers.schedules.PiecewiseConstantDecay,
|
54 |
+
@tf.keras.optimizers.schedules.PiecewiseConstantDecay,
|
55 |
+
@tf.keras.optimizers.schedules.PiecewiseConstantDecay]
|
56 |
+
# Increase the weight of style loss at 1.5M steps.
|
57 |
+
training_losses.loss_weight_parameters = [
|
58 |
+
{'boundaries':[0], 'values':[1.0, 1.0]},
|
59 |
+
{'boundaries':[1500000], 'values':[1.0, 0.25]},
|
60 |
+
{'boundaries':[1500000], 'values':[0.0, 40.0]}]
|
61 |
+
|
62 |
+
test_losses.loss_names = ['l1', 'psnr', 'ssim']
|
63 |
+
test_losses.loss_weights = [1.0, 1.0, 1.0]
|
64 |
+
|
65 |
+
vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
|
66 |
+
style.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
|
training/config/film_net-VGG.gin
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
model.name = 'film_net'
|
16 |
+
|
17 |
+
film_net.pyramid_levels = 7
|
18 |
+
film_net.fusion_pyramid_levels = 5
|
19 |
+
film_net.specialized_levels = 3
|
20 |
+
film_net.sub_levels = 4
|
21 |
+
film_net.flow_convs = [3, 3, 3, 3]
|
22 |
+
film_net.flow_filters = [32, 64, 128, 256]
|
23 |
+
film_net.filters = 64
|
24 |
+
|
25 |
+
training.learning_rate = 0.0001
|
26 |
+
training.learning_rate_decay_steps = 750000
|
27 |
+
training.learning_rate_decay_rate = 0.464158
|
28 |
+
training.learning_rate_staircase = True
|
29 |
+
training.num_steps = 3000000
|
30 |
+
|
31 |
+
# in the sweep
|
32 |
+
training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
|
33 |
+
training_dataset.batch_size = 8
|
34 |
+
training_dataset.crop_size = 256
|
35 |
+
|
36 |
+
eval_datasets.batch_size = 1
|
37 |
+
eval_datasets.max_examples = -1
|
38 |
+
# eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
|
39 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
|
40 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
|
41 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
|
42 |
+
# 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
|
43 |
+
# eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
|
44 |
+
eval_datasets.files = []
|
45 |
+
eval_datasets.names = []
|
46 |
+
|
47 |
+
# Training augmentation (in addition to random crop)
|
48 |
+
data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
|
49 |
+
|
50 |
+
# Loss functions
|
51 |
+
training_losses.loss_names = ['l1', 'vgg']
|
52 |
+
training_losses.loss_weight_schedules = [
|
53 |
+
@tf.keras.optimizers.schedules.PiecewiseConstantDecay,
|
54 |
+
@tf.keras.optimizers.schedules.PiecewiseConstantDecay]
|
55 |
+
|
56 |
+
# Decrease the weight of VGG loss at 1.5M steps.
|
57 |
+
training_losses.loss_weight_parameters = [
|
58 |
+
{'boundaries':[0], 'values':[1.0, 1.0]},
|
59 |
+
{'boundaries':[1500000], 'values':[1.0, 0.25]}]
|
60 |
+
|
61 |
+
test_losses.loss_names = ['l1', 'psnr', 'ssim']
|
62 |
+
test_losses.loss_weights = [1.0, 1.0, 1.0]
|
63 |
+
|
64 |
+
vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
|
training/data_lib.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Dataset creation for frame interpolation."""
|
16 |
+
from typing import Callable, Dict, List, Optional
|
17 |
+
|
18 |
+
from absl import logging
|
19 |
+
import gin.tf
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
|
23 |
+
def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]:
|
24 |
+
"""Creates the feature map for extracting the frame triplet."""
|
25 |
+
feature_map = {
|
26 |
+
'frame_0/encoded':
|
27 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
28 |
+
'frame_0/format':
|
29 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
30 |
+
'frame_0/height':
|
31 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
32 |
+
'frame_0/width':
|
33 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
34 |
+
'frame_1/encoded':
|
35 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
36 |
+
'frame_1/format':
|
37 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
38 |
+
'frame_1/height':
|
39 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
40 |
+
'frame_1/width':
|
41 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
42 |
+
'frame_2/encoded':
|
43 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
44 |
+
'frame_2/format':
|
45 |
+
tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
|
46 |
+
'frame_2/height':
|
47 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
48 |
+
'frame_2/width':
|
49 |
+
tf.io.FixedLenFeature((), tf.int64, default_value=0),
|
50 |
+
'path':
|
51 |
+
tf.io.FixedLenFeature((), tf.string, default_value=''),
|
52 |
+
}
|
53 |
+
return feature_map
|
54 |
+
|
55 |
+
|
56 |
+
def _parse_example(sample):
|
57 |
+
"""Parses a serialized sample.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
sample: A serialized tf.Example to be parsed.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
dictionary containing the following:
|
64 |
+
encoded_image
|
65 |
+
image_height
|
66 |
+
image_width
|
67 |
+
"""
|
68 |
+
feature_map = _create_feature_map()
|
69 |
+
features = tf.io.parse_single_example(sample, feature_map)
|
70 |
+
output_dict = {
|
71 |
+
'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32),
|
72 |
+
'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32),
|
73 |
+
'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32),
|
74 |
+
# The fractional time value of frame_1 is not included in our tfrecords,
|
75 |
+
# but is always at 0.5. The model will expect this to be specificed, so
|
76 |
+
# we insert it here.
|
77 |
+
'time': 0.5,
|
78 |
+
# Store the original mid frame filepath for identifying examples.
|
79 |
+
'path': features['path'],
|
80 |
+
}
|
81 |
+
|
82 |
+
return output_dict
|
83 |
+
|
84 |
+
|
85 |
+
def _random_crop_images(crop_size: int, images: tf.Tensor,
|
86 |
+
total_channel_size: int) -> tf.Tensor:
|
87 |
+
"""Crops the tensor with random offset to the given size."""
|
88 |
+
if crop_size > 0:
|
89 |
+
crop_shape = tf.constant([crop_size, crop_size, total_channel_size])
|
90 |
+
images = tf.image.random_crop(images, crop_shape)
|
91 |
+
return images
|
92 |
+
|
93 |
+
|
94 |
+
def crop_example(example: tf.Tensor, crop_size: int,
|
95 |
+
crop_keys: Optional[List[str]] = None):
|
96 |
+
"""Random crops selected images in the example to given size and keys.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
example: Input tensor representing images to be cropped.
|
100 |
+
crop_size: The size to crop images to. This value is used for both
|
101 |
+
height and width.
|
102 |
+
crop_keys: The images in the input example to crop.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Example with cropping applied to selected images.
|
106 |
+
"""
|
107 |
+
if crop_keys is None:
|
108 |
+
crop_keys = ['x0', 'x1', 'y']
|
109 |
+
channels = [3, 3, 3]
|
110 |
+
|
111 |
+
# Stack images along channel axis, and perform a random crop once.
|
112 |
+
image_to_crop = [example[key] for key in crop_keys]
|
113 |
+
stacked_images = tf.concat(image_to_crop, axis=-1)
|
114 |
+
cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels))
|
115 |
+
cropped_images = tf.split(
|
116 |
+
cropped_images, num_or_size_splits=channels, axis=-1)
|
117 |
+
for key, cropped_image in zip(crop_keys, cropped_images):
|
118 |
+
example[key] = cropped_image
|
119 |
+
return example
|
120 |
+
|
121 |
+
|
122 |
+
def apply_data_augmentation(
|
123 |
+
augmentation_fns: Dict[str, Callable[..., tf.Tensor]],
|
124 |
+
example: tf.Tensor,
|
125 |
+
augmentation_keys: Optional[List[str]] = None) -> tf.Tensor:
|
126 |
+
"""Applies random augmentation in succession to selected image keys.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
augmentation_fns: A Dict of Callables to data augmentation functions.
|
130 |
+
example: Input tensor representing images to be augmented.
|
131 |
+
augmentation_keys: The images in the input example to augment.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Example with augmentation applied to selected images.
|
135 |
+
"""
|
136 |
+
if augmentation_keys is None:
|
137 |
+
augmentation_keys = ['x0', 'x1', 'y']
|
138 |
+
|
139 |
+
# Apply each augmentation in sequence
|
140 |
+
augmented_images = {key: example[key] for key in augmentation_keys}
|
141 |
+
for augmentation_function in augmentation_fns.values():
|
142 |
+
augmented_images = augmentation_function(augmented_images)
|
143 |
+
|
144 |
+
for key in augmentation_keys:
|
145 |
+
example[key] = augmented_images[key]
|
146 |
+
return example
|
147 |
+
|
148 |
+
|
149 |
+
def _create_from_tfrecord(batch_size, file, augmentation_fns,
|
150 |
+
crop_size) -> tf.data.Dataset:
|
151 |
+
"""Creates a dataset from TFRecord."""
|
152 |
+
dataset = tf.data.TFRecordDataset(file)
|
153 |
+
dataset = dataset.map(
|
154 |
+
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
155 |
+
|
156 |
+
# Perform data_augmentation before cropping and batching
|
157 |
+
if augmentation_fns is not None:
|
158 |
+
dataset = dataset.map(
|
159 |
+
lambda x: apply_data_augmentation(augmentation_fns, x),
|
160 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
161 |
+
|
162 |
+
if crop_size > 0:
|
163 |
+
dataset = dataset.map(
|
164 |
+
lambda x: crop_example(x, crop_size=crop_size),
|
165 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
166 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
167 |
+
return dataset
|
168 |
+
|
169 |
+
|
170 |
+
def _generate_sharded_filenames(filename: str) -> List[str]:
|
171 |
+
"""Generates filenames of the each file in the sharded filepath.
|
172 |
+
|
173 |
+
Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
filename: The sharded filepath.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
A list of filepaths for each file in the shard.
|
180 |
+
"""
|
181 |
+
base, count = filename.split('@')
|
182 |
+
count = int(count)
|
183 |
+
return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)]
|
184 |
+
|
185 |
+
|
186 |
+
def _create_from_sharded_tfrecord(batch_size,
|
187 |
+
train_mode,
|
188 |
+
file,
|
189 |
+
augmentation_fns,
|
190 |
+
crop_size,
|
191 |
+
max_examples=-1) -> tf.data.Dataset:
|
192 |
+
"""Creates a dataset from a sharded tfrecord."""
|
193 |
+
dataset = tf.data.Dataset.from_tensor_slices(
|
194 |
+
_generate_sharded_filenames(file))
|
195 |
+
|
196 |
+
# pylint: disable=g-long-lambda
|
197 |
+
dataset = dataset.interleave(
|
198 |
+
lambda x: _create_from_tfrecord(
|
199 |
+
batch_size,
|
200 |
+
file=x,
|
201 |
+
augmentation_fns=augmentation_fns,
|
202 |
+
crop_size=crop_size),
|
203 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
204 |
+
deterministic=not train_mode)
|
205 |
+
# pylint: enable=g-long-lambda
|
206 |
+
dataset = dataset.prefetch(buffer_size=2)
|
207 |
+
if max_examples > 0:
|
208 |
+
return dataset.take(max_examples)
|
209 |
+
return dataset
|
210 |
+
|
211 |
+
|
212 |
+
@gin.configurable('training_dataset')
|
213 |
+
def create_training_dataset(
|
214 |
+
batch_size: int,
|
215 |
+
file: Optional[str] = None,
|
216 |
+
files: Optional[List[str]] = None,
|
217 |
+
crop_size: int = -1,
|
218 |
+
crop_sizes: Optional[List[int]] = None,
|
219 |
+
augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None
|
220 |
+
) -> tf.data.Dataset:
|
221 |
+
"""Creates the training dataset.
|
222 |
+
|
223 |
+
The given tfrecord should contain data in a format produced by
|
224 |
+
frame_interpolation/datasets/create_*_tfrecord.py
|
225 |
+
|
226 |
+
Args:
|
227 |
+
batch_size: The number of images to batch per example.
|
228 |
+
file: (deprecated) A path to a sharded tfrecord in <tfrecord>@N format.
|
229 |
+
Deprecated. Use 'files' instead.
|
230 |
+
files: A list of paths to sharded tfrecords in <tfrecord>@N format.
|
231 |
+
crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size
|
232 |
+
using tensorflow's random cropping. Deprecated: use 'files' and
|
233 |
+
'crop_sizes' instead.
|
234 |
+
crop_sizes: List of crop sizes. If > 0, images are cropped to
|
235 |
+
crop_size x crop_size using tensorflow's random cropping.
|
236 |
+
augmentation_fns: A Dict of Callables to data augmentation functions.
|
237 |
+
Returns:
|
238 |
+
A tensorflow dataset for accessing examples that contain the input images
|
239 |
+
'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a
|
240 |
+
dictionary of tensors.
|
241 |
+
"""
|
242 |
+
if file:
|
243 |
+
logging.warning('gin-configurable training_dataset.file is deprecated. '
|
244 |
+
'Use training_dataset.files instead.')
|
245 |
+
return _create_from_sharded_tfrecord(batch_size, True, file,
|
246 |
+
augmentation_fns, crop_size)
|
247 |
+
else:
|
248 |
+
if not crop_sizes or len(crop_sizes) != len(files):
|
249 |
+
raise ValueError('Please pass crop_sizes[] with training_dataset.files.')
|
250 |
+
if crop_size > 0:
|
251 |
+
raise ValueError(
|
252 |
+
'crop_size should not be used with files[], use crop_sizes[] instead.'
|
253 |
+
)
|
254 |
+
tables = []
|
255 |
+
for file, crop_size in zip(files, crop_sizes):
|
256 |
+
tables.append(
|
257 |
+
_create_from_sharded_tfrecord(batch_size, True, file,
|
258 |
+
augmentation_fns, crop_size))
|
259 |
+
return tf.data.experimental.sample_from_datasets(tables)
|
260 |
+
|
261 |
+
|
262 |
+
@gin.configurable('eval_datasets')
|
263 |
+
def create_eval_datasets(batch_size: int,
|
264 |
+
files: List[str],
|
265 |
+
names: List[str],
|
266 |
+
crop_size: int = -1,
|
267 |
+
max_examples: int = -1) -> Dict[str, tf.data.Dataset]:
|
268 |
+
"""Creates the evaluation datasets.
|
269 |
+
|
270 |
+
As opposed to create_training_dataset this function makes sure that the
|
271 |
+
examples for each dataset are always read in a deterministic (same) order.
|
272 |
+
|
273 |
+
Each given tfrecord should contain data in a format produced by
|
274 |
+
frame_interpolation/datasets/create_*_tfrecord.py
|
275 |
+
|
276 |
+
The (batch_size, crop_size, max_examples) are specified for all eval datasets.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
batch_size: The number of images to batch per example.
|
280 |
+
files: List of paths to a sharded tfrecord in <tfrecord>@N format.
|
281 |
+
names: List of names of eval datasets.
|
282 |
+
crop_size: If > 0, images are cropped to crop_size x crop_size using
|
283 |
+
tensorflow's random cropping.
|
284 |
+
max_examples: If > 0, truncate the dataset to 'max_examples' in length. This
|
285 |
+
can be useful for speeding up evaluation loop in case the tfrecord for the
|
286 |
+
evaluation set is very large.
|
287 |
+
Returns:
|
288 |
+
A dict of name to tensorflow dataset for accessing examples that contain the
|
289 |
+
input images 'x0', 'x1', ground truth 'y' and time of the ground truth
|
290 |
+
'time'=[0,1] in a dictionary of tensors.
|
291 |
+
"""
|
292 |
+
return {
|
293 |
+
name: _create_from_sharded_tfrecord(batch_size, False, file, None,
|
294 |
+
crop_size, max_examples)
|
295 |
+
for name, file in zip(names, files)
|
296 |
+
}
|
training/eval_lib.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Evaluation library for frame interpolation."""
|
16 |
+
from typing import Dict, Mapping, Text
|
17 |
+
|
18 |
+
from absl import logging
|
19 |
+
import tensorflow as tf
|
20 |
+
|
21 |
+
|
22 |
+
def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor:
|
23 |
+
"""Collect tensors of the different replicas into a list."""
|
24 |
+
return tf.nest.flatten(tensors, expand_composites=True)
|
25 |
+
|
26 |
+
|
27 |
+
@tf.function
|
28 |
+
def _distributed_eval_step(strategy: tf.distribute.Strategy,
|
29 |
+
batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
|
30 |
+
metrics: Dict[Text, tf.keras.metrics.Metric],
|
31 |
+
checkpoint_step: int) -> Dict[Text, tf.Tensor]:
|
32 |
+
"""Distributed eval step.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
strategy: A Tensorflow distribution strategy.
|
36 |
+
batch: A batch of training examples.
|
37 |
+
model: The Keras model to evaluate.
|
38 |
+
metrics: The Keras metrics used for evaluation (a dictionary).
|
39 |
+
checkpoint_step: The iteration number at which the checkpoint is restored.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
list of predictions from each replica.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def _eval_step(
|
46 |
+
batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
|
47 |
+
"""Eval for one step."""
|
48 |
+
predictions = model(batch, training=False)
|
49 |
+
# Note: these metrics expect batch and prediction dictionaries rather than
|
50 |
+
# tensors like standard TF metrics do. This allows our losses and metrics to
|
51 |
+
# use a richer set of inputs than just the predicted final image.
|
52 |
+
for metric in metrics.values():
|
53 |
+
metric.update_state(batch, predictions, checkpoint_step=checkpoint_step)
|
54 |
+
return predictions
|
55 |
+
|
56 |
+
return strategy.run(_eval_step, args=(batch,))
|
57 |
+
|
58 |
+
|
59 |
+
def _summarize_image_tensors(combined, prefix, step):
|
60 |
+
for name in combined:
|
61 |
+
image = combined[name]
|
62 |
+
if isinstance(image, tf.Tensor):
|
63 |
+
if len(image.shape) == 4 and (image.shape[-1] == 1 or
|
64 |
+
image.shape[-1] == 3):
|
65 |
+
tf.summary.image(prefix + '/' + name, image, step=step)
|
66 |
+
|
67 |
+
|
68 |
+
def eval_loop(strategy: tf.distribute.Strategy,
|
69 |
+
eval_base_folder: str,
|
70 |
+
model: tf.keras.Model,
|
71 |
+
metrics: Dict[str, tf.keras.metrics.Metric],
|
72 |
+
datasets: Mapping[str, tf.data.Dataset],
|
73 |
+
summary_writer: tf.summary.SummaryWriter,
|
74 |
+
checkpoint_step: int):
|
75 |
+
"""Eval function that is strategy agnostic.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
strategy: A Tensorflow distributed strategy.
|
79 |
+
eval_base_folder: A path to where the summaries event files and
|
80 |
+
checkpoints will be saved.
|
81 |
+
model: A function that returns the model.
|
82 |
+
metrics: A function that returns the metrics dictionary.
|
83 |
+
datasets: A dict of tf.data.Dataset to evaluate on.
|
84 |
+
summary_writer: Eval summary writer.
|
85 |
+
checkpoint_step: The number of iterations completed.
|
86 |
+
"""
|
87 |
+
logging.info('Saving eval summaries to: %s...', eval_base_folder)
|
88 |
+
summary_writer.set_as_default()
|
89 |
+
|
90 |
+
for dataset_name, dataset in datasets.items():
|
91 |
+
for metric in metrics.values():
|
92 |
+
metric.reset_states()
|
93 |
+
|
94 |
+
logging.info('Loading %s testing data ...', dataset_name)
|
95 |
+
dataset = strategy.experimental_distribute_dataset(dataset)
|
96 |
+
|
97 |
+
logging.info('Evaluating %s ...', dataset_name)
|
98 |
+
batch_idx = 0
|
99 |
+
max_batches_to_summarize = 10
|
100 |
+
for batch in dataset:
|
101 |
+
predictions = _distributed_eval_step(strategy, batch, model, metrics,
|
102 |
+
checkpoint_step)
|
103 |
+
# Clip interpolator output to [0,1]. Clipping is done only
|
104 |
+
# on the eval loop to get better metrics, but not on the training loop
|
105 |
+
# so gradients are not killed.
|
106 |
+
if strategy.num_replicas_in_sync > 1:
|
107 |
+
predictions = {
|
108 |
+
'image': tf.concat(predictions['image'].values, axis=0)
|
109 |
+
}
|
110 |
+
predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.)
|
111 |
+
if batch_idx % 10 == 0:
|
112 |
+
logging.info('Evaluating batch %s', batch_idx)
|
113 |
+
batch_idx = batch_idx + 1
|
114 |
+
if batch_idx < max_batches_to_summarize:
|
115 |
+
# Loop through the global batch:
|
116 |
+
prefix = f'{dataset_name}/eval_{batch_idx}'
|
117 |
+
# Find all tensors that look like images, and summarize:
|
118 |
+
combined = {**batch, **predictions}
|
119 |
+
_summarize_image_tensors(combined, prefix, step=checkpoint_step)
|
120 |
+
|
121 |
+
elif batch_idx == max_batches_to_summarize:
|
122 |
+
tf.summary.flush()
|
123 |
+
|
124 |
+
for name, metric in metrics.items():
|
125 |
+
tf.summary.scalar(
|
126 |
+
f'{dataset_name}/{name}', metric.result(), step=checkpoint_step)
|
127 |
+
tf.summary.flush()
|
128 |
+
logging.info('Step {:2}, {} {}'.format(checkpoint_step,
|
129 |
+
f'{dataset_name}/{name}',
|
130 |
+
metric.result().numpy()))
|
131 |
+
metric.reset_states()
|
training/metrics_lib.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""A library for instantiating frame interpolation evaluation metrics."""
|
16 |
+
|
17 |
+
from typing import Callable, Dict, Text
|
18 |
+
|
19 |
+
from ..losses import losses
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
|
23 |
+
class TrainLossMetric(tf.keras.metrics.Metric):
|
24 |
+
"""Compute training loss for our example and prediction format.
|
25 |
+
|
26 |
+
The purpose of this is to ensure that we always include a loss that is exactly
|
27 |
+
like the training loss into the evaluation in order to detect possible
|
28 |
+
overfitting.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, name='eval_loss', **kwargs):
|
32 |
+
super(TrainLossMetric, self).__init__(name=name, **kwargs)
|
33 |
+
self.acc = self.add_weight(name='train_metric_acc', initializer='zeros')
|
34 |
+
self.count = self.add_weight(name='train_metric_count', initializer='zeros')
|
35 |
+
|
36 |
+
def update_state(self,
|
37 |
+
batch,
|
38 |
+
predictions,
|
39 |
+
sample_weight=None,
|
40 |
+
checkpoint_step=0):
|
41 |
+
loss_functions = losses.training_losses()
|
42 |
+
loss_list = []
|
43 |
+
for (loss_value, loss_weight) in loss_functions.values():
|
44 |
+
loss_list.append(
|
45 |
+
loss_value(batch, predictions) * loss_weight(checkpoint_step))
|
46 |
+
loss = tf.add_n(loss_list)
|
47 |
+
self.acc.assign_add(loss)
|
48 |
+
self.count.assign_add(1)
|
49 |
+
|
50 |
+
def result(self):
|
51 |
+
return self.acc / self.count
|
52 |
+
|
53 |
+
def reset_states(self):
|
54 |
+
self.acc.assign(0)
|
55 |
+
self.count.assign(0)
|
56 |
+
|
57 |
+
|
58 |
+
class L1Metric(tf.keras.metrics.Metric):
|
59 |
+
"""Compute L1 over our training example and prediction format.
|
60 |
+
|
61 |
+
The purpose of this is to ensure that we have at least one metric that is
|
62 |
+
compatible across all eval the session and allows us to quickly compare models
|
63 |
+
against each other.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, name='eval_loss', **kwargs):
|
67 |
+
super(L1Metric, self).__init__(name=name, **kwargs)
|
68 |
+
self.acc = self.add_weight(name='l1_metric_acc', initializer='zeros')
|
69 |
+
self.count = self.add_weight(name='l1_metric_count', initializer='zeros')
|
70 |
+
|
71 |
+
def update_state(self, batch, prediction, sample_weight=None,
|
72 |
+
checkpoint_step=0):
|
73 |
+
self.acc.assign_add(losses.l1_loss(batch, prediction))
|
74 |
+
self.count.assign_add(1)
|
75 |
+
|
76 |
+
def result(self):
|
77 |
+
return self.acc / self.count
|
78 |
+
|
79 |
+
def reset_states(self):
|
80 |
+
self.acc.assign(0)
|
81 |
+
self.count.assign(0)
|
82 |
+
|
83 |
+
|
84 |
+
class GenericLossMetric(tf.keras.metrics.Metric):
|
85 |
+
"""Metric based on any loss function."""
|
86 |
+
|
87 |
+
def __init__(self, name: str, loss: Callable[..., tf.Tensor],
|
88 |
+
weight: Callable[..., tf.Tensor], **kwargs):
|
89 |
+
"""Initializes a metric based on a loss function and a weight schedule.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
name: The name of the metric.
|
93 |
+
loss: The callable loss that calculates a loss value for a (prediction,
|
94 |
+
target) pair.
|
95 |
+
weight: The callable weight scheduling function that samples a weight
|
96 |
+
based on iteration.
|
97 |
+
**kwargs: Any additional keyword arguments to be passed.
|
98 |
+
"""
|
99 |
+
super(GenericLossMetric, self).__init__(name=name, **kwargs)
|
100 |
+
self.acc = self.add_weight(name='loss_metric_acc', initializer='zeros')
|
101 |
+
self.count = self.add_weight(name='loss_metric_count', initializer='zeros')
|
102 |
+
self.loss = loss
|
103 |
+
self.weight = weight
|
104 |
+
|
105 |
+
def update_state(self,
|
106 |
+
batch,
|
107 |
+
predictions,
|
108 |
+
sample_weight=None,
|
109 |
+
checkpoint_step=0):
|
110 |
+
self.acc.assign_add(
|
111 |
+
self.loss(batch, predictions) * self.weight(checkpoint_step))
|
112 |
+
self.count.assign_add(1)
|
113 |
+
|
114 |
+
def result(self):
|
115 |
+
return self.acc / self.count
|
116 |
+
|
117 |
+
def reset_states(self):
|
118 |
+
self.acc.assign(0)
|
119 |
+
self.count.assign(0)
|
120 |
+
|
121 |
+
|
122 |
+
def create_metrics_fn() -> Dict[Text, tf.keras.metrics.Metric]:
|
123 |
+
"""Create evaluation metrics.
|
124 |
+
|
125 |
+
L1 and total training loss are added by default.
|
126 |
+
The rest are the configured by the test_losses item via gin.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
A dictionary from metric name to Keras Metric object.
|
130 |
+
"""
|
131 |
+
metrics = {}
|
132 |
+
# L1 is explicitly added just so we always have some consistent numbers around
|
133 |
+
# to compare across sessions.
|
134 |
+
metrics['l1'] = L1Metric()
|
135 |
+
# We also always include training loss for the eval set to detect overfitting:
|
136 |
+
metrics['training_loss'] = TrainLossMetric()
|
137 |
+
|
138 |
+
test_losses = losses.test_losses()
|
139 |
+
for loss_name, (loss_value, loss_weight) in test_losses.items():
|
140 |
+
metrics[loss_name] = GenericLossMetric(
|
141 |
+
name=loss_name, loss=loss_value, weight=loss_weight)
|
142 |
+
return metrics
|
training/model_lib.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""A library for instantiating the model for training frame interpolation.
|
16 |
+
|
17 |
+
All models are expected to use three inputs: input image batches 'x0' and 'x1'
|
18 |
+
and 'time', the fractional time where the output should be generated.
|
19 |
+
|
20 |
+
The models are expected to output the prediction as a dictionary that contains
|
21 |
+
at least the predicted image batch as 'image' plus optional data for debug,
|
22 |
+
analysis or custom losses.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import gin.tf
|
26 |
+
from ..models.film_net import interpolator as film_net_interpolator
|
27 |
+
from ..models.film_net import options as film_net_options
|
28 |
+
|
29 |
+
import tensorflow as tf
|
30 |
+
|
31 |
+
|
32 |
+
@gin.configurable('model')
|
33 |
+
def create_model(name: str) -> tf.keras.Model:
|
34 |
+
"""Creates the frame interpolation model based on given model name."""
|
35 |
+
if name == 'film_net':
|
36 |
+
return _create_film_net_model() # pylint: disable=no-value-for-parameter
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Model {name} not implemented.')
|
39 |
+
|
40 |
+
|
41 |
+
def _create_film_net_model() -> tf.keras.Model:
|
42 |
+
"""Creates the film_net interpolator."""
|
43 |
+
# Options are gin-configured in the Options class directly.
|
44 |
+
options = film_net_options.Options()
|
45 |
+
|
46 |
+
x0 = tf.keras.Input(
|
47 |
+
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0')
|
48 |
+
x1 = tf.keras.Input(
|
49 |
+
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1')
|
50 |
+
time = tf.keras.Input(
|
51 |
+
shape=(1,), batch_size=None, dtype=tf.float32, name='time')
|
52 |
+
|
53 |
+
return film_net_interpolator.create_model(x0, x1, time, options)
|
training/train.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""The training loop for frame interpolation.
|
16 |
+
|
17 |
+
gin_config: The gin configuration file containing model, losses and datasets.
|
18 |
+
|
19 |
+
To run on GPUs:
|
20 |
+
python3 -m frame_interpolation.training.train \
|
21 |
+
--gin_config <path to network.gin> \
|
22 |
+
--base_folder <base folder for all training runs> \
|
23 |
+
--label <descriptive label for the run>
|
24 |
+
|
25 |
+
To debug the training loop on CPU:
|
26 |
+
python3 -m frame_interpolation.training.train \
|
27 |
+
--gin_config <path to config.gin> \
|
28 |
+
--base_folder /tmp
|
29 |
+
--label test_run \
|
30 |
+
--mode cpu
|
31 |
+
|
32 |
+
The training output directory will be created at <base_folder>/<label>.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
|
36 |
+
from . import augmentation_lib
|
37 |
+
from . import data_lib
|
38 |
+
from . import eval_lib
|
39 |
+
from . import metrics_lib
|
40 |
+
from . import model_lib
|
41 |
+
from . import train_lib
|
42 |
+
from absl import app
|
43 |
+
from absl import flags
|
44 |
+
from absl import logging
|
45 |
+
import gin.tf
|
46 |
+
from ..losses import losses
|
47 |
+
|
48 |
+
# Reduce tensorflow logs to ERRORs only.
|
49 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
50 |
+
import tensorflow as tf # pylint: disable=g-import-not-at-top
|
51 |
+
tf.get_logger().setLevel('ERROR')
|
52 |
+
|
53 |
+
|
54 |
+
_GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
|
55 |
+
_LABEL = flags.DEFINE_string('label', 'run0',
|
56 |
+
'Descriptive label for this run.')
|
57 |
+
_BASE_FOLDER = flags.DEFINE_string('base_folder', None,
|
58 |
+
'Path to checkpoints/summaries.')
|
59 |
+
_MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
|
60 |
+
'Distributed strategy approach.')
|
61 |
+
|
62 |
+
|
63 |
+
@gin.configurable('training')
|
64 |
+
class TrainingOptions(object):
|
65 |
+
"""Training-related options."""
|
66 |
+
|
67 |
+
def __init__(self, learning_rate: float, learning_rate_decay_steps: int,
|
68 |
+
learning_rate_decay_rate: int, learning_rate_staircase: int,
|
69 |
+
num_steps: int):
|
70 |
+
self.learning_rate = learning_rate
|
71 |
+
self.learning_rate_decay_steps = learning_rate_decay_steps
|
72 |
+
self.learning_rate_decay_rate = learning_rate_decay_rate
|
73 |
+
self.learning_rate_staircase = learning_rate_staircase
|
74 |
+
self.num_steps = num_steps
|
75 |
+
|
76 |
+
|
77 |
+
def main(argv):
|
78 |
+
if len(argv) > 1:
|
79 |
+
raise app.UsageError('Too many command-line arguments.')
|
80 |
+
|
81 |
+
output_dir = os.path.join(_BASE_FOLDER.value, _LABEL.value)
|
82 |
+
logging.info('Creating output_dir @ %s ...', output_dir)
|
83 |
+
|
84 |
+
# Copy config file to <base_folder>/<label>/config.gin.
|
85 |
+
tf.io.gfile.makedirs(output_dir)
|
86 |
+
tf.io.gfile.copy(
|
87 |
+
_GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
|
88 |
+
|
89 |
+
gin.external_configurable(
|
90 |
+
tf.keras.optimizers.schedules.PiecewiseConstantDecay,
|
91 |
+
module='tf.keras.optimizers.schedules')
|
92 |
+
|
93 |
+
gin_configs = [_GIN_CONFIG.value]
|
94 |
+
gin.parse_config_files_and_bindings(
|
95 |
+
config_files=gin_configs, bindings=None, skip_unknown=True)
|
96 |
+
|
97 |
+
training_options = TrainingOptions() # pylint: disable=no-value-for-parameter
|
98 |
+
|
99 |
+
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
|
100 |
+
training_options.learning_rate,
|
101 |
+
training_options.learning_rate_decay_steps,
|
102 |
+
training_options.learning_rate_decay_rate,
|
103 |
+
training_options.learning_rate_staircase,
|
104 |
+
name='learning_rate')
|
105 |
+
|
106 |
+
# Initialize data augmentation functions
|
107 |
+
augmentation_fns = augmentation_lib.data_augmentations()
|
108 |
+
|
109 |
+
saved_model_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value,
|
110 |
+
'saved_model')
|
111 |
+
train_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
|
112 |
+
eval_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'eval')
|
113 |
+
|
114 |
+
train_lib.train(
|
115 |
+
strategy=train_lib.get_strategy(_MODE.value),
|
116 |
+
train_folder=train_folder,
|
117 |
+
saved_model_folder=saved_model_folder,
|
118 |
+
n_iterations=training_options.num_steps,
|
119 |
+
create_model_fn=model_lib.create_model,
|
120 |
+
create_losses_fn=losses.training_losses,
|
121 |
+
create_metrics_fn=metrics_lib.create_metrics_fn,
|
122 |
+
dataset=data_lib.create_training_dataset(
|
123 |
+
augmentation_fns=augmentation_fns),
|
124 |
+
learning_rate=learning_rate,
|
125 |
+
eval_loop_fn=eval_lib.eval_loop,
|
126 |
+
eval_folder=eval_folder,
|
127 |
+
eval_datasets=data_lib.create_eval_datasets() or None)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
app.run(main)
|
training/train_lib.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
r"""Training library for frame interpolation using distributed strategy."""
|
16 |
+
import functools
|
17 |
+
from typing import Any, Callable, Dict, Text, Tuple
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
|
23 |
+
def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor:
|
24 |
+
"""Concat tensors of the different replicas."""
|
25 |
+
return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0)
|
26 |
+
|
27 |
+
|
28 |
+
@tf.function
|
29 |
+
def _distributed_train_step(strategy: tf.distribute.Strategy,
|
30 |
+
batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
|
31 |
+
loss_functions: Dict[Text,
|
32 |
+
Tuple[Callable[..., tf.Tensor],
|
33 |
+
Callable[...,
|
34 |
+
tf.Tensor]]],
|
35 |
+
optimizer: tf.keras.optimizers.Optimizer,
|
36 |
+
iterations: int) -> Dict[Text, Any]:
|
37 |
+
"""Distributed training step.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
strategy: A Tensorflow distribution strategy.
|
41 |
+
batch: A batch of training examples.
|
42 |
+
model: The Keras model to train.
|
43 |
+
loss_functions: The list of Keras losses used to train the model.
|
44 |
+
optimizer: The Keras optimizer used to train the model.
|
45 |
+
iterations: Iteration number used to sample weights to each loss.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
A dictionary of train step outputs.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
|
52 |
+
"""Train for one step."""
|
53 |
+
with tf.GradientTape() as tape:
|
54 |
+
predictions = model(batch, training=True)
|
55 |
+
losses = []
|
56 |
+
for (loss_value, loss_weight) in loss_functions.values():
|
57 |
+
losses.append(loss_value(batch, predictions) * loss_weight(iterations))
|
58 |
+
loss = tf.add_n(losses)
|
59 |
+
grads = tape.gradient(loss, model.trainable_variables)
|
60 |
+
optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
61 |
+
# post process for visualization
|
62 |
+
all_data = {'loss': loss}
|
63 |
+
all_data.update(batch)
|
64 |
+
all_data.update(predictions)
|
65 |
+
return all_data
|
66 |
+
|
67 |
+
step_outputs = strategy.run(_train_step, args=(batch,))
|
68 |
+
|
69 |
+
loss = strategy.reduce(
|
70 |
+
tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None)
|
71 |
+
|
72 |
+
x0 = _concat_tensors(step_outputs['x0'])
|
73 |
+
x1 = _concat_tensors(step_outputs['x1'])
|
74 |
+
y = _concat_tensors(step_outputs['y'])
|
75 |
+
pred_y = _concat_tensors(step_outputs['image'])
|
76 |
+
|
77 |
+
scalar_summaries = {'training_loss': loss}
|
78 |
+
|
79 |
+
image_summaries = {
|
80 |
+
'x0': x0,
|
81 |
+
'x1': x1,
|
82 |
+
'y': y,
|
83 |
+
'pred_y': pred_y
|
84 |
+
}
|
85 |
+
|
86 |
+
extra_images = {
|
87 |
+
'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image',
|
88 |
+
'bg_image', 'fg_alpha', 'x1_unfiltered_warped'
|
89 |
+
}
|
90 |
+
for image in extra_images:
|
91 |
+
if image in step_outputs:
|
92 |
+
image_summaries[image] = _concat_tensors(step_outputs[image])
|
93 |
+
|
94 |
+
return {
|
95 |
+
'loss': loss,
|
96 |
+
'scalar_summaries': scalar_summaries,
|
97 |
+
'image_summaries': {
|
98 |
+
f'training/{name}': value for name, value in image_summaries.items()
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def _summary_writer(summaries_dict: Dict[Text, Any]) -> None:
|
104 |
+
"""Adds scalar and image summaries."""
|
105 |
+
# Adds scalar summaries.
|
106 |
+
for key, scalars in summaries_dict['scalar_summaries'].items():
|
107 |
+
tf.summary.scalar(key, scalars)
|
108 |
+
# Adds image summaries.
|
109 |
+
for key, images in summaries_dict['image_summaries'].items():
|
110 |
+
tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0))
|
111 |
+
tf.summary.histogram(key + '_h', images)
|
112 |
+
|
113 |
+
|
114 |
+
def train_loop(
|
115 |
+
strategy: tf.distribute.Strategy,
|
116 |
+
train_set: tf.data.Dataset,
|
117 |
+
create_model_fn: Callable[..., tf.keras.Model],
|
118 |
+
create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor],
|
119 |
+
Callable[..., tf.Tensor]]]],
|
120 |
+
create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer],
|
121 |
+
distributed_train_step_fn: Callable[[
|
122 |
+
tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[
|
123 |
+
str,
|
124 |
+
Tuple[Callable[..., tf.Tensor],
|
125 |
+
Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int
|
126 |
+
], Dict[str, Any]],
|
127 |
+
eval_loop_fn: Callable[..., None],
|
128 |
+
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
|
129 |
+
eval_folder: Dict[str, Any],
|
130 |
+
eval_datasets: Dict[str, tf.data.Dataset],
|
131 |
+
summary_writer_fn: Callable[[Dict[str, Any]], None],
|
132 |
+
train_folder: str,
|
133 |
+
saved_model_folder: str,
|
134 |
+
num_iterations: int,
|
135 |
+
save_summaries_frequency: int = 500,
|
136 |
+
save_checkpoint_frequency: int = 500,
|
137 |
+
checkpoint_max_to_keep: int = 10,
|
138 |
+
checkpoint_save_every_n_hours: float = 2.,
|
139 |
+
timing_frequency: int = 100,
|
140 |
+
logging_frequency: int = 10):
|
141 |
+
"""A Tensorflow 2 eager mode training loop.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
strategy: A Tensorflow distributed strategy.
|
145 |
+
train_set: A tf.data.Dataset to loop through for training.
|
146 |
+
create_model_fn: A callable that returns a tf.keras.Model.
|
147 |
+
create_losses_fn: A callable that returns a tf.keras.losses.Loss.
|
148 |
+
create_optimizer_fn: A callable that returns a
|
149 |
+
tf.keras.optimizers.Optimizer.
|
150 |
+
distributed_train_step_fn: A callable that takes a distribution strategy, a
|
151 |
+
Dict[Text, tf.Tensor] holding the batch of training data, a
|
152 |
+
tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer,
|
153 |
+
iteartion number to sample a weight value to loos functions,
|
154 |
+
and returns a dictionary to be passed to the summary_writer_fn.
|
155 |
+
eval_loop_fn: Eval loop function.
|
156 |
+
create_metrics_fn: create_metric_fn.
|
157 |
+
eval_folder: A path to where the summary event files and checkpoints will be
|
158 |
+
saved.
|
159 |
+
eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for
|
160 |
+
evaluation.
|
161 |
+
summary_writer_fn: A callable that takes the output of
|
162 |
+
distributed_train_step_fn and writes summaries to be visualized in
|
163 |
+
TensorBoard.
|
164 |
+
train_folder: A path to where the summaries event files and checkpoints
|
165 |
+
will be saved.
|
166 |
+
saved_model_folder: A path to where the saved models are stored.
|
167 |
+
num_iterations: An integer, the number of iterations to train for.
|
168 |
+
save_summaries_frequency: The iteration frequency with which summaries are
|
169 |
+
saved.
|
170 |
+
save_checkpoint_frequency: The iteration frequency with which model
|
171 |
+
checkpoints are saved.
|
172 |
+
checkpoint_max_to_keep: The maximum number of checkpoints to keep.
|
173 |
+
checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints.
|
174 |
+
timing_frequency: The iteration frequency with which to log timing.
|
175 |
+
logging_frequency: How often to output with logging.info().
|
176 |
+
"""
|
177 |
+
logging.info('Creating training tensorboard summaries ...')
|
178 |
+
summary_writer = tf.summary.create_file_writer(train_folder)
|
179 |
+
|
180 |
+
if eval_datasets is not None:
|
181 |
+
logging.info('Creating eval tensorboard summaries ...')
|
182 |
+
eval_summary_writer = tf.summary.create_file_writer(eval_folder)
|
183 |
+
|
184 |
+
train_set = strategy.experimental_distribute_dataset(train_set)
|
185 |
+
with strategy.scope():
|
186 |
+
logging.info('Building model ...')
|
187 |
+
model = create_model_fn()
|
188 |
+
loss_functions = create_losses_fn()
|
189 |
+
optimizer = create_optimizer_fn()
|
190 |
+
if eval_datasets is not None:
|
191 |
+
metrics = create_metrics_fn()
|
192 |
+
|
193 |
+
logging.info('Creating checkpoint ...')
|
194 |
+
checkpoint = tf.train.Checkpoint(
|
195 |
+
model=model,
|
196 |
+
optimizer=optimizer,
|
197 |
+
step=optimizer.iterations,
|
198 |
+
epoch=tf.Variable(0, dtype=tf.int64, trainable=False),
|
199 |
+
training_finished=tf.Variable(False, dtype=tf.bool, trainable=False))
|
200 |
+
|
201 |
+
logging.info('Restoring old model (if exists) ...')
|
202 |
+
checkpoint_manager = tf.train.CheckpointManager(
|
203 |
+
checkpoint,
|
204 |
+
directory=train_folder,
|
205 |
+
max_to_keep=checkpoint_max_to_keep,
|
206 |
+
keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours)
|
207 |
+
|
208 |
+
with strategy.scope():
|
209 |
+
if checkpoint_manager.latest_checkpoint:
|
210 |
+
checkpoint.restore(checkpoint_manager.latest_checkpoint)
|
211 |
+
|
212 |
+
logging.info('Creating Timer ...')
|
213 |
+
timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency)
|
214 |
+
timer.update_last_triggered_step(optimizer.iterations.numpy())
|
215 |
+
|
216 |
+
logging.info('Training on devices: %s.', [
|
217 |
+
el.name.split('/physical_device:')[-1]
|
218 |
+
for el in tf.config.get_visible_devices()
|
219 |
+
])
|
220 |
+
|
221 |
+
# Re-assign training_finished=False, in case we restored a checkpoint.
|
222 |
+
checkpoint.training_finished.assign(False)
|
223 |
+
while optimizer.iterations.numpy() < num_iterations:
|
224 |
+
for i_batch, batch in enumerate(train_set):
|
225 |
+
summary_writer.set_as_default()
|
226 |
+
iterations = optimizer.iterations.numpy()
|
227 |
+
|
228 |
+
if iterations % logging_frequency == 0:
|
229 |
+
# Log epoch, total iterations and batch index.
|
230 |
+
logging.info('epoch %d; iterations %d; i_batch %d',
|
231 |
+
checkpoint.epoch.numpy(), iterations,
|
232 |
+
i_batch)
|
233 |
+
|
234 |
+
# Break if the number of iterations exceeds the max.
|
235 |
+
if iterations >= num_iterations:
|
236 |
+
break
|
237 |
+
|
238 |
+
# Compute distributed step outputs.
|
239 |
+
distributed_step_outputs = distributed_train_step_fn(
|
240 |
+
strategy, batch, model, loss_functions, optimizer, iterations)
|
241 |
+
|
242 |
+
# Save checkpoint, and optionally run the eval loops.
|
243 |
+
if iterations % save_checkpoint_frequency == 0:
|
244 |
+
checkpoint_manager.save(checkpoint_number=iterations)
|
245 |
+
if eval_datasets is not None:
|
246 |
+
eval_loop_fn(
|
247 |
+
strategy=strategy,
|
248 |
+
eval_base_folder=eval_folder,
|
249 |
+
model=model,
|
250 |
+
metrics=metrics,
|
251 |
+
datasets=eval_datasets,
|
252 |
+
summary_writer=eval_summary_writer,
|
253 |
+
checkpoint_step=iterations)
|
254 |
+
|
255 |
+
# Write summaries.
|
256 |
+
if iterations % save_summaries_frequency == 0:
|
257 |
+
tf.summary.experimental.set_step(step=iterations)
|
258 |
+
summary_writer_fn(distributed_step_outputs)
|
259 |
+
tf.summary.scalar('learning_rate',
|
260 |
+
optimizer.learning_rate(iterations).numpy())
|
261 |
+
|
262 |
+
# Log steps/sec.
|
263 |
+
if timer.should_trigger_for_step(iterations):
|
264 |
+
elapsed_time, elapsed_steps = timer.update_last_triggered_step(
|
265 |
+
iterations)
|
266 |
+
if elapsed_time is not None:
|
267 |
+
steps_per_second = elapsed_steps / elapsed_time
|
268 |
+
tf.summary.scalar(
|
269 |
+
'steps/sec', steps_per_second, step=optimizer.iterations)
|
270 |
+
|
271 |
+
# Increment epoch.
|
272 |
+
checkpoint.epoch.assign_add(1)
|
273 |
+
|
274 |
+
# Assign training_finished variable to True after training is finished and
|
275 |
+
# save the last checkpoint.
|
276 |
+
checkpoint.training_finished.assign(True)
|
277 |
+
checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy())
|
278 |
+
|
279 |
+
# Generate a saved model.
|
280 |
+
model.save(saved_model_folder)
|
281 |
+
|
282 |
+
|
283 |
+
def train(strategy: tf.distribute.Strategy, train_folder: str,
|
284 |
+
saved_model_folder: str, n_iterations: int,
|
285 |
+
create_model_fn: Callable[..., tf.keras.Model],
|
286 |
+
create_losses_fn: Callable[..., Dict[str,
|
287 |
+
Tuple[Callable[..., tf.Tensor],
|
288 |
+
Callable[...,
|
289 |
+
tf.Tensor]]]],
|
290 |
+
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
|
291 |
+
dataset: tf.data.Dataset,
|
292 |
+
learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
|
293 |
+
eval_loop_fn: Callable[..., None],
|
294 |
+
eval_folder: str,
|
295 |
+
eval_datasets: Dict[str, tf.data.Dataset]):
|
296 |
+
"""Training function that is strategy agnostic.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
strategy: A Tensorflow distributed strategy.
|
300 |
+
train_folder: A path to where the summaries event files and checkpoints
|
301 |
+
will be saved.
|
302 |
+
saved_model_folder: A path to where the saved models are stored.
|
303 |
+
n_iterations: An integer, the number of iterations to train for.
|
304 |
+
create_model_fn: A callable that returns tf.keras.Model.
|
305 |
+
create_losses_fn: A callable that returns the losses.
|
306 |
+
create_metrics_fn: A function that returns the metrics dictionary.
|
307 |
+
dataset: The tensorflow dataset object.
|
308 |
+
learning_rate: Keras learning rate schedule object.
|
309 |
+
eval_loop_fn: eval loop function.
|
310 |
+
eval_folder: A path to where eval summaries event files and checkpoints
|
311 |
+
will be saved.
|
312 |
+
eval_datasets: The tensorflow evaluation dataset objects.
|
313 |
+
"""
|
314 |
+
train_loop(
|
315 |
+
strategy=strategy,
|
316 |
+
train_set=dataset,
|
317 |
+
create_model_fn=create_model_fn,
|
318 |
+
create_losses_fn=create_losses_fn,
|
319 |
+
create_optimizer_fn=functools.partial(
|
320 |
+
tf.keras.optimizers.Adam, learning_rate=learning_rate),
|
321 |
+
distributed_train_step_fn=_distributed_train_step,
|
322 |
+
eval_loop_fn=eval_loop_fn,
|
323 |
+
create_metrics_fn=create_metrics_fn,
|
324 |
+
eval_folder=eval_folder,
|
325 |
+
eval_datasets=eval_datasets,
|
326 |
+
summary_writer_fn=_summary_writer,
|
327 |
+
train_folder=train_folder,
|
328 |
+
saved_model_folder=saved_model_folder,
|
329 |
+
num_iterations=n_iterations,
|
330 |
+
save_summaries_frequency=3000,
|
331 |
+
save_checkpoint_frequency=3000)
|
332 |
+
|
333 |
+
|
334 |
+
def get_strategy(mode) -> tf.distribute.Strategy:
|
335 |
+
"""Creates a distributed strategy."""
|
336 |
+
strategy = None
|
337 |
+
if mode == 'cpu':
|
338 |
+
strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
|
339 |
+
elif mode == 'gpu':
|
340 |
+
strategy = tf.distribute.MirroredStrategy()
|
341 |
+
else:
|
342 |
+
raise ValueError('Unsupported distributed mode.')
|
343 |
+
return strategy
|