lmzjms commited on
Commit
0b32ad6
·
verified ·
1 Parent(s): 608dc1e

Upload 1162 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. s3prl_s3prl_main/Dockerfile +42 -0
  3. s3prl_s3prl_main/LICENSE +201 -0
  4. s3prl_s3prl_main/README.md +357 -0
  5. s3prl_s3prl_main/__pycache__/hubconf.cpython-310.pyc +0 -0
  6. s3prl_s3prl_main/__pycache__/hubconf.cpython-39.pyc +0 -0
  7. s3prl_s3prl_main/ci/format.py +84 -0
  8. s3prl_s3prl_main/docs/Makefile +20 -0
  9. s3prl_s3prl_main/docs/README.md +27 -0
  10. s3prl_s3prl_main/docs/from_scratch_tutorial.md +146 -0
  11. s3prl_s3prl_main/docs/make.bat +35 -0
  12. s3prl_s3prl_main/docs/rebuild_docs.sh +9 -0
  13. s3prl_s3prl_main/docs/source/_static/css/custom.css +3 -0
  14. s3prl_s3prl_main/docs/source/_static/js/custom.js +7 -0
  15. s3prl_s3prl_main/docs/source/_templates/custom-module-template.rst +81 -0
  16. s3prl_s3prl_main/docs/source/conf.py +120 -0
  17. s3prl_s3prl_main/docs/source/contribute/general.rst +167 -0
  18. s3prl_s3prl_main/docs/source/contribute/private.rst +104 -0
  19. s3prl_s3prl_main/docs/source/contribute/public.rst +29 -0
  20. s3prl_s3prl_main/docs/source/contribute/upstream.rst +100 -0
  21. s3prl_s3prl_main/docs/source/index.rst +75 -0
  22. s3prl_s3prl_main/docs/source/tutorial/installation.rst +55 -0
  23. s3prl_s3prl_main/docs/source/tutorial/problem.rst +122 -0
  24. s3prl_s3prl_main/docs/source/tutorial/upstream_collection.rst +1457 -0
  25. s3prl_s3prl_main/docs/util/is_valid.py +21 -0
  26. s3prl_s3prl_main/example/customize.py +43 -0
  27. s3prl_s3prl_main/example/run_asr.sh +3 -0
  28. s3prl_s3prl_main/example/run_sid.sh +4 -0
  29. s3prl_s3prl_main/example/ssl/pretrain.py +285 -0
  30. s3prl_s3prl_main/example/superb/train.py +291 -0
  31. s3prl_s3prl_main/example/superb_asr/inference.py +40 -0
  32. s3prl_s3prl_main/example/superb_asr/train.py +241 -0
  33. s3prl_s3prl_main/example/superb_asr/train_with_lightning.py +127 -0
  34. s3prl_s3prl_main/example/superb_sid/inference.py +40 -0
  35. s3prl_s3prl_main/example/superb_sid/train.py +235 -0
  36. s3prl_s3prl_main/example/superb_sid/train_with_lightning.py +127 -0
  37. s3prl_s3prl_main/example/superb_sv/inference.py +47 -0
  38. s3prl_s3prl_main/example/superb_sv/train.py +266 -0
  39. s3prl_s3prl_main/example/superb_sv/train_with_lightning.py +184 -0
  40. s3prl_s3prl_main/external_tools/install_espnet.sh +20 -0
  41. s3prl_s3prl_main/file/S3PRL-integration.png +0 -0
  42. s3prl_s3prl_main/file/S3PRL-interface.png +0 -0
  43. s3prl_s3prl_main/file/S3PRL-logo.png +0 -0
  44. s3prl_s3prl_main/file/license.svg +1 -0
  45. s3prl_s3prl_main/find_content.sh +8 -0
  46. s3prl_s3prl_main/hubconf.py +4 -0
  47. s3prl_s3prl_main/pyrightconfig.json +5 -0
  48. s3prl_s3prl_main/pytest.ini +5 -0
  49. s3prl_s3prl_main/requirements/all.txt +33 -0
  50. s3prl_s3prl_main/requirements/dev.txt +11 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ s3prl_s3prl_main/s3prl/downstream/phone_linear/data/converted_aligned_phones.txt filter=lfs diff=lfs merge=lfs -text
37
+ s3prl_s3prl_main/s3prl/downstream/voxceleb2_amsoftmax_segment_eval/cache_wav_paths/cache_test_segment.p filter=lfs diff=lfs merge=lfs -text
38
+ s3prl_s3prl_main/s3prl/downstream/voxceleb2_amsoftmax_segment_eval/cache_wav_paths/cache_Voxceleb1.p filter=lfs diff=lfs merge=lfs -text
39
+ s3prl_s3prl_main/s3prl/downstream/voxceleb2_amsoftmax_segment_eval/cache_wav_paths/cache_Voxceleb2.p filter=lfs diff=lfs merge=lfs -text
s3prl_s3prl_main/Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We need this to use GPUs inside the container
2
+ FROM nvidia/cuda:11.2.2-base
3
+ # Using a multi-stage build simplifies the s3prl installation
4
+ # TODO: Find a slimmer base image that also "just works"
5
+ FROM tiangolo/uvicorn-gunicorn:python3.8
6
+
7
+ RUN apt-get update --fix-missing && apt-get install -y wget \
8
+ libsndfile1 \
9
+ sox \
10
+ git \
11
+ git-lfs
12
+
13
+ RUN python -m pip install --upgrade pip
14
+ RUN python -m pip --no-cache-dir install fairseq@git+https://github.com//pytorch/fairseq.git@f2146bdc7abf293186de9449bfa2272775e39e1d#egg=fairseq
15
+ RUN python -m pip --no-cache-dir install git+https://github.com/s3prl/s3prl.git#egg=s3prl
16
+
17
+ COPY s3prl/ /app/s3prl
18
+ COPY src/ /app/src
19
+
20
+ # Setup filesystem
21
+ RUN mkdir /app/data
22
+
23
+ # Configure Git
24
+ # TODO: Create a dedicated SUPERB account for the project?
25
+ RUN git config --global user.email "[email protected]"
26
+ RUN git config --global user.name "SUPERB Admin"
27
+
28
+ # Default args for fine-tuning
29
+ ENV upstream_model osanseviero/hubert_base
30
+ ENV downstream_task asr
31
+ ENV hub huggingface
32
+ ENV hf_hub_org None
33
+ ENV push_to_hf_hub True
34
+ ENV override None
35
+
36
+ WORKDIR /app/s3prl
37
+ # Each task's config.yaml is used to set all the training parameters, but can be overridden with the `override` argument
38
+ # The results of each training run are stored in /app/s3prl/result/downstream/{downstream_task}
39
+ # and pushed to the Hugging Face Hub with name:
40
+ # Default behaviour - {hf_hub_username}/superb-s3prl-{upstream_model}-{downstream_task}-uuid
41
+ # With hf_hub_org set - {hf_hub_org}/superb-s3prl-{upstream_model}-{downstream_task}-uuid
42
+ CMD python run_downstream.py -n ${downstream_task} -m train -u ${upstream_model} -d ${downstream_task} --hub ${hub} --hf_hub_org ${hf_hub_org} --push_to_hf_hub ${push_to_hf_hub} --override ${override}
s3prl_s3prl_main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2022 Andy T. Liu (Ting-Wei Liu) and Shu-wen (Leo) Yang
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
s3prl_s3prl_main/README.md ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://raw.githubusercontent.com/s3prl/s3prl/main/file/S3PRL-logo.png" width="900"/>
3
+ <br>
4
+ <br>
5
+ <a href="./LICENSE.txt"><img alt="Apache License 2.0" src="https://raw.githubusercontent.com/s3prl/s3prl/main/file/license.svg" /></a>
6
+ <a href="https://creativecommons.org/licenses/by-nc/4.0/"><img alt="CC_BY_NC License" src="https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg" /></a>
7
+ <a href="https://github.com/s3prl/s3prl/actions/workflows/ci.yml"><img alt="CI" src="https://github.com/s3prl/s3prl/actions/workflows/ci.yml/badge.svg?branch=main&event=push"></a>
8
+ <a href="#development-pattern-for-contributors"><img alt="Codecov" src="https://img.shields.io/badge/contributions-welcome-brightgreen.svg"></a>
9
+ <a href="https://github.com/s3prl/s3prl/issues"><img alt="Bitbucket open issues" src="https://img.shields.io/github/issues/s3prl/s3prl"></a>
10
+ </p>
11
+
12
+ ## Contact
13
+
14
+ We prefer to have discussions directly on Github issue page, so that all the information is transparent to all the contributors and is auto-archived on the Github.
15
+ If you wish to use email, please contact:
16
+
17
+ - [Shu-wen (Leo) Yang](https://leo19941227.github.io/) ([email protected])
18
+ - [Andy T. Liu](https://andi611.github.io/) ([email protected])
19
+
20
+ Please refer to the [legacy citation](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=R1mNI8QAAAAJ&citation_for_view=R1mNI8QAAAAJ:LkGwnXOMwfcC) of S3PRL and the timeline below, which justify our initiative on this project. This information is used to protect us from half-truths. We encourage to cite the individual papers most related to the function you are using to give fair credit to the developer of the function. You can find the names in the [Change Log](#change-log). Finally, we would like to thank our advisor, [Prof. Hung-yi Lee](https://speech.ee.ntu.edu.tw/~hylee/index.php), for his advice. The project would be impossible without his support.
21
+
22
+ If you have any question (e.g., about who came up with / developed which ideas / functions or how the project started), feel free to engage in an open and responsible conversation on the GitHub issue page, and we'll be happy to help!
23
+
24
+ ## Contribution (pull request)
25
+
26
+ **Guideline**
27
+
28
+ - Starting in 2024, we will only accept new contributions in the form of new upstream models, so we can save bandwidth for developing new techniques (which will not be in S3PRL.)
29
+ - S3PRL has transitioned into pure maintenance mode, ensuring the long-term maintenance of all existing functions.
30
+ - Reporting bugs or the PR fixing the bugs is always welcome! Thanks!
31
+
32
+ **Tutorials**
33
+
34
+ - [General tutorial](https://s3prl.github.io/s3prl/contribute/general.html)
35
+ - [Tutorial for adding new upstream models](https://s3prl.github.io/s3prl/contribute/upstream.html)
36
+
37
+ ## Environment compatibilities [![CI](https://github.com/s3prl/s3prl/actions/workflows/ci.yml/badge.svg?branch=main&event=push)](https://github.com/s3prl/s3prl/actions/workflows/ci.yml)
38
+
39
+ We support the following environments. The test cases are ran with **[tox](./tox.ini)** locally and on **[github action](.github/workflows/ci.yml)**:
40
+
41
+ | Env | versions |
42
+ | --- | --- |
43
+ | os | `ubuntu-18.04`, `ubuntu-20.04` |
44
+ | python | `3.7`, `3.8`, `3.9`, `3.10` |
45
+ | pytorch | `1.8.1`, `1.9.1`, `1.10.2`, `1.11.0`, `1.12.1` , `1.13.1` , `2.0.1` , `2.1.0` |
46
+
47
+ ## Star History
48
+
49
+ [![Star History Chart](https://api.star-history.com/svg?repos=s3prl/s3prl&type=Date)](https://star-history.com/#s3prl/s3prl&Date)
50
+
51
+ ## Change Log
52
+
53
+ > We only list the major contributors here for conciseness. However, we are deeply grateful for all the contributions. Please see the [Contributors](https://github.com/s3prl/s3prl/graphs/contributors) page for the full list.
54
+
55
+ * *Sep 2024*: Support MS-HuBERT (see [MS-HuBERT](https://arxiv.org/pdf/2406.05661))
56
+ * *Dec 2023*: Support Multi-resolution HuBERT (MR-HuBERT, see [Multiresolution HuBERT](https://arxiv.org/pdf/2310.02720.pdf))
57
+ * *Oct 2023*: Support ESPnet pre-trained upstream models (see [ESPnet HuBERT](https://arxiv.org/abs/2306.06672) and [WavLabLM](https://arxiv.org/abs/2309.15317))
58
+ * *Sep 2022*: In [JSALT 2022](https://jsalt-2022-ssl.github.io/member), We upgrade the codebase to support testing, documentation and a new [S3PRL PyPI package](https://pypi.org/project/s3prl/) for easy installation and usage for upstream models. See our [online doc](https://s3prl.github.io/s3prl/) for more information. The package is now used by many [open-source projects](https://github.com/s3prl/s3prl/network/dependents), including [ESPNet](https://github.com/espnet/espnet/blob/master/espnet2/asr/frontend/s3prl.py). Contributors: [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)***, [Andy T. Liu](https://andi611.github.io/) ***(NTU)***, [Heng-Jui Chang](https://people.csail.mit.edu/hengjui/) ***(MIT)***, [Haibin Wu](https://hbwu-ntu.github.io/) ***(NTU)*** and [Xuankai Chang](https://www.xuankaic.com/) ***(CMU)***.
59
+ * *Mar 2022*: Introduce [**SUPERB-SG**](https://arxiv.org/abs/2203.06849), see [Speech Translation](./s3prl/downstream/speech_translation) by [Hsiang-Sheng Tsai](https://github.com/bearhsiang) ***(NTU)***, [Out-of-domain ASR](./s3prl/downstream/ctc/) by [Heng-Jui Chang](https://people.csail.mit.edu/hengjui/) ***(NTU)***, [Voice Conversion](./s3prl/downstream/a2o-vc-vcc2020/) by [Wen-Chin Huang](https://unilight.github.io/) ***(Nagoya)***, [Speech Separation](./s3prl/downstream/separation_stft/) and [Speech Enhancement](./s3prl/downstream/enhancement_stft/) by [Zili Huang](https://scholar.google.com/citations?user=iQ-S0fQAAAAJ&hl=en) ***(JHU)*** for more info.
60
+ * *Mar 2022*: Introduce [**SSL for SE/SS**](https://arxiv.org/abs/2203.07960) by [Zili Huang](https://scholar.google.com/citations?user=iQ-S0fQAAAAJ&hl=en) ***(JHU)***. See [SE1](https://github.com/s3prl/s3prl/tree/main/s3prl/downstream/enhancement_stft) and [SS1](https://github.com/s3prl/s3prl/tree/main/s3prl/downstream/separation_stft) folders for more details. Note that the improved performances can be achieved by the later introduced [SE2](https://github.com/s3prl/s3prl/tree/main/s3prl/downstream/enhancement_stft2) and [SS2](https://github.com/s3prl/s3prl/tree/main/s3prl/downstream/separation_stft2). However, for aligning with [SUPERB-SG](https://arxiv.org/abs/2203.06849) benchmarking, please use the version 1.
61
+ * *Nov 2021*: Introduce [**S3PRL-VC**](https://arxiv.org/abs/2110.06280) by [Wen-Chin Huang](https://unilight.github.io/) ***(Nagoya)***, see [Any-to-one](https://github.com/s3prl/s3prl/tree/master/s3prl/downstream/a2o-vc-vcc2020) for more info. We highly recommend to consider the [newly released official repo of S3PRL-VC](https://github.com/unilight/s3prl-vc) which is developed and actively maintained by [Wen-Chin Huang](https://unilight.github.io/). The standalone repo contains much more recepies for the VC experiments. In S3PRL we only include the Any-to-one recipe for reproducing the SUPERB results.
62
+ * *Oct 2021*: Support [**DistilHuBERT**](https://arxiv.org/abs/2110.01900) by [Heng-Jui Chang](https://people.csail.mit.edu/hengjui/) ***(NTU)***, see [docs](./s3prl/upstream/distiller/README.md) for more info.
63
+ * *Sep 2021:* We host a *challenge* in [*AAAI workshop: The 2nd Self-supervised Learning for Audio and Speech Processing*](https://aaai-sas-2022.github.io/)! See [**SUPERB official site**](https://superbbenchmark.org/) for the challenge details and the [**SUPERB documentation**](./s3prl/downstream/docs/superb.md) in this toolkit!
64
+ * *Aug 2021:* [Andy T. Liu](https://andi611.github.io/) ***(NTU)*** and [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)*** introduces the S3PRL toolkit in [MLSS 2021](https://ai.ntu.edu.tw/%e7%b7%9a%e4%b8%8a%e5%ad%b8%e7%bf%92-2/mlss-2021/), you can also **[watch it on Youtube](https://youtu.be/PkMFnS6cjAc)**!
65
+ * *Aug 2021:* [**TERA**](https://ieeexplore.ieee.org/document/9478264) by [Andy T. Liu](https://andi611.github.io/) ***(NTU)*** is accepted to TASLP!
66
+ * *July 2021:* We are now working on packaging s3prl and reorganizing the file structure in **v0.3**. Please consider using the stable **v0.2.0** for now. We will test and release **v0.3** before August.
67
+ * *June 2021:* Support [**SUPERB:** **S**peech processing **U**niversal **PER**formance **B**enchmark](https://arxiv.org/abs/2105.01051), submitted to Interspeech 2021. Use the tag **superb-interspeech2021** or **v0.2.0**. Contributors: [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)***, [Pohan Chi](https://scholar.google.com/citations?user=SiyicoEAAAAJ&hl=zh-TW) ***(NTU)***, [Yist Lin](https://scholar.google.com/citations?user=0lrZq9MAAAAJ&hl=en) ***(NTU)***, [Yung-Sung Chuang](https://scholar.google.com/citations?user=3ar1DOwAAAAJ&hl=zh-TW) ***(NTU)***, [Jiatong Shi](https://scholar.google.com/citations?user=FEDNbgkAAAAJ&hl=en) ***(CMU)***, [Xuankai Chang](https://www.xuankaic.com/) ***(CMU)***, [Wei-Cheng Tseng](https://scholar.google.com.tw/citations?user=-d6aNP0AAAAJ&hl=zh-TW) ***(NTU)***, Tzu-Hsien Huang ***(NTU)*** and [Kushal Lakhotia](https://scholar.google.com/citations?user=w9W6zXUAAAAJ&hl=en) ***(Meta)***.
68
+ * *June 2021:* Support extracting multiple hidden states for all the SSL pretrained models by [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)***.
69
+ * *Jan 2021:* Readme updated with detailed instructions on how to use our latest version!
70
+ * *Dec 2020:* We are migrating to a newer version for a more general, flexible, and scalable code. See the introduction below for more information! The legacy version can be accessed the tag **v0.1.0**.
71
+ * *Oct 2020:* [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)*** and [Andy T. Liu](https://andi611.github.io/) ***(NTU)*** added varioius classic upstream models, including PASE+, APC, VQ-APC, NPC, wav2vec, vq-wav2vec ...etc.
72
+ * *Oct 2019:* The birth of S3PRL! The repository was created for the [**Mockingjay**](https://arxiv.org/abs/1910.12638) development. [Andy T. Liu](https://andi611.github.io/) ***(NTU)***, [Shu-wen Yang](https://leo19941227.github.io/) ***(NTU)*** and [Pohan Chi](https://scholar.google.com/citations?user=SiyicoEAAAAJ&hl=zh-TW) ***(NTU)*** implemented the pre-training scripts and several simple downstream evaluation tasks. This work was the very start of the S3PRL project which established lots of foundamental modules and coding styles. Feel free to checkout to the old commits to explore our [legacy codebase](https://github.com/s3prl/s3prl/tree/6a53ee92bffeaa75fc2fb56071050bcf71e93785)!
73
+
74
+ ****
75
+
76
+ ## Introduction and Usages
77
+
78
+ This is an open source toolkit called **s3prl**, which stands for **S**elf-**S**upervised **S**peech **P**re-training and **R**epresentation **L**earning.
79
+ Self-supervised speech pre-trained models are called **upstream** in this toolkit, and are utilized in various **downstream** tasks.
80
+
81
+ The toolkit has **three major usages**:
82
+
83
+ ### Pretrain
84
+
85
+ - Pretrain upstream models, including Mockingjay, Audio ALBERT and TERA.
86
+ - Document: [**pretrain/README.md**](./s3prl/pretrain/README.md)
87
+
88
+ ### Upstream
89
+
90
+ - Easily load most of the existing upstream models with pretrained weights in a unified I/O interface.
91
+ - Pretrained models are registered through **torch.hub**, which means you can use these models in your own project by one-line plug-and-play without depending on this toolkit's coding style.
92
+ - Document: [**upstream/README.md**](./s3prl/upstream/README.md)
93
+
94
+ ### Downstream
95
+
96
+ - Utilize upstream models in lots of downstream tasks
97
+ - Benchmark upstream models with [**SUPERB Benchmark**](./s3prl/downstream/docs/superb.md)
98
+ - Document: [**downstream/README.md**](./s3prl/downstream/README.md)
99
+
100
+ ---
101
+
102
+ Here is a high-level illustration of how S3PRL might help you. We support to leverage numerous SSL representations on numerous speech processing tasks in our [GitHub codebase](https://github.com/s3prl/s3prl):
103
+
104
+ ![interface](file/S3PRL-interface.png)
105
+
106
+ ---
107
+
108
+ We also modularize all the SSL models into a standalone [PyPi package](https://pypi.org/project/s3prl/) so that you can easily install it and use it without depending on our entire codebase. The following shows a simple example and you can find more details in our [documentation](https://s3prl.github.io/s3prl/).
109
+
110
+ 1. Install the S3PRL package:
111
+
112
+ ```sh
113
+ pip install s3prl
114
+ ```
115
+
116
+ 2. Use it to extract representations for your own audio:
117
+
118
+ ```python
119
+ import torch
120
+ from s3prl.nn import S3PRLUpstream
121
+
122
+ model = S3PRLUpstream("hubert")
123
+ model.eval()
124
+
125
+ with torch.no_grad():
126
+ wavs = torch.randn(2, 16000 * 2)
127
+ wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
128
+ all_hs, all_hs_len = model(wavs, wavs_len)
129
+
130
+ for hs, hs_len in zip(all_hs, all_hs_len):
131
+ assert isinstance(hs, torch.FloatTensor)
132
+ assert isinstance(hs_len, torch.LongTensor)
133
+
134
+ batch_size, max_seq_len, hidden_size = hs.shape
135
+ assert hs_len.dim() == 1
136
+ ```
137
+
138
+ ---
139
+
140
+ With this modularization, we have achieved close integration with the general speech processing toolkit [ESPNet](https://github.com/espnet/espnet), enabling the use of SSL models for a broader range of speech processing tasks and corpora to achieve state-of-the-art (SOTA) results (kudos to the [ESPNet Team](https://www.wavlab.org/open_source)):
141
+
142
+ ![integration](file/S3PRL-integration.png)
143
+
144
+ You can start the journey of SSL with the following entry points:
145
+
146
+ - S3PRL: [A simple SUPERB downstream task](https://github.com/s3prl/s3prl/blob/main/s3prl/downstream/docs/superb.md#pr-phoneme-recognition)
147
+ - ESPNet: [Levearging S3PRL for ASR](https://github.com/espnet/espnet/tree/master/egs2/librispeech/asr1#self-supervised-learning-features-hubert_large_ll60k-conformer-utt_mvn-with-transformer-lm)
148
+
149
+ ---
150
+
151
+ Feel free to use or modify our toolkit in your research. Here is a [list of papers using our toolkit](#used-by). Any question, bug report or improvement suggestion is welcome through [opening up a new issue](https://github.com/s3prl/s3prl/issues).
152
+
153
+ If you find this toolkit helpful to your research, please do consider citing [our papers](#citation), thanks!
154
+
155
+ ## Installation
156
+
157
+ 1. **Python** >= 3.6
158
+ 2. Install **sox** on your OS
159
+ 3. Install s3prl: [Read doc](https://s3prl.github.io/s3prl/tutorial/installation.html#) or `pip install -e ".[all]"`
160
+ 4. (Optional) Some upstream models require special dependencies. If you encounter error with a specific upstream model, you can look into the `README.md` under each `upstream` folder. E.g., `upstream/pase/README.md`=
161
+
162
+ ## Reference Repositories
163
+
164
+ * [Pytorch](https://github.com/pytorch/pytorch), Pytorch.
165
+ * [Audio](https://github.com/pytorch/audio), Pytorch.
166
+ * [Kaldi](https://github.com/kaldi-asr/kaldi), Kaldi-ASR.
167
+ * [Transformers](https://github.com/huggingface/transformers), Hugging Face.
168
+ * [PyTorch-Kaldi](https://github.com/mravanelli/pytorch-kaldi), Mirco Ravanelli.
169
+ * [fairseq](https://github.com/pytorch/fairseq), Facebook AI Research.
170
+ * [CPC](https://github.com/facebookresearch/CPC_audio), Facebook AI Research.
171
+ * [APC](https://github.com/iamyuanchung/Autoregressive-Predictive-Coding), Yu-An Chung.
172
+ * [VQ-APC](https://github.com/s3prl/VQ-APC), Yu-An Chung.
173
+ * [NPC](https://github.com/Alexander-H-Liu/NPC), Alexander-H-Liu.
174
+ * [End-to-end-ASR-Pytorch](https://github.com/Alexander-H-Liu/End-to-end-ASR-Pytorch), Alexander-H-Liu
175
+ * [Mockingjay](https://github.com/andi611/Mockingjay-Speech-Representation), Andy T. Liu.
176
+ * [ESPnet](https://github.com/espnet/espnet), Shinji Watanabe
177
+ * [speech-representations](https://github.com/awslabs/speech-representations), aws lab
178
+ * [PASE](https://github.com/santi-pdp/pase), Santiago Pascual and Mirco Ravanelli
179
+ * [LibriMix](https://github.com/JorisCos/LibriMix), Joris Cosentino and Manuel Pariente
180
+
181
+ ## License
182
+
183
+ The majority of S3PRL Toolkit is licensed under the Apache License version 2.0, however all the files authored by Facebook, Inc. (which have explicit copyright statement on the top) are licensed under CC-BY-NC.
184
+
185
+ ## Used by
186
+ <details><summary>List of papers that used our toolkit (Feel free to add your own paper by making a pull request)</summary><p>
187
+
188
+ ### Self-Supervised Pretraining
189
+
190
+ + [Mockingjay: Unsupervised Speech Representation Learning with Deep Bidirectional Transformer Encoders (Liu et al., 2020)](https://arxiv.org/abs/1910.12638)
191
+ ```
192
+ @article{mockingjay,
193
+ title={Mockingjay: Unsupervised Speech Representation Learning with Deep Bidirectional Transformer Encoders},
194
+ ISBN={9781509066315},
195
+ url={http://dx.doi.org/10.1109/ICASSP40776.2020.9054458},
196
+ DOI={10.1109/icassp40776.2020.9054458},
197
+ journal={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
198
+ publisher={IEEE},
199
+ author={Liu, Andy T. and Yang, Shu-wen and Chi, Po-Han and Hsu, Po-chun and Lee, Hung-yi},
200
+ year={2020},
201
+ month={May}
202
+ }
203
+ ```
204
+ + [TERA: Self-Supervised Learning of Transformer Encoder Representation for Speech (Liu et al., 2020)](https://arxiv.org/abs/2007.06028)
205
+ ```
206
+ @misc{tera,
207
+ title={TERA: Self-Supervised Learning of Transformer Encoder Representation for Speech},
208
+ author={Andy T. Liu and Shang-Wen Li and Hung-yi Lee},
209
+ year={2020},
210
+ eprint={2007.06028},
211
+ archivePrefix={arXiv},
212
+ primaryClass={eess.AS}
213
+ }
214
+ ```
215
+ + [Audio ALBERT: A Lite BERT for Self-supervised Learning of Audio Representation (Chi et al., 2020)](https://arxiv.org/abs/2005.08575)
216
+ ```
217
+ @inproceedings{audio_albert,
218
+ title={Audio ALBERT: A Lite BERT for Self-supervised Learning of Audio Representation},
219
+ author={Po-Han Chi and Pei-Hung Chung and Tsung-Han Wu and Chun-Cheng Hsieh and Shang-Wen Li and Hung-yi Lee},
220
+ year={2020},
221
+ booktitle={SLT 2020},
222
+ }
223
+ ```
224
+
225
+ ### Explanability
226
+
227
+ + [Understanding Self-Attention of Self-Supervised Audio Transformers (Yang et al., 2020)](https://arxiv.org/abs/2006.03265)
228
+ ```
229
+ @inproceedings{understanding_sat,
230
+ author={Shu-wen Yang and Andy T. Liu and Hung-yi Lee},
231
+ title={{Understanding Self-Attention of Self-Supervised Audio Transformers}},
232
+ year=2020,
233
+ booktitle={Proc. Interspeech 2020},
234
+ pages={3785--3789},
235
+ doi={10.21437/Interspeech.2020-2231},
236
+ url={http://dx.doi.org/10.21437/Interspeech.2020-2231}
237
+ }
238
+ ```
239
+
240
+ ### Adversarial Attack
241
+
242
+ + [Defense for Black-box Attacks on Anti-spoofing Models by Self-Supervised Learning (Wu et al., 2020)](https://arxiv.org/abs/2006.03214), code for computing LNSR: [utility/observe_lnsr.py](https://github.com/s3prl/s3prl/blob/master/utility/observe_lnsr.py)
243
+ ```
244
+ @inproceedings{mockingjay_defense,
245
+ author={Haibin Wu and Andy T. Liu and Hung-yi Lee},
246
+ title={{Defense for Black-Box Attacks on Anti-Spoofing Models by Self-Supervised Learning}},
247
+ year=2020,
248
+ booktitle={Proc. Interspeech 2020},
249
+ pages={3780--3784},
250
+ doi={10.21437/Interspeech.2020-2026},
251
+ url={http://dx.doi.org/10.21437/Interspeech.2020-2026}
252
+ }
253
+ ```
254
+
255
+ + [Adversarial Defense for Automatic Speaker Verification by Cascaded Self-Supervised Learning Models (Wu et al., 2021)](https://arxiv.org/abs/2102.07047)
256
+ ```
257
+ @misc{asv_ssl,
258
+ title={Adversarial defense for automatic speaker verification by cascaded self-supervised learning models},
259
+ author={Haibin Wu and Xu Li and Andy T. Liu and Zhiyong Wu and Helen Meng and Hung-yi Lee},
260
+ year={2021},
261
+ eprint={2102.07047},
262
+ archivePrefix={arXiv},
263
+ primaryClass={eess.AS}
264
+ ```
265
+
266
+ ### Voice Conversion
267
+
268
+ + [S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations (Lin et al., 2021)](https://arxiv.org/abs/2104.02901)
269
+ ```
270
+ @misc{s2vc,
271
+ title={S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations},
272
+ author={Jheng-hao Lin and Yist Y. Lin and Chung-Ming Chien and Hung-yi Lee},
273
+ year={2021},
274
+ eprint={2104.02901},
275
+ archivePrefix={arXiv},
276
+ primaryClass={eess.AS}
277
+ }
278
+ ```
279
+
280
+ ### Benchmark and Evaluation
281
+
282
+ + [SUPERB: Speech processing Universal PERformance Benchmark (Yang et al., 2021)](https://arxiv.org/abs/2105.01051)
283
+ ```
284
+ @misc{superb,
285
+ title={SUPERB: Speech processing Universal PERformance Benchmark},
286
+ author={Shu-wen Yang and Po-Han Chi and Yung-Sung Chuang and Cheng-I Jeff Lai and Kushal Lakhotia and Yist Y. Lin and Andy T. Liu and Jiatong Shi and Xuankai Chang and Guan-Ting Lin and Tzu-Hsien Huang and Wei-Cheng Tseng and Ko-tik Lee and Da-Rong Liu and Zili Huang and Shuyan Dong and Shang-Wen Li and Shinji Watanabe and Abdelrahman Mohamed and Hung-yi Lee},
287
+ year={2021},
288
+ eprint={2105.01051},
289
+ archivePrefix={arXiv},
290
+ primaryClass={cs.CL}
291
+ }
292
+ ```
293
+
294
+ + [Utilizing Self-supervised Representations for MOS Prediction (Tseng et al., 2021)](https://arxiv.org/abs/2104.03017)
295
+ ```
296
+ @misc{ssr_mos,
297
+ title={Utilizing Self-supervised Representations for MOS Prediction},
298
+ author={Wei-Cheng Tseng and Chien-yu Huang and Wei-Tsung Kao and Yist Y. Lin and Hung-yi Lee},
299
+ year={2021},
300
+ eprint={2104.03017},
301
+ archivePrefix={arXiv},
302
+ primaryClass={eess.AS}
303
+ }
304
+ ```
305
+ }
306
+
307
+ </p></details>
308
+
309
+ ## Citation
310
+
311
+ If you find this toolkit useful, please consider citing following papers.
312
+
313
+ - If you use our pre-training scripts, or the downstream tasks considered in *TERA* and *Mockingjay*, please consider citing the following:
314
+ ```
315
+ @misc{tera,
316
+ title={TERA: Self-Supervised Learning of Transformer Encoder Representation for Speech},
317
+ author={Andy T. Liu and Shang-Wen Li and Hung-yi Lee},
318
+ year={2020},
319
+ eprint={2007.06028},
320
+ archivePrefix={arXiv},
321
+ primaryClass={eess.AS}
322
+ }
323
+ ```
324
+ ```
325
+ @article{mockingjay,
326
+ title={Mockingjay: Unsupervised Speech Representation Learning with Deep Bidirectional Transformer Encoders},
327
+ ISBN={9781509066315},
328
+ url={http://dx.doi.org/10.1109/ICASSP40776.2020.9054458},
329
+ DOI={10.1109/icassp40776.2020.9054458},
330
+ journal={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
331
+ publisher={IEEE},
332
+ author={Liu, Andy T. and Yang, Shu-wen and Chi, Po-Han and Hsu, Po-chun and Lee, Hung-yi},
333
+ year={2020},
334
+ month={May}
335
+ }
336
+ ```
337
+
338
+ - If you use our organized upstream interface and features, or the *SUPERB* downstream benchmark, please consider citing the following:
339
+ ```
340
+ @article{yang2024large,
341
+ title={A Large-Scale Evaluation of Speech Foundation Models},
342
+ author={Yang, Shu-wen and Chang, Heng-Jui and Huang, Zili and Liu, Andy T and Lai, Cheng-I and Wu, Haibin and Shi, Jiatong and Chang, Xuankai and Tsai, Hsiang-Sheng and Huang, Wen-Chin and others},
343
+ journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
344
+ year={2024},
345
+ publisher={IEEE}
346
+ }
347
+ ```
348
+ ```
349
+ @inproceedings{yang21c_interspeech,
350
+ author={Shu-wen Yang and Po-Han Chi and Yung-Sung Chuang and Cheng-I Jeff Lai and Kushal Lakhotia and Yist Y. Lin and Andy T. Liu and Jiatong Shi and Xuankai Chang and Guan-Ting Lin and Tzu-Hsien Huang and Wei-Cheng Tseng and Ko-tik Lee and Da-Rong Liu and Zili Huang and Shuyan Dong and Shang-Wen Li and Shinji Watanabe and Abdelrahman Mohamed and Hung-yi Lee},
351
+ title={{SUPERB: Speech Processing Universal PERformance Benchmark}},
352
+ year=2021,
353
+ booktitle={Proc. Interspeech 2021},
354
+ pages={1194--1198},
355
+ doi={10.21437/Interspeech.2021-1775}
356
+ }
357
+ ```
s3prl_s3prl_main/__pycache__/hubconf.cpython-310.pyc ADDED
Binary file (248 Bytes). View file
 
s3prl_s3prl_main/__pycache__/hubconf.cpython-39.pyc ADDED
Binary file (244 Bytes). View file
 
s3prl_s3prl_main/ci/format.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from subprocess import CalledProcessError, check_output
6
+
7
+
8
+ def load_valid_paths():
9
+ with open("./valid_paths.txt", "r") as fp:
10
+ paths = [line.strip() for line in fp if line.strip() != ""]
11
+ return paths
12
+
13
+
14
+ def get_third_party():
15
+ txt_files = list(Path("./requirements").rglob("*.txt"))
16
+ package_list = []
17
+ for file in txt_files:
18
+ with open(file, "r") as fp:
19
+ for line in fp:
20
+ line = line.strip()
21
+ if line == "":
22
+ continue
23
+ package_list.append(line.split(" ")[0])
24
+ return package_list
25
+
26
+
27
+ def run_command(command: str):
28
+ try:
29
+ check_output(command.split(" "))
30
+ except CalledProcessError as e:
31
+ print(e.output.decode("utf-8"))
32
+ raise
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument(
38
+ "files",
39
+ type=str,
40
+ nargs="*",
41
+ default=[],
42
+ help="If no file is given, use the files under ./valid_paths.txt",
43
+ )
44
+ parser.add_argument("--check", action="store_true", help="Only checks the files")
45
+ args = parser.parse_args()
46
+
47
+ if len(args.files) == 0:
48
+ args.files = load_valid_paths()
49
+
50
+ print(f"Formatting files: {args.files}")
51
+ args.files = " ".join(args.files)
52
+
53
+ print("Run flake8")
54
+ # stop the build if there are Python syntax errors or undefined names
55
+ run_command(
56
+ f"flake8 {args.files} --count --select=E9,F63,F7,F82 --show-source --statistics"
57
+ )
58
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
59
+ run_command(
60
+ f"flake8 {args.files} --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics"
61
+ )
62
+
63
+ print("Run black")
64
+ if args.check:
65
+ run_command(f"black --check {args.files}")
66
+ else:
67
+ run_command(f"black {args.files}")
68
+
69
+ print("Run isort")
70
+ third_party = get_third_party()
71
+ third_party = ",".join(third_party)
72
+ if args.check:
73
+ run_command(
74
+ f"isort --profile black --thirdparty {third_party} --check {args.files}"
75
+ )
76
+ else:
77
+ run_command(f"isort --profile black --thirdparty {third_party} {args.files}")
78
+
79
+ if args.check:
80
+ print("Successfully passed the format check!")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
s3prl_s3prl_main/docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
s3prl_s3prl_main/docs/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Documentation
2
+
3
+ To auto-generate documents for S3PRL, please follow the following steps:
4
+
5
+ 1. Activate an python env for the doc-generating tool to import all the modules to realize auto-documentation. There are various ways to achieve this. You can also follow:
6
+
7
+ ```sh
8
+ conda create -y -n doc python=3.8
9
+ conda activate doc
10
+
11
+ cd $S3PRL_ROOT
12
+ pip install ".[dev]"
13
+ ```
14
+
15
+ 2. Auto-generate HTML files for all the packages, modules and their submodules listed in `$S3PRL_ROOT/valid_paths.txt`. The HTML files will appear in `$S3PRL_ROOT/docs/build/html`
16
+
17
+ ```sh
18
+ cd $S3PRL_ROOT/docs_src
19
+ ./rebuild_docs.sh
20
+ ```
21
+
22
+ 3. Launch the simple webserver to see the documentation.
23
+
24
+ ```sh
25
+ cd $S3PRL_ROOT/docs/build/html
26
+ python3 -m http.server
27
+ ```
s3prl_s3prl_main/docs/from_scratch_tutorial.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # REDAME
2
+ ## install sphinx
3
+ ```shell
4
+ $ pip install -U sphinx
5
+ ```
6
+
7
+ ## set up
8
+ at s3prl root directory (s3prl-private/)
9
+ ```shell
10
+ $ mkdir docs
11
+ $ cd docs
12
+ $ sphinx-quickstart
13
+ $ cd ..
14
+ $ sphinx-apidoc -d 3 --separate --implicit-namespace -o docs ./s3prl s3prl/downstream s3prl/interface s3prl/preprocess s3prl/pretrain s3prl/problem s3prl/sampler s3prl/submit s3prl/superb s3prl/upstream s3prl/utility s3prl/wrapper s3prl/__init__.py s3prl/hub.py s3prl/optimizers.py s3prl/run_downstream.py s3prl/run_pretrain.py s3prl/run_while.sh s3prl/schedulers.py
15
+ ```
16
+
17
+ ## install dependencies
18
+ ```shell
19
+ $ cd s3prl-private/docs
20
+ $ echo "furo
21
+ torch
22
+ numpy
23
+ pandas
24
+ tqdm
25
+ pytorch_lightning
26
+ matplotlib
27
+ ipdb>=0.13.9
28
+ PyYAML
29
+ transformers
30
+ torchaudio
31
+ gdown
32
+ sklearn
33
+ joblib
34
+ tensorboardX
35
+ librosa
36
+ scipy
37
+ lxml
38
+ h5py
39
+ dtw
40
+ catalyst
41
+ sox
42
+ six
43
+ easydict
44
+ Resemblyzer
45
+ sentencepiece
46
+ pysoundfile
47
+ asteroid
48
+ sacrebleu
49
+ speechbrain
50
+ omegaconf
51
+ editdistance" > requirement.txt
52
+ $ pip install -r requirement.txt
53
+ ```
54
+
55
+ add custom.js at s3prl-private/docs/_static/js/
56
+ (just paste the following lines)
57
+ ```javascript=
58
+ /*
59
+ change the default sphinx.ext.linkcode's [source] to [Github]
60
+ */
61
+ document.querySelectorAll(".reference.external .viewcode-link .pre").forEach(item => {
62
+ item.innerHTML = "[Github]"
63
+ item.style.marginRight = "3px"
64
+ })
65
+
66
+ ```
67
+
68
+ modify s3prl-private/docs/index.rst
69
+ ```diff
70
+ # remove these lines
71
+ - .. toctree::
72
+ - :maxdepth: 2
73
+ - :caption: Contents:
74
+
75
+ # replaced with this line
76
+ + .. include:: s3prl.rst
77
+ ```
78
+
79
+ modify s3prl-private/docs/conf.py
80
+ ```python
81
+ # add these lines at top
82
+ import inspect
83
+ import os
84
+ import sys
85
+ for x in os.walk('..'):
86
+ sys.path.insert(0, x[0])
87
+
88
+ # add extensions
89
+ extensions = [
90
+ 'sphinx.ext.autodoc',
91
+ 'sphinx.ext.napoleon',
92
+ 'sphinx.ext.viewcode',
93
+ 'sphinx.ext.linkcode'
94
+ ]
95
+
96
+ html_js_files = [
97
+ 'js/custom.js',
98
+ ]
99
+
100
+ def linkcode_resolve(domain, info):
101
+ def find_source():
102
+ obj = sys.modules[info['module']]
103
+ if info['fullname'] == 'InitConfig.args': return None
104
+ if info['fullname'] == 'InitConfig.kwargs': return None
105
+ for part in info['fullname'].split('.'):
106
+ obj = getattr(obj, part)
107
+
108
+ if isinstance(obj, property): return None
109
+
110
+ fn = inspect.getsourcefile(obj)
111
+ fn = os.path.relpath(fn, start=os.path.dirname(os.path.abspath(__file__))[:-4])
112
+
113
+ source, lineno = inspect.getsourcelines(obj)
114
+ return fn, lineno, lineno + len(source) - 1
115
+
116
+ if domain != 'py' or not info['module']: return None
117
+
118
+ tag = 'master' if 'dev' in release else ('v' + release) # s3prl github version
119
+
120
+ try:
121
+ filename = '%s#L%d-L%d' % find_source() # specify file page with line number
122
+ except Exception:
123
+ filename = info['module'].replace('.', '/') + '.py' # cannot find corresponding codeblock, use the file page instead
124
+
125
+ return "https://github.com/s3prl/s3prl-private/blob/%s/%s" % (tag, filename)
126
+ ```
127
+
128
+ to use the furo theme, add this line in s3prl-private/docs/conf.py (replace the original alabaster theme)
129
+ ```python
130
+ html_theme = "furo"
131
+ ```
132
+
133
+ ## generate html files
134
+ at s3prl-private/docs/
135
+ ```shell
136
+ $ make html
137
+ ```
138
+
139
+ the html files will be generated at **s3prl-private/docs/_build/html/**
140
+ click on **index.html** to view the doc page on your browser
141
+
142
+ if you want to see how your modified codes looks like, simply do
143
+ ```shell
144
+ $ make clean html # this remove the old html files
145
+ $ make html # generate new html files
146
+ ```
s3prl_s3prl_main/docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
s3prl_s3prl_main/docs/rebuild_docs.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ rm -rf ./source/s3prl.*
4
+ rm -rf ./source/_autosummary
5
+
6
+ make clean html
7
+
8
+ touch build/html/.nojekyll
9
+
s3prl_s3prl_main/docs/source/_static/css/custom.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .toctree-wrapper.compound .caption-text {
2
+ display: none;
3
+ }
s3prl_s3prl_main/docs/source/_static/js/custom.js ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ /*
2
+ change the default sphinx.ext.linkcode's [source] to [Github]
3
+ */
4
+ document.querySelectorAll(".reference.external .viewcode-link .pre").forEach(item => {
5
+ item.innerHTML = "[Github]"
6
+ item.style.marginRight = "3px"
7
+ })
s3prl_s3prl_main/docs/source/_templates/custom-module-template.rst ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ name | escape | underline}}
2
+
3
+ **({{ fullname }})**
4
+
5
+ .. automodule:: {{ fullname }}
6
+
7
+ {% block modules %}
8
+ {% if modules %}
9
+
10
+ .. autosummary::
11
+ :toctree:
12
+ :template: custom-module-template.rst
13
+ :recursive:
14
+
15
+ {% for item in modules %}
16
+ {{ item }}
17
+ {%- endfor %}
18
+ {% endif %}
19
+ {% endblock %}
20
+
21
+
22
+ {% block attributes %}
23
+ {% if attributes %}
24
+ {% for item in attributes %}
25
+
26
+ {{ item }}
27
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28
+
29
+ .. autoattribute:: {{ item }}
30
+
31
+ {%- endfor %}
32
+ {% endif %}
33
+ {% endblock %}
34
+
35
+ {% block classes %}
36
+ {% if classes %}
37
+ {% for item in classes %}
38
+
39
+ {{ item }}
40
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41
+
42
+ .. autoclass:: {{ item }}
43
+ :member-order: bysource
44
+ :members:
45
+ :undoc-members:
46
+ :inherited-members: torch.nn.Module,nn.Module,Module
47
+ :show-inheritance:
48
+
49
+ {%- endfor %}
50
+ {% endif %}
51
+ {% endblock %}
52
+
53
+ {% block functions %}
54
+ {% if functions %}
55
+ {% for item in functions %}
56
+
57
+ {{ item }}
58
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59
+
60
+ .. autofunction:: {{ item }}
61
+
62
+ {%- endfor %}
63
+ {% endif %}
64
+ {% endblock %}
65
+
66
+ {% block exceptions %}
67
+ {% if exceptions %}
68
+ {% for item in exceptions %}
69
+
70
+ {{ item }}
71
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72
+
73
+ .. autoexception:: {{ item }}
74
+ :members:
75
+ :undoc-members:
76
+ :show-inheritance:
77
+ :member-order: bysource
78
+
79
+ {%- endfor %}
80
+ {% endif %}
81
+ {% endblock %}
s3prl_s3prl_main/docs/source/conf.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # This file only contains a selection of the most common options. For a full
4
+ # list see the documentation:
5
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
6
+
7
+ # -- Path setup --------------------------------------------------------------
8
+
9
+ # If extensions (or modules to document with autodoc) are in another directory,
10
+ # add these directories to sys.path here. If the directory is relative to the
11
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
12
+ #
13
+ import inspect
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ # -- Project information -----------------------------------------------------
19
+
20
+ project = "S3PRL"
21
+ copyright = "2022, S3PRL Team"
22
+ author = "S3PRL Team"
23
+
24
+ # The full version, including alpha/beta/rc tags
25
+ with (Path(__file__).parent.parent.parent / "s3prl" / "version.txt").open() as f:
26
+ release = f.read()
27
+
28
+
29
+ def linkcode_resolve(domain, info):
30
+ def find_source():
31
+ obj = sys.modules[info["module"]]
32
+ for part in info["fullname"].split("."):
33
+ obj = getattr(obj, part)
34
+
35
+ if isinstance(obj, property):
36
+ return None
37
+
38
+ file_parts = Path(inspect.getsourcefile(obj)).parts
39
+ reversed_parts = []
40
+ for part in reversed(file_parts):
41
+ if part == "s3prl":
42
+ reversed_parts.append(part)
43
+ break
44
+ else:
45
+ reversed_parts.append(part)
46
+ fn = "/".join(reversed(reversed_parts))
47
+
48
+ source, lineno = inspect.getsourcelines(obj)
49
+ return fn, lineno, lineno + len(source) - 1
50
+
51
+ if domain != "py" or not info["module"]:
52
+ return None
53
+
54
+ tag = "master" if "dev" in release else ("v" + release) # s3prl github version
55
+
56
+ try:
57
+ filename = "%s#L%d-L%d" % find_source() # specify file page with line number
58
+ except Exception:
59
+ filename = (
60
+ info["module"].replace(".", "/") + ".py"
61
+ ) # cannot find corresponding codeblock, use the file page instead
62
+
63
+ return "https://github.com/s3prl/s3prl/blob/%s/%s" % (tag, filename)
64
+
65
+
66
+ # -- General configuration ---------------------------------------------------
67
+
68
+ # Add any Sphinx extension module names here, as strings. They can be
69
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
70
+ # ones.
71
+ # add extensions
72
+ extensions = [
73
+ "sphinx.ext.autodoc",
74
+ "sphinx.ext.napoleon",
75
+ "sphinx.ext.viewcode",
76
+ "sphinx.ext.linkcode",
77
+ "sphinx.ext.autosummary",
78
+ ]
79
+
80
+ autosummary_generate = True
81
+
82
+ html_js_files = [
83
+ "js/custom.js",
84
+ ]
85
+ html_css_files = ["css/custom.css"]
86
+
87
+ # Add any paths that contain templates here, relative to this directory.
88
+ templates_path = ["_templates"]
89
+
90
+ # List of patterns, relative to source directory, that match files and
91
+ # directories to ignore when looking for source files.
92
+ # This pattern also affects html_static_path and html_extra_path.
93
+ exclude_patterns = []
94
+
95
+
96
+ # -- Options for HTML output -------------------------------------------------
97
+
98
+ # The theme to use for HTML and HTML Help pages. See the documentation for
99
+ # a list of builtin themes.
100
+ #
101
+ html_theme = "furo"
102
+
103
+ # Add any paths that contain custom static files (such as style sheets) here,
104
+ # relative to this directory. They are copied after the builtin static files,
105
+ # so a file named "default.css" will overwrite the builtin "default.css".
106
+ html_static_path = ["_static"]
107
+
108
+ # Uncomment the following if you want to document __call__
109
+ #
110
+ # def skip(app, what, name, obj, would_skip, options):
111
+ # if name == "__call__":
112
+ # return False
113
+ # return would_skip
114
+ #
115
+ # def setup(app):
116
+ # app.connect("autodoc-skip-member", skip)
117
+
118
+ autosummary_imported_members = True
119
+ autosummary_ignore_module_all = False
120
+ autodoc_member_order = "bysource"
s3prl_s3prl_main/docs/source/contribute/general.rst ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _general-contribution-guideline:
2
+
3
+ General Guideline
4
+ ==================
5
+
6
+ Thank you for considering contributing to S3PRL, we really appreciate it.
7
+
8
+ However, due to the increasing difficulty of maintenance, please understand that **we might not accept the new feature**.
9
+
10
+ Hence, **before submitting the implemented pull request**,
11
+ please **submit your feature request** to the Github issue page so we can discuss about whether we want it and how to achieve it.
12
+
13
+ .. warning::
14
+
15
+ If we did not go through this discussion, the pull request will not be dealt with and will be directly closed.
16
+
17
+
18
+ Discuss
19
+ -----------
20
+
21
+ `Submit your feature request <https://github.com/s3prl/s3prl/issues/new?assignees=&labels=&template=feature_request.md&title=>`_
22
+ on our Github Issue page to propose features and changes.
23
+
24
+ Please wait for our response and do not move on to the following steps before we have a consensus on what is going to be changed.
25
+
26
+ Setup
27
+ -----------
28
+
29
+ Clone the repository to **S3PRL_ROOT** and install the package
30
+
31
+ .. code-block:: bash
32
+
33
+ S3PRL_ROOT="/home/leo/d/s3prl"
34
+ git clone https://github.com/s3prl/s3prl.git ${S3PRL_ROOT}
35
+ cd ${S3PRL_ROOT}
36
+
37
+ pip install -e ".[dev]"
38
+ # This installs the dependencies for the full functionality of S3PRL in the editable mode,
39
+ # including the dependencies for development like testing and building doc
40
+
41
+
42
+ Tests
43
+ ----------
44
+
45
+ Add unit tests to **${S3PRL_ROOT}/test/** to test your own new modules
46
+
47
+ Verify you pass all the tests
48
+
49
+ .. code-block:: bash
50
+
51
+ cd ${S3PRL_ROOT}
52
+ pytest
53
+
54
+
55
+ Documentation
56
+ -------------
57
+
58
+ Make sure you write the documentation on the modules' docstring
59
+
60
+ Build the documentation and make sure it looks correct
61
+
62
+ .. code-block:: bash
63
+
64
+ cd ${S3PRL_ROOT}/docs
65
+
66
+ # build
67
+ ./rebuild_docs.sh
68
+
69
+ # launch http server
70
+ python3 -m http.server -d build/html/
71
+
72
+ # You can then use a browser to access the doc webpage on: YOUR_IP_OR_LOCALHOST:8000
73
+
74
+ Coding-style check
75
+ ------------------
76
+
77
+ Stage your changes
78
+
79
+ .. code-block:: bash
80
+
81
+ git add "YOUR_CHANGED_OR_ADDED_FILES_ONLY"
82
+
83
+ .. warning::
84
+
85
+ Please do not use **git add .** to add all the files under your repository.
86
+ If there are files not ignored by git (specified in **.gitignore**), like
87
+ temporary experiment result files, they will all be added into git version
88
+ control, which will mess out our repository.
89
+
90
+ .. note::
91
+
92
+ In our **.gitignore**, there are lots of ignored files especially for *.yaml*
93
+ and *.sh* files. If the config files or the shell scripts are important, please
94
+ remember to add them forcely, for example :code::`git add -f asr.yaml`
95
+
96
+ Run **pre-commit** to apply the standardized coding-style on **YOUR_CHANGED_OR_ADDED_FILES_ONLY**
97
+
98
+ .. code-block:: bash
99
+
100
+ pre-commit run
101
+
102
+ If the results show there are files modified by **pre-commit**, you need to re-stage
103
+ these files following the previous step.
104
+
105
+
106
+ Commit / Push
107
+ -------------
108
+
109
+ Commit and push the changes
110
+
111
+ .. code-block:: bash
112
+
113
+ git commit -m "YOUR_COMMIT_MESSAGE"
114
+ git push origin "YOUR_BRANCH"
115
+
116
+
117
+ (Optional) Test against multiple environments
118
+ ---------------------------------------------
119
+
120
+ We leverage **tox** to simulate multiple envs, see the `tox configuration <https://github.com/s3prl/s3prl/blob/main/tox.ini>`_ for more information.
121
+ Tox helps automate the pipeline of creating different virtual envs, installing differnet dependencies of S3PRL, running different testing commands.
122
+ Our Github Action CI also relies on tox, hence you can debug the CI error locally with tox.
123
+
124
+ Before using tox, make sure your cli can launch the following python versions. Usually, this can be achieved via `pyenv <https://github.com/pyenv/pyenv>`_
125
+
126
+ - python3.7
127
+ - python3.8
128
+ - python3.9
129
+ - python3.10
130
+
131
+ List all the available environments. An environment means a pre-defined routine of packaging S3PRL, installing S3PRL, installing specific dependencies,
132
+ test specific commands. See `tox configuration <https://github.com/s3prl/s3prl/blob/main/tox.ini>`_ for more information.
133
+
134
+ .. code-block:: bash
135
+
136
+ tox -l
137
+
138
+ Suppose there is an environment named :code:`all_upstream-py38-audio0.12.1`, you can also test against this specific env:
139
+
140
+ .. code-block:: bash
141
+
142
+ tox -e all_upstream-py38-audio0.12.1
143
+
144
+ Test all environments. This simulate the environments you will meet on the Github Action CI
145
+
146
+ .. code-block:: bash
147
+
148
+ tox
149
+
150
+
151
+ Send a pull request
152
+ -------------------
153
+
154
+ Verify your codes are in the proper format
155
+
156
+ .. code-block:: bash
157
+
158
+ ./ci/format.sh --check
159
+ # If this fails, simply remove --check to do the actual formatting
160
+
161
+ Make sure you add test cases and your change pass the tests
162
+
163
+ .. code-block:: bash
164
+
165
+ pytest
166
+
167
+ Send a pull request on GitHub
s3prl_s3prl_main/docs/source/contribute/private.rst ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Internal S3PRL Development
2
+ ==========================
3
+
4
+ Write code
5
+ ----------
6
+
7
+ 1. Make sure you have access to `s3prl/s3prl-private <https://github.com/s3prl/s3prl-private/>`_
8
+
9
+ 2. Clone the repository to **S3PRL_ROOT** and install the package
10
+
11
+ .. code-block:: bash
12
+
13
+ git clone s3prl/s3prl-private "S3PRL_ROOT"
14
+ cd "S3PRL_ROOT"
15
+ pip install -e ".[dev]" # This installs the dependencies for the full functionality of S3PRL
16
+
17
+ 3. Write code into the packages listed in **S3PRL_ROOT/valid_paths.txt**
18
+
19
+ Unit tests
20
+ ----------
21
+
22
+ 4. Add unit tests to **S3PRL_ROOT/test/** to test your own new modules
23
+
24
+ 5. Verify you pass all the tests
25
+
26
+ .. code-block:: bash
27
+
28
+ cd "S3PRL_ROOT"
29
+ pytest
30
+
31
+ Documentation
32
+ -------------
33
+
34
+ 6. Make sure you write the documentation on the modules' docstring
35
+
36
+ 7. Build the documentation and make sure it looks correct
37
+
38
+ .. code-block:: bash
39
+
40
+ cd "S3PRL_ROOT"/docs
41
+
42
+ # build
43
+ ./rebuild_docs.sh
44
+
45
+ # launch http server
46
+ python3 -m http.server -d build/html/
47
+
48
+ # You can then use a browser to access the doc webpage on: YOUR_IP_OR_LOCALHOST:8000
49
+
50
+ 8. OK now your new changes are ready to be commit
51
+
52
+ Coding-style check
53
+ ------------------
54
+
55
+ 9. Stage your changes
56
+
57
+ .. code-block:: bash
58
+
59
+ git add "YOUR_CHANGED_OR_ADDED_FILES_ONLY"
60
+
61
+ .. warning::
62
+
63
+ Please do not use **git add .** to add all the files under your repository.
64
+ If there are files not ignored by git (specified in **.gitignore**), like
65
+ temporary experiment result files, they will all be added into git version
66
+ control, which will mess out our repository.
67
+
68
+ 10. Run **pre-commit** to apply the standardized coding-style on **YOUR_CHANGED_OR_ADDED_FILES_ONLY**
69
+
70
+ .. code-block:: bash
71
+
72
+ pre-commit run
73
+
74
+ If the results show there are files modified by **pre-commit**, you need to re-stage
75
+ these files following step 9.
76
+
77
+ Commit / Push
78
+ -------------
79
+
80
+ 12. Commit and push the changes
81
+
82
+ .. code-block:: bash
83
+
84
+ git commit -m "YOUR_COMMIT_MESSAGE"
85
+ git push origin "YOUR_BRANCH"
86
+
87
+ Send a pull request
88
+ -------------------
89
+
90
+ Only do this when you are ready to merge your branch. Since once you send a pull request,
91
+ every newly pushed commit will cause GitHub to run CI, but we have a limited number of
92
+ runnable CI per month, regularized by GitHub. Hence, you should do this only after the
93
+ branch is ready.
94
+
95
+ 13. Verify you can pass the CI locally
96
+
97
+ .. code-block:: bash
98
+
99
+ ./ci/format.sh --check
100
+ # If this fails, simply remove --check to do the actual formatting
101
+
102
+ pytest
103
+
104
+ 14. Send a pull request on GitHub
s3prl_s3prl_main/docs/source/contribute/public.rst ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Contribute to S3PRL
2
+ ===================
3
+
4
+ Thank you for considering contributing to S3PRL! We really appreciate it!
5
+
6
+ **Please discuss with us on the Github issue page about your "feature request"** before submitting
7
+ the implemented pull request, and we will discuss about the following things:
8
+
9
+ 1. Whether we want to add it in S3PRL
10
+ 2. If yes, how should we achieve this feature
11
+ 3. How can we collaborate on the implementation
12
+ 4. Proceed with the actual implementation and the pull request
13
+
14
+ .. warning::
15
+
16
+ If we did not go through this discussion, the pull request will not be dealt with and will be directly closed.
17
+
18
+ .. note::
19
+
20
+ S3PRL has grown too large and become too hard to maintain.
21
+ In principle, we plan to keep only a few essential functionality which is guaranteed to work,
22
+ instead of continually adding lots of tasks, which can be easily added but hard to keep people
23
+ always there to maintain all the functions in the long term.
24
+ Hence, *we welcome all the bug fixes or slight improvement on existing functions,
25
+ while by default we do not accept pull requests containing lots of changes or an entirely new task*,
26
+ unless our core maintainers feel it is maintainable from our side.
27
+ That is, we do not expect outside collaborators to maintain their contribution for a long time,
28
+ which is the usual case in our experience. We wish to keep all the functionality of S3PRL work in the long term,
29
+ and thanks for your kind support and understanding.
s3prl_s3prl_main/docs/source/contribute/upstream.rst ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding New Upstream
2
+ ====================
3
+
4
+
5
+ Discuss
6
+ ---------
7
+
8
+ Please make sure that you already go through :doc:`./general`.
9
+ Again, please make sure we have consensus on the new feature request.
10
+ The best and the most transparent way is to
11
+ `submit your feature request <https://github.com/s3prl/s3prl/issues/new?assignees=&labels=&template=feature_request.md&title=>`_.
12
+
13
+
14
+ Copy from the template
15
+ -----------------------
16
+
17
+ To add new upstream, you can start with an `example <https://github.com/s3prl/s3prl/tree/main/s3prl/upstream/example>`_
18
+ Suppose your new upstream called :code:`my_awesome_upstream`, the simplest way to start will be the following:
19
+
20
+ 1.
21
+
22
+ .. code-block:: bash
23
+
24
+ cd ${S3PRL_ROOT}
25
+ cp -r s3prl/upstream/example/ s3prl/upstream/my_awesome_upstream
26
+
27
+ 2. In :code:`s3prl/upstream/my_awesome_upstream/hubconf.py`, change :code:`customized_upstream` to :code:`my_entry_1`
28
+ 3. In :code:`s3prl/hub.py`, add :code:`from s3prl.upstream.my_awesome_upstream.hubconf import *`
29
+
30
+ 4.
31
+
32
+ .. code-block:: bash
33
+
34
+ python3 utility/extract_feat.py my_entry_1 sample_hidden_states
35
+ # this script extract hidden states from an upstream entry to the "sample_hidden_states" folder
36
+
37
+ This will extract the hidden states from this :code:`my_entry_1` entry.
38
+ The default content in :code:`s3prl/upstream/example/` always works, so you can simply edit the files
39
+ inside the new :code:`s3prl/upstream/my_awesome_upstream` folder to enable your new upstream.
40
+
41
+
42
+ Implement
43
+ ----------
44
+
45
+ The folder is in the following structure:
46
+
47
+ .. code-block:: bash
48
+
49
+ my_awesome_upstream
50
+ |
51
+ ---- expert.py
52
+ |
53
+ ---- hubconf.py
54
+
55
+ In principle, :code:`hubconf.py` serves as the URL registry, where each callable function is an entry specifying
56
+ the source of the checkpoint, while the :code:`expert.py` serves as the wrapper of your model definition to fit
57
+ with our upstream API.
58
+
59
+ During your implementation, please try to remove as many package dependencies as possible, since the upstream
60
+ functionality is our core feature, and should have minimal dependencies to be maintainable.
61
+
62
+
63
+ Tests
64
+ -------
65
+
66
+ After you implementation, please make sure all your entries can pass `the tests <https://github.com/s3prl/s3prl/blob/8eac602117003e2bb5cdb7a4d0e94cc9975fd4f2/test/test_upstream.py#L194-L250>`_
67
+ The :code:`test_upstream_with_extracted` test case requires you to pre-extract the expected hidden states via:
68
+
69
+ .. code-block:: bash
70
+
71
+ python3 utility/extract_feat.py my_awesome_upstream ./sample_hidden_states
72
+
73
+ That is, the test case expects there will be a :code:`my_awesome_upstream.pt` in the :code:`sample_hidden_states` folder.
74
+
75
+ All the existing sampled hidden states are hosted at a `Huggingface Dataset Repo <https://huggingface.co/datasets/s3prl/sample_hidden_states/tree/main>`_,
76
+ and we expect you to clone (by :code:`git lfs`) this :code:`sample_hidden_states` repo and add the sampled hidden states for your new entries.
77
+
78
+ To make changes to this hidden states repo, please follow the steps `here <https://huggingface.co/datasets/s3prl/sample_hidden_states/discussions>`_
79
+ to create a pull request, so that our core maintainer can sync the hidden states extracted by you.
80
+
81
+ In conclusion, to add new upstream one needs to make two pull requests:
82
+
83
+ - To https://github.com/s3prl/s3prl/pulls
84
+ - To https://huggingface.co/datasets/s3prl/sample_hidden_states/tree/main
85
+
86
+
87
+ .. note::
88
+
89
+ In fact, due to the huge time cost, most of the upstreams in S3PRL will not be tested in Github Action CI (or else it will take several hours
90
+ to download all the checkpoints for every PRs). However, our core maintainers will still clone the repository and run tox locally to make sure
91
+ everything works fine, and there is a `tox environment <https://github.com/s3prl/s3prl/blob/8eac602117003e2bb5cdb7a4d0e94cc9975fd4f2/tox.ini#L11>`_
92
+ testing all the upstreams.
93
+
94
+
95
+ Documentation
96
+ --------------
97
+
98
+ After all the implementation, make sure your efforts are known by the users by adding documentation of your entries at
99
+ the :doc:`../tutorial/upstream_collection` tutorial page. Also, you can add your name at the bottom of the tutorial
100
+ page if you like.
s3prl_s3prl_main/docs/source/index.rst ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. S3PRL documentation master file, created by
2
+ sphinx-quickstart on Sun May 15 15:43:39 2022.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ S3PRL
7
+ =====
8
+
9
+ .. image:: https://raw.githubusercontent.com/s3prl/s3prl/master/file/S3PRL-logo.png
10
+
11
+ **S3PRL** is a toolkit targeting for Self-Supervised Learning for speech processing.
12
+ Its full name is **S**\elf-**S**\upervised **S**\peech **P**\re-training and **R**\epresentation **L**\earning.
13
+ It supports the following three major features:
14
+
15
+ * **Pre-training**
16
+
17
+ * You can train the following models from scratch:
18
+
19
+ * *Mockingjay*, *Audio ALBERT*, *TERA*, *APC*, *VQ-APC*, *NPC*, and *DistilHuBERT*
20
+
21
+ * **Pre-trained models (Upstream) collection**
22
+
23
+ * Easily load most of the existing upstream models with pretrained weights in a unified I/O interface.
24
+ * Pretrained models are registered through torch.hub, which means you can use these models in your own project by one-line plug-and-play without depending on this toolkit's coding style.
25
+
26
+ * **Downstream Evaluation**
27
+
28
+ * Utilize upstream models in lots of downstream tasks
29
+ * The official implementation of the `SUPERB Benchmark <https://superbbenchmark.org/>`_
30
+
31
+
32
+ Getting Started
33
+ ---------------
34
+
35
+ .. toctree::
36
+ :caption: Getting started
37
+
38
+ ./tutorial/installation.rst
39
+ ./tutorial/upstream_collection.rst
40
+ ./tutorial/problem.rst
41
+
42
+
43
+ How to Contribute
44
+ -----------------
45
+
46
+ .. toctree::
47
+ :caption: How to Contribute
48
+
49
+ ./contribute/general.rst
50
+ ./contribute/upstream.rst
51
+
52
+
53
+ API Documentation
54
+ -----------------
55
+
56
+ .. autosummary::
57
+ :caption: API Documentation
58
+ :toctree: _autosummary
59
+ :template: custom-module-template.rst
60
+ :recursive:
61
+
62
+ s3prl.nn
63
+ s3prl.problem
64
+ s3prl.task
65
+ s3prl.dataio
66
+ s3prl.metric
67
+ s3prl.util
68
+
69
+
70
+ Indices and tables
71
+ ==================
72
+
73
+ * :ref:`genindex`
74
+ * :ref:`modindex`
75
+ * :ref:`search`
s3prl_s3prl_main/docs/source/tutorial/installation.rst ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Install S3PRL
2
+ =============
3
+
4
+ Minimal installation
5
+ --------------------
6
+
7
+ This installation only enables the **S3PRL Upstream collection** function to
8
+ keep the minimal dependency. To enable all the functions including downstream benchmarking,
9
+ you need to follow `Full installation`_.
10
+
11
+ .. code-block:: bash
12
+
13
+ pip install s3prl
14
+
15
+
16
+ Editable installation
17
+ ---------------------
18
+
19
+ Installing a package in the editable mode means when you use a imported class/function,
20
+ the source code of the class/function is right in your cloned repository.
21
+ So, when you modify the code inside it, the newly imported class/function will reflect
22
+ your modification.
23
+
24
+ .. code-block:: bash
25
+
26
+ git clone https://github.com/s3prl/s3prl.git
27
+ cd s3prl
28
+ pip install -e .
29
+
30
+
31
+ Full installation
32
+ ------------------
33
+
34
+ Install all the dependencies to enable all the S3PRL functions. However, there are **LOTS**
35
+ of dependencies.
36
+
37
+ .. code-block:: bash
38
+
39
+ pip install s3prl[all]
40
+
41
+ # editable
42
+ pip install ".[all]"
43
+
44
+
45
+ Development installation
46
+ -------------------------
47
+
48
+ Install dependencies of full installation and extra packages for development,
49
+ including **pytest** for unit-testing and **sphinx** for documentation.
50
+
51
+ Usually, you will use this installation variant only in editable mode
52
+
53
+ .. code-block:: bash
54
+
55
+ pip install ".[dev]"
s3prl_s3prl_main/docs/source/tutorial/problem.rst ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Use Problem module to run customizable recipes
2
+ =======================================================
3
+
4
+ The :obj:`s3prl.problem` module provides customizable recipes in pure python (almost).
5
+ See :obj:`s3prl.problem` for all the recipes ready to be ran.
6
+
7
+
8
+ Usage 1. Import and run on Colab
9
+ --------------------------------
10
+
11
+ All the problem class follows the same usage
12
+
13
+ >>> import torch
14
+ >>> from s3prl.problem import SuperbASR
15
+ ...
16
+ >>> problem = SuperbASR()
17
+ >>> config = problem.default_config()
18
+ >>> print(config)
19
+ ...
20
+ >>> # See the config for the '???' required fields and fill them
21
+ >>> config["target_dir"] = "result/asr_exp"
22
+ >>> config["prepare_data"]["dataset_root"] = "/corpus/LibriSpeech/"
23
+ ...
24
+ >>> problem.run(**config)
25
+
26
+
27
+ Usage 2. Run & configure from CLI
28
+ -----------------------------------
29
+
30
+ If you want to directly run from command-line, write a python script (:code:`asr.py`) as follow:
31
+
32
+ .. code-block::
33
+
34
+ # This is asr.py
35
+
36
+ from s3prl.problem import SuperbASR
37
+ SuperbASR().main()
38
+
39
+ Then, run the command below:
40
+
41
+ >>> # Note that the main function supports overridding a field in the config by:
42
+ >>> # --{field_name} {value}
43
+ >>> # --{outer_field_name}.{inner_field_name} {value}
44
+ ...
45
+ >>> python3 asr.py --target_dir result/asr_exp --prepare_data.dataset_root /corpus/LibriSpeech/
46
+
47
+
48
+ Usage 3. Run & configure with the unified :obj:`s3prl-main`
49
+ -----------------------------------------------------------
50
+
51
+ However, this means that for every problem you still need to create a file.
52
+ Hence, we provide an easy helper supporting all the problems in :obj:`s3prl.problem`:
53
+
54
+ >>> python3 -m s3prl.main SuperbASR --target_dir result/asr_exp --prepare_data.dataset_root /corpus/LibriSpeech/
55
+
56
+ or use our CLI entry: :code:`s3prl-main`
57
+
58
+ >>> s3prl-main SuperbASR --target_dir result/asr_exp --prepare_data.dataset_root /corpus/LibriSpeech/
59
+
60
+ Customization
61
+ -------------
62
+
63
+ The core feature of the :obj:`s3prl.problem` module is customization.
64
+ You can easily change the corpus, change the SSL upstream model, change the downstream model,
65
+ optimizer, scheduler... etc, which can all be freely defined by you!
66
+
67
+ We demonstrate how to change the corpus and the downstream model in the following :code:`new_asr.py`:
68
+
69
+ .. code-block:: python
70
+
71
+ # This is new_asr.py
72
+
73
+ import torch
74
+ import pandas as pd
75
+ from s3prl.problem import SuperbASR
76
+
77
+
78
+ class LowResourceLinearSuperbASR(SuperbASR):
79
+ def prepare_data(
80
+ self, prepare_data: dict, target_dir: str, cache_dir: str, get_path_only=False
81
+ ):
82
+ train_path, valid_path, test_paths = super().prepare_data(
83
+ prepare_data, target_dir, cache_dir, get_path_only
84
+ )
85
+
86
+ # Take only the first 100 utterances for training
87
+ df = pd.read_csv(train_path)
88
+ df = df.iloc[:100]
89
+ df.to_csv(train_path, index=False)
90
+
91
+ return train_path, valid_path, test_paths
92
+
93
+ def build_downstream(
94
+ self,
95
+ build_downstream: dict,
96
+ downstream_input_size: int,
97
+ downstream_output_size: int,
98
+ downstream_input_stride: int,
99
+ ):
100
+ class Model(torch.nn.Module):
101
+ def __init__(self, input_size, output_size) -> None:
102
+ super().__init__()
103
+ self.linear = torch.nn.Linear(input_size, output_size)
104
+
105
+ def forward(self, x, x_len):
106
+ return self.linear(x), x_len
107
+
108
+ return Model(downstream_input_size, downstream_output_size)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ LowResourceLinearSuperbASR().main()
113
+
114
+
115
+ By subclassing :obj:`SuperbASR`, we create a new problem called :code:`LowResourceLinearSuperbASR` by
116
+ overridding the :code:`prepare_data` and :code:`build_downstream` methods. After this simple modification,
117
+ now the :code:`LowResourceLinearSuperbASR` works exactly the same as :code:`SuperbASR` while with two slight
118
+ setting changes, and then you can follow the first two usages introduced above to launch this new class.
119
+
120
+ For example:
121
+
122
+ >>> python3 new_asr.py --target_dir result/new_asr_exp --prepare_data.dataset_root /corpus/LibriSpeech/
s3prl_s3prl_main/docs/source/tutorial/upstream_collection.rst ADDED
@@ -0,0 +1,1457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :tocdepth: 2
2
+
3
+ S3PRL Upstream Collection
4
+ =======================================
5
+
6
+ We collect almost all the existing SSL pre-trained models in S3PRL,
7
+ so you can import and use them easily in an unified I/O interface.
8
+
9
+ :obj:`s3prl.nn.upstream.S3PRLUpstream` is an easy interface to retrieve all the self-supervised learning (SSL) pre-trained models
10
+ available in S3PRL. the :code:`name` argument for :obj:`s3prl.nn.upstream.S3PRLUpstream` specifies the checkpoint,
11
+ and then the pre-trained models in this checkpoint will be automatically constructed and
12
+ initialized.
13
+
14
+ Here is an example on how to get a hubert model and its representation using the :code:`name='hubert'`:
15
+
16
+ .. code-block:: python
17
+
18
+ import torch
19
+ from s3prl.nn import S3PRLUpstream
20
+
21
+ model = S3PRLUpstream("hubert")
22
+ model.eval()
23
+
24
+ with torch.no_grad():
25
+ wavs = torch.randn(2, 16000 * 2)
26
+ wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
27
+ all_hs, all_hs_len = model(wavs, wavs_len)
28
+
29
+ for hs, hs_len in zip(all_hs, all_hs_len):
30
+ assert isinstance(hs, torch.FloatTensor)
31
+ assert isinstance(hs_len, torch.LongTensor)
32
+
33
+ batch_size, max_seq_len, hidden_size = hs.shape
34
+ assert hs_len.dim() == 1
35
+
36
+ .. tip::
37
+
38
+ For each SSL learning method, like wav2vec 2.0, there are several checkpoint variants, trained by
39
+ different amount of unlabeled data, or different model sizes. Hence there are also various
40
+ :code:`name` to retrieve these different models.
41
+
42
+ Like, the HuBERT method has "hubert" and "hubert_large_ll60k" different names for different
43
+ checkpoint variants.
44
+
45
+ .. tip::
46
+
47
+ Some SSL pre-trained models' entries can be further configured by a :code:`extra_conf` dictionary.
48
+ See :obj:`s3prl.nn.S3PRLUpstream`. You can find the valid :code:`extra_conf` options in each SSL
49
+ model category. If not documented, by default it does not support any :code:`extra_conf`.
50
+
51
+ The following includes the model and checkpoint information for each :code:`name`, including the releasing date,
52
+ paper, citation, model architecture, pre-training data, criterion, and their source code. The format follows:
53
+
54
+
55
+
56
+ SSL Method
57
+ --------------------------------------------------------
58
+ `Paper full title with arxiv link <https://arxiv.org/>`_
59
+
60
+ .. code-block:: bash
61
+
62
+ @article{citation-block,
63
+ title={Paper Title},
64
+ author={Authors},
65
+ year={2020},
66
+ month={May}
67
+ }
68
+
69
+ The information shared across checkpoint variants.
70
+
71
+ name1
72
+ ~~~~~~~~~~~~~~~~~~~
73
+
74
+ The detailed specific information for this checkpoint variant (:code:`name=name1`)
75
+
76
+ name2
77
+ ~~~~~~~~~~~~~~~~~~~
78
+
79
+ The detailed specific information for this checkpoint variant (:code:`name=name2`)
80
+
81
+
82
+
83
+ Mockingjay
84
+ --------------------------------------------------------
85
+ `Mockingjay: Unsupervised Speech Representation Learning with Deep Bidirectional Transformer Encoders <https://arxiv.org/abs/1910.12638>`_
86
+
87
+ .. code-block:: bash
88
+
89
+ @article{mockingjay,
90
+ title={Mockingjay: Unsupervised Speech Representation Learning with Deep Bidirectional Transformer Encoders},
91
+ ISBN={9781509066315},
92
+ url={http://dx.doi.org/10.1109/ICASSP40776.2020.9054458},
93
+ DOI={10.1109/icassp40776.2020.9054458},
94
+ journal={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
95
+ publisher={IEEE},
96
+ author={Liu, Andy T. and Yang, Shu-wen and Chi, Po-Han and Hsu, Po-chun and Lee, Hung-yi},
97
+ year={2020},
98
+ month={May}
99
+ }
100
+
101
+ Mockingjay is a BERT on Spectrogram, with 12-layers of transformer encoders in the paper.
102
+
103
+
104
+ mockingjay
105
+ ~~~~~~~~~~~~~~~~
106
+
107
+ This is alias for `mockingjay_origin`_
108
+
109
+ mockingjay_origin
110
+ ~~~~~~~~~~~~~~~~~~~~~~~~
111
+
112
+ This is alias for `mockingjay_logMelLinearLarge_T_AdamW_b32_500k_360hr_drop1`_
113
+
114
+ mockingjay_100hr
115
+ ~~~~~~~~~~~~~~~~
116
+
117
+ This is alias for `mockingjay_logMelBase_T_AdamW_b32_200k_100hr`_
118
+
119
+ mockingjay_960hr
120
+ ~~~~~~~~~~~~~~~~
121
+
122
+ This is alias for `mockingjay_logMelBase_T_AdamW_b32_1m_960hr_drop1`_
123
+
124
+ mockingjay_logMelBase_T_AdamW_b32_200k_100hr
125
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126
+
127
+ - Feature: 80-dim log Mel
128
+ - Alteration: time
129
+ - Optimizer: AdamW
130
+ - Batch size: 32
131
+ - Total steps: 200k
132
+ - Unlabled Speech: LibriSpeech 100hr
133
+
134
+ mockingjay_logMelLinearLarge_T_AdamW_b32_500k_360hr_drop1
135
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
136
+
137
+ - Feature: 80-dim log Mel (input) / 201-dim Linear (target)
138
+ - Alteration: time
139
+ - Optimizer: AdamW
140
+ - Batch size: 32
141
+ - Total steps: 500k
142
+ - Unlabled Speech: LibriSpeech 360hr
143
+
144
+ mockingjay_logMelBase_T_AdamW_b32_1m_960hr
145
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
146
+
147
+ - Feature: 80-dim log Mel
148
+ - Alteration: time
149
+ - Optimizer: AdamW
150
+ - Batch size: 32
151
+ - Total steps: 1M
152
+ - Unlabled Speech: LibriSpeech 960hr
153
+
154
+ mockingjay_logMelBase_T_AdamW_b32_1m_960hr_drop1
155
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156
+
157
+ - Feature: 80-dim log Mel
158
+ - Alteration: time
159
+ - Optimizer: AdamW
160
+ - Batch size: 32
161
+ - Total steps: 1M
162
+ - Unlabled Speech: LibriSpeech 960hr
163
+ - Differences: Dropout of 0.1 (instead of 0.3)
164
+
165
+
166
+ mockingjay_logMelBase_T_AdamW_b32_1m_960hr_seq3k
167
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
168
+
169
+ - Feature: 80-dim log Mel
170
+ - Alteration: time
171
+ - Optimizer: AdamW
172
+ - Batch size: 32
173
+ - Total steps: 1M
174
+ - Unlabled Speech: LibriSpeech 960hr
175
+ - Differences: sequence length of 3k (instead of 1.5k)
176
+
177
+
178
+
179
+ TERA
180
+ --------------------------------------------------------
181
+ `TERA: Self-Supervised Learning of Transformer Encoder Representation for Speech <https://arxiv.org/abs/2007.06028>`_
182
+
183
+ .. code-block:: bash
184
+
185
+ @misc{tera,
186
+ title={TERA: Self-Supervised Learning of Transformer Encoder Representation for Speech},
187
+ author={Andy T. Liu and Shang-Wen Li and Hung-yi Lee},
188
+ year={2020},
189
+ eprint={2007.06028},
190
+ archivePrefix={arXiv},
191
+ primaryClass={eess.AS}
192
+ }
193
+
194
+
195
+ tera
196
+ ~~~~~~~~~~~~~~~~
197
+
198
+ This is alias for `tera_960hr`_
199
+
200
+ tera_100hr
201
+ ~~~~~~~~~~~~~~~~~~
202
+
203
+ This is alias for `tera_logMelBase_T_F_M_AdamW_b32_200k_100hr`_
204
+
205
+ tera_960hr
206
+ ~~~~~~~~~~~~~~~~~~~
207
+
208
+ This is alias for `tera_logMelBase_T_F_M_AdamW_b32_1m_960hr_drop1`_
209
+
210
+ tera_logMelBase_T_F_AdamW_b32_200k_100hr
211
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212
+
213
+ - Feature: 80-dim log Mel
214
+ - Alteration: time + freq
215
+ - Optimizer: AdamW
216
+ - Batch size: 32
217
+ - Total steps: 200k
218
+ - Unlabled Speech: LibriSpeech 100hr
219
+
220
+ tera_logMelBase_T_F_M_AdamW_b32_200k_100hr
221
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222
+
223
+ - Feature: 80-dim log Mel
224
+ - Alteration: time + freq + mag
225
+ - Optimizer: AdamW
226
+ - Batch size: 32
227
+ - Total steps: 200k
228
+ - Unlabled Speech: LibriSpeech 100hr
229
+
230
+ tera_logMelBase_T_F_AdamW_b32_1m_960hr
231
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
232
+
233
+ - Feature: 80-dim log Mel
234
+ - Alteration: time + freq
235
+ - Optimizer: AdamW
236
+ - Batch size: 32
237
+ - Total steps: 1M
238
+ - Unlabled Speech: LibriSpeech 960hr
239
+
240
+ tera_logMelBase_T_F_AdamW_b32_1m_960hr_drop1
241
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
242
+
243
+ - Feature: 80-dim log Mel
244
+ - Alteration: time + freq
245
+ - Optimizer: AdamW
246
+ - Batch size: 32
247
+ - Total steps: 1M
248
+ - Unlabled Speech: LibriSpeech 960hr
249
+ - Differences: Dropout of 0.1 (instead of 0.3)
250
+
251
+ tera_logMelBase_T_F_AdamW_b32_1m_960hr_seq3k
252
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
253
+
254
+ - Feature: 80-dim log Mel
255
+ - Alteration: time + freq
256
+ - Optimizer: AdamW
257
+ - Batch size: 32
258
+ - Total steps: 1M
259
+ - Unlabled Speech: LibriSpeech 960hr
260
+ - Differences: sequence length of 3k (instead of 1.5k)
261
+
262
+ tera_logMelBase_T_F_M_AdamW_b32_1m_960hr_drop1
263
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264
+
265
+ - Feature: 80-dim log Mel
266
+ - Alteration: time + freq + mag
267
+ - Optimizer: AdamW
268
+ - Batch size: 32
269
+ - Total steps: 1M
270
+ - Unlabled Speech: 960hr
271
+ - Differences: Dropout of 0.1 (instead of 0.3)
272
+
273
+ tera_fbankBase_T_F_AdamW_b32_200k_100hr
274
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
275
+
276
+ - Feature: 240-dim fbank
277
+ - Alteration: time + freq
278
+ - Optimizer: AdamW
279
+ - Batch size: 32
280
+ - Total steps: 200k
281
+ - Unlabled Speech: LibriSpeech 100hr
282
+
283
+
284
+
285
+ Audio ALBERT
286
+ --------------------------------------------------------
287
+ `Audio ALBERT: A Lite BERT for Self-supervised Learning of Audio Representation <https://arxiv.org/abs/2007.06028>`_
288
+
289
+ .. code-block:: bash
290
+
291
+ @inproceedings{chi2021audio,
292
+ title={Audio albert: A lite bert for self-supervised learning of audio representation},
293
+ author={Chi, Po-Han and Chung, Pei-Hung and Wu, Tsung-Han and Hsieh, Chun-Cheng and Chen, Yen-Hao and Li, Shang-Wen and Lee, Hung-yi},
294
+ booktitle={2021 IEEE Spoken Language Technology Workshop (SLT)},
295
+ pages={344--350},
296
+ year={2021},
297
+ organization={IEEE}
298
+ }
299
+
300
+
301
+ audio_albert
302
+ ~~~~~~~~~~~~~~~~
303
+
304
+ This is alias of `audio_albert_960hr`_
305
+
306
+
307
+ audio_albert_960hr
308
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
309
+
310
+ This is alias of `audio_albert_logMelBase_T_share_AdamW_b32_1m_960hr_drop1`_
311
+
312
+
313
+ audio_albert_logMelBase_T_share_AdamW_b32_1m_960hr_drop1
314
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
315
+
316
+ - Feature: 80-dim log Mel
317
+ - Alteration: time
318
+ - Optimizer: AdamW
319
+ - Batch size: 32
320
+ - Total steps: 1M
321
+ - Unlabled Speech: LibriSpeech 960hr
322
+
323
+
324
+
325
+ APC
326
+ --------------------------------------------------------
327
+ `An Unsupervised Autoregressive Model for Speech Representation Learning <https://arxiv.org/abs/1904.03240>`_
328
+
329
+ .. code-block:: bash
330
+
331
+ @inproceedings{chung2019unsupervised,
332
+ title = {An unsupervised autoregressive model for speech representation learning},
333
+ author = {Chung, Yu-An and Hsu, Wei-Ning and Tang, Hao and Glass, James},
334
+ booktitle = {Interspeech},
335
+ year = {2019}
336
+ }
337
+
338
+
339
+ apc
340
+ ~~~~~~~~~~~~~~~~
341
+
342
+ This is alias of `apc_360hr`_
343
+
344
+
345
+ apc_360hr
346
+ ~~~~~~~~~~~~~~~~~~
347
+
348
+ - Unlabled Speech: LibriSpeech 360hr
349
+
350
+
351
+ apc_960hr
352
+ ~~~~~~~~~~~~~~~~~~
353
+
354
+ - Unlabled Speech: LibriSpeech 960hr
355
+
356
+
357
+
358
+ VQ-APC
359
+ --------------------------------------------------------
360
+ `Vector-Quantized Autoregressive Predictive Coding <https://arxiv.org/abs/2005.08392>`_
361
+
362
+ .. code-block:: bash
363
+
364
+ @inproceedings{chung2020vqapc,
365
+ title = {Vector-quantized autoregressive predictive coding},
366
+ autohor = {Chung, Yu-An and Tang, Hao and Glass, James},
367
+ booktitle = {Interspeech},
368
+ year = {2020}
369
+ }
370
+
371
+ vq_apc
372
+ ~~~~~~~~~~~~~~~~
373
+
374
+ This is alias of `vq_apc_360hr`_
375
+
376
+
377
+ vq_apc_360hr
378
+ ~~~~~~~~~~~~~~~~
379
+
380
+ - Unlabled Speech: LibriSpeech 360hr
381
+
382
+
383
+ vq_apc_960hr
384
+ ~~~~~~~~~~~~~~~~~
385
+
386
+ - Unlabled Speech: LibriSpeech 960hr
387
+
388
+
389
+
390
+ NPC
391
+ --------------------------------------------------------
392
+ `Non-Autoregressive Predictive Coding for Learning Speech Representations from Local Dependencies <https://arxiv.org/abs/2011.00406>`_
393
+
394
+ .. code-block:: bash
395
+
396
+ @article{liu2020nonautoregressive,
397
+ title = {Non-Autoregressive Predictive Coding for Learning Speech Representations from Local Dependencies},
398
+ author = {Liu, Alexander and Chung, Yu-An and Glass, James},
399
+ journal = {arXiv preprint arXiv:2011.00406},
400
+ year = {2020}
401
+ }
402
+
403
+
404
+ npc
405
+ ~~~~~~~~~~~~~~~~
406
+
407
+ This is alias of `npc_360hr`_
408
+
409
+
410
+ npc_360hr
411
+ ~~~~~~~~~~~~~~~~~~
412
+
413
+ - Unlabled Speech: LibriSpeech 360hr
414
+
415
+
416
+ npc_960hr
417
+ ~~~~~~~~~~~~~~~~~~~~
418
+
419
+ - Unlabled Speech: LibriSpeech 960hr
420
+
421
+
422
+
423
+ PASE+
424
+ --------------------------------------------------------
425
+ `Multi-task self-supervised learning for Robust Speech Recognition <https://arxiv.org/abs/2001.09239>`_
426
+
427
+ .. code-block:: bash
428
+
429
+ @inproceedings{ravanelli2020multi,
430
+ title={Multi-task self-supervised learning for robust speech recognition},
431
+ author={Ravanelli, Mirco and Zhong, Jianyuan and Pascual, Santiago and Swietojanski, Pawel and Monteiro, Joao and Trmal, Jan and Bengio, Yoshua},
432
+ booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
433
+ pages={6989--6993},
434
+ year={2020},
435
+ organization={IEEE}
436
+ }
437
+
438
+ .. hint::
439
+
440
+ To use PASE models, there are many extra dependencies required to install.
441
+ Please follow the below installation instruction:
442
+
443
+ .. code-block:: bash
444
+
445
+ pip install -r https://raw.githubusercontent.com/s3prl/s3prl/master/s3prl/upstream/pase/requirements.txt
446
+
447
+
448
+ pase_plus
449
+ ~~~~~~~~~~~~~~~~
450
+
451
+ - Unlabled Speech: LibriSpeech 50hr
452
+
453
+
454
+
455
+ Modified CPC
456
+ --------------------------------------------------------
457
+ `Unsupervised pretraining transfers well across languages <https://arxiv.org/abs/2002.02848>`_
458
+
459
+ .. code-block:: bash
460
+
461
+ @inproceedings{riviere2020unsupervised,
462
+ title={Unsupervised pretraining transfers well across languages},
463
+ author={Riviere, Morgane and Joulin, Armand and Mazar{\'e}, Pierre-Emmanuel and Dupoux, Emmanuel},
464
+ booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
465
+ pages={7414--7418},
466
+ year={2020},
467
+ organization={IEEE}
468
+ }
469
+
470
+ .. note::
471
+
472
+ This is a slightly improved version on the original CPC by DeepMind. To cite the DeepMind version:
473
+
474
+ .. code-block:: bash
475
+
476
+ @article{oord2018representation,
477
+ title={Representation learning with contrastive predictive coding},
478
+ author={Oord, Aaron van den and Li, Yazhe and Vinyals, Oriol},
479
+ journal={arXiv preprint arXiv:1807.03748},
480
+ year={2018}
481
+ }
482
+
483
+
484
+ modified_cpc
485
+ ~~~~~~~~~~~~~~~~
486
+
487
+ - Unlabled Speech: LibriLight 60k hours
488
+
489
+
490
+
491
+ DeCoAR
492
+ --------------------------------------------------------
493
+ `Deep contextualized acoustic representations for semi-supervised speech recognition <https://arxiv.org/abs/1912.01679>`_
494
+
495
+ .. code-block:: bash
496
+
497
+ @inproceedings{ling2020deep,
498
+ title={Deep contextualized acoustic representations for semi-supervised speech recognition},
499
+ author={Ling, Shaoshi and Liu, Yuzong and Salazar, Julian and Kirchhoff, Katrin},
500
+ booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
501
+ pages={6429--6433},
502
+ year={2020},
503
+ organization={IEEE}
504
+ }
505
+
506
+
507
+ decoar_layers
508
+ ~~~~~~~~~~~~~~~~
509
+
510
+ - Unlabled Speech: LibriSpeech 960hr
511
+
512
+
513
+ DeCoAR 2.0
514
+ --------------------------------------------------------
515
+ `DeCoAR 2.0: Deep Contextualized Acoustic Representations with Vector Quantization <https://arxiv.org/abs/2012.06659>`_
516
+
517
+ .. code-block:: bash
518
+
519
+ @misc{ling2020decoar,
520
+ title={DeCoAR 2.0: Deep Contextualized Acoustic Representations with Vector Quantization},
521
+ author={Shaoshi Ling and Yuzong Liu},
522
+ year={2020},
523
+ eprint={2012.06659},
524
+ archivePrefix={arXiv},
525
+ primaryClass={eess.AS}
526
+ }
527
+
528
+
529
+ decoar2
530
+ ~~~~~~~~~~~~~~~~~~~~~
531
+
532
+ - Unlabled Speech: LibriSpeech 960hr
533
+
534
+
535
+
536
+ wav2vec
537
+ --------------------------------------------------
538
+ `wav2vec: Unsupervised Pre-Training for Speech Recognition <https://arxiv.org/abs/1904.05862>`_
539
+
540
+ .. code-block:: bash
541
+
542
+ @article{schneider2019wav2vec,
543
+ title={wav2vec: Unsupervised Pre-Training for Speech Recognition},
544
+ author={Schneider, Steffen and Baevski, Alexei and Collobert, Ronan and Auli, Michael},
545
+ journal={Proc. Interspeech 2019},
546
+ pages={3465--3469},
547
+ year={2019}
548
+ }
549
+
550
+
551
+ wav2vec
552
+ ~~~~~~~~~~~
553
+
554
+ This is alias of `wav2vec_large`_
555
+
556
+
557
+ wav2vec_large
558
+ ~~~~~~~~~~~~~~~
559
+
560
+ This is the official wav2vec model from fairseq.
561
+
562
+ - Unlabled Speech: LibriSpeech 960hr
563
+
564
+
565
+ vq-wav2vec
566
+ --------------------------------------------------
567
+ `vq-wav2vec: Self-supervised learning of discrete speech representations <https://arxiv.org/abs/1910.05453>`_
568
+
569
+ .. code-block:: bash
570
+
571
+ @inproceedings{baevski2019vq,
572
+ title={vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations},
573
+ author={Baevski, Alexei and Schneider, Steffen and Auli, Michael},
574
+ booktitle={International Conference on Learning Representations},
575
+ year={2019}
576
+ }
577
+
578
+ .. note::
579
+
580
+ We only take the Conv encoders' hidden_states for vq-wav2vec in this SSL method category.
581
+ If you wish to consider the BERT model after ths Conv encoders, please refer to `Discrete BERT`_.
582
+
583
+ vq_wav2vec
584
+ ~~~~~~~~~~~
585
+
586
+ This is alias of `vq_wav2vec_gumbel`_
587
+
588
+
589
+ vq_wav2vec_gumbel
590
+ ~~~~~~~~~~~~~~~~~~~~
591
+
592
+ This is the official vq-wav2vec model from fairseq.
593
+ This model uses gumbel-softmax as the quantization technique
594
+
595
+ - Unlabled Speech: LibriSpeech 960hr
596
+
597
+
598
+ vq_wav2vec_kmeans
599
+ ~~~~~~~~~~~~~~~~~~~~~
600
+
601
+ This is the official vq-wav2vec model from fairseq.
602
+ This model uses K-means as the quantization technique
603
+
604
+
605
+ Discrete BERT
606
+ --------------------------------------------------
607
+ `vq-wav2vec: Self-supervised learning of discrete speech representations <https://arxiv.org/abs/1910.05453>`_
608
+
609
+ .. code-block:: bash
610
+
611
+ @inproceedings{baevski2019vq,
612
+ title={vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations},
613
+ author={Baevski, Alexei and Schneider, Steffen and Auli, Michael},
614
+ booktitle={International Conference on Learning Representations},
615
+ year={2019}
616
+ }
617
+
618
+ This method takes the Conv feature encoder's output, quantize it into token ids, and feed the
619
+ tokens into a NLP BERT (Specifically, RoBERTa). The output hidden_states are all the hidden hidden_states
620
+ of the NLP BERT (excluding the hidden_states in `vq-wav2vec`_)
621
+
622
+
623
+ discretebert
624
+ ~~~~~~~~~~~~~~~~
625
+
626
+ Alias of `vq_wav2vec_kmeans_roberta`_
627
+
628
+
629
+ vq_wav2vec_kmeans_roberta
630
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
631
+
632
+ This model uses `vq_wav2vec_kmeans`_ as the frontend waveform tokenizer. After the waveform is tokenized
633
+ into a sequence of token ids, tokens are then fed into a RoBERTa model.
634
+
635
+
636
+
637
+ wav2vec 2.0
638
+ --------------------------------------------------
639
+ `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`_
640
+
641
+ .. code-block:: bash
642
+
643
+ @article{baevski2020wav2vec,
644
+ title={wav2vec 2.0: A framework for self-supervised learning of speech representations},
645
+ author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael},
646
+ journal={Advances in Neural Information Processing Systems},
647
+ volume={33},
648
+ pages={12449--12460},
649
+ year={2020}
650
+ }
651
+
652
+ All the entries below support the following :code:`extra_conf`:
653
+
654
+ ==================== ====================
655
+ column description
656
+ ==================== ====================
657
+ feature_selection (str) -
658
+ if :code:`fairseq_layers` or :code:`fairseq_layers_before_residual`,
659
+ extract the representation following official fairseq API.
660
+ for :code:`fairseq_layers`, it is the output of each transformer
661
+ encoder layer; for :code:`fairseq_layers_before_residual`, it is
662
+ the output of the feedforward layer (before adding with the
663
+ main residual) of each transformer encoder layer. by default
664
+ this option is None, which follows the default place to extract
665
+ in S3PRL.
666
+ ==================== ====================
667
+
668
+
669
+ wav2vec2_custom
670
+ ~~~~~~~~~~~~~~~~~~~~~
671
+
672
+ This entry expects you to provide the source of the checkpoint: :code:`path_or_url`, which should be
673
+ the local path or a url of the checkpoint converted by :code:`s3prl/upstream/wav2vec2/convert.py` (
674
+ from a regular fairseq checkpoint.)
675
+
676
+ This entry also supports the following additional :code:`extra_conf`.
677
+
678
+ ==================== ====================
679
+ column description
680
+ ==================== ====================
681
+ fairseq (bool) -
682
+ If True, perform the on-the-fly checkpoint conversion, so that
683
+ you can directly give the fairseq checkpoint to the :code:`path_or_url`
684
+ argument, either a fairseq URL or a fairseq checkpoint local path.
685
+ ==================== ====================
686
+
687
+
688
+ hf_wav2vec2_custom
689
+ ~~~~~~~~~~~~~~~~~~~~
690
+
691
+ This entry expects you to provide the source of the checkpoint: :code:`path_or_url`, which should be
692
+ in the HuggingFace format, like :code:`facebook/wav2vec2-large-960h`
693
+
694
+
695
+ wav2vec2
696
+ ~~~~~~~~~~~~~~~~
697
+
698
+ This is the alias of `wav2vec2_base_960`_
699
+
700
+
701
+ wav2vec2_base_960
702
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
703
+ This is the official wav2vec 2.0 model in fairseq
704
+
705
+ - Architecture: 12-layer Transformer encoders
706
+ - Unlabled Speech: LibriSpeech 960hr
707
+
708
+
709
+ wav2vec2_large_960
710
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
711
+
712
+ - Architecture: 24-layer Transformer encoders
713
+ - Unlabled Speech: LibriSpeech 960hr
714
+
715
+
716
+ wav2vec2_large_ll60k
717
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
718
+
719
+ - Architecture: 24-layer Transformer encoders
720
+ - Unlabled Speech: LibriLight LL60k hours
721
+
722
+
723
+ wav2vec2_large_lv60_cv_swbd_fsh
724
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
725
+
726
+ The Large model trained on Libri-Light 60k hours + CommonVoice + Switchboard + Fisher
727
+
728
+ - Architecture: 24-layer Transformer encoders
729
+ - Unlabeled Speech: Libri-Light 60k hours + CommonVoice + Switchboard + Fisher
730
+
731
+
732
+ wav2vec2_conformer_relpos
733
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+
735
+ The results can be found in the Table 4 of `fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`_.
736
+
737
+ - Architecture: 24-layer Conformer encoders with relative positional encoding
738
+ - Unlabeled Speech: LibriLight LL60k hours
739
+
740
+
741
+ wav2vec2_conformer_rope
742
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
743
+
744
+ The results can be found in the Table 4 of `fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`_.
745
+
746
+ - Architecture: 24-layer Conformer encoders with ROPE positional encoding
747
+ - Unlabeled Speech: LibriLight LL60k hours
748
+
749
+
750
+ wav2vec2_base_s2st_es_voxpopuli
751
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
752
+
753
+ - The wav2vec2 model from `Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation <https://arxiv.org/abs/2204.02967>`_,
754
+ - released in Fairseq with the link: `https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt <https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt>`_
755
+
756
+
757
+ wav2vec2_base_s2st_en_librilight
758
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
759
+
760
+ - The wav2vec2 model from `Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation <https://arxiv.org/abs/2204.02967>`_,
761
+ - released in Fairseq with the link: `https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt <https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt>`_
762
+
763
+
764
+ wav2vec2_conformer_large_s2st_es_voxpopuli
765
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
766
+
767
+ - The wav2vec2 model from `Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation <https://arxiv.org/abs/2204.02967>`_,
768
+ - released in Fairseq with the link: `https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt <https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt>`_
769
+
770
+
771
+ wav2vec2_conformer_large_s2st_en_librilight
772
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
773
+
774
+ - The wav2vec2 model from `Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation <https://arxiv.org/abs/2204.02967>`_,
775
+ - released in Fairseq with the link: `https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt <https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt>`_
776
+
777
+
778
+ xlsr_53
779
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
780
+
781
+ The wav2vec 2.0 model trained on multilingual presented in `Unsupervised Cross-lingual Representation Learning for Speech Recognition <https://arxiv.org/abs/2006.13979>`_
782
+
783
+ .. code-block:: bash
784
+
785
+ @article{conneau2020unsupervised,
786
+ title={Unsupervised cross-lingual representation learning for speech recognition},
787
+ author={Conneau, Alexis and Baevski, Alexei and Collobert, Ronan and Mohamed, Abdelrahman and Auli, Michael},
788
+ journal={arXiv preprint arXiv:2006.13979},
789
+ year={2020}
790
+ }
791
+
792
+
793
+ XLS-R
794
+ --------------------------------------------------
795
+ `XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale <https://arxiv.org/abs/2111.09296>`_
796
+
797
+ .. code-block:: bash
798
+
799
+ @article{babu2021xls,
800
+ title={XLS-R: Self-supervised cross-lingual speech representation learning at scale},
801
+ author={Babu, Arun and Wang, Changhan and Tjandra, Andros and Lakhotia, Kushal and Xu, Qiantong and Goyal, Naman and Singh, Kritika and von Platen, Patrick and Saraf, Yatharth and Pino, Juan and others},
802
+ journal={arXiv preprint arXiv:2111.09296},
803
+ year={2021}
804
+ }
805
+
806
+
807
+ xls_r_300m
808
+ ~~~~~~~~~~~~~~~~~~~~~
809
+
810
+ - Unlabled Speech: 128 languages, 436K hours
811
+
812
+
813
+ xls_r_1b
814
+ ~~~~~~~~~~~~~~~~~~~~~
815
+
816
+ - Unlabled Speech: 128 languages, 436K hours
817
+
818
+
819
+ xls_r_2b
820
+ ~~~~~~~~~~~~~~~~~~~~~
821
+
822
+ - Unlabled Speech: 128 languages, 436K hours
823
+
824
+
825
+ HuBERT
826
+ --------------------------------------------------
827
+ `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units <https://arxiv.org/abs/2106.07447>`_
828
+
829
+ .. code-block:: bash
830
+
831
+ @article{hsu2021hubert,
832
+ title={Hubert: Self-supervised speech representation learning by masked prediction of hidden units},
833
+ author={Hsu, Wei-Ning and Bolte, Benjamin and Tsai, Yao-Hung Hubert and Lakhotia, Kushal and Salakhutdinov, Ruslan and Mohamed, Abdelrahman},
834
+ journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
835
+ volume={29},
836
+ pages={3451--3460},
837
+ year={2021},
838
+ publisher={IEEE}
839
+ }
840
+
841
+
842
+ hubert_custom
843
+ ~~~~~~~~~~~~~~~~~~~~~
844
+
845
+ This entry expects you to provide the source of the checkpoint: :code:`path_or_url`, which should be
846
+ the local path or a url of the checkpoint converted by :code:`s3prl/upstream/hubert/convert.py` (
847
+ from a regular fairseq checkpoint.)
848
+
849
+ This entry also supports the following additional :code:`extra_conf`.
850
+
851
+ ==================== ====================
852
+ column description
853
+ ==================== ====================
854
+ fairseq (bool) -
855
+ If True, perform the on-the-fly checkpoint conversion, so that
856
+ you can directly give the fairseq checkpoint to the :code:`path_or_url`
857
+ argument, either a fairseq URL or a fairseq checkpoint local path.
858
+ ==================== ====================
859
+
860
+
861
+ hf_hubert_custom
862
+ ~~~~~~~~~~~~~~~~~~~~
863
+
864
+ This entry expects you to provide the source of the checkpoint: :code:`path_or_url`, which should be
865
+ in the HuggingFace format, like :code:`facebook/hubert-large-ll60k`
866
+
867
+
868
+ hubert
869
+ ~~~~~~~~~~~~~~~~~~~~~
870
+
871
+ This is alias of `hubert_base`_
872
+
873
+
874
+ hubert_base
875
+ ~~~~~~~~~~~~~~~~~~~~~
876
+
877
+ - Unlabled Speech: LibriSpeech 960hr
878
+
879
+
880
+ hubert_large_ll60k
881
+ ~~~~~~~~~~~~~~~~~~~~~
882
+
883
+ - Unlabled Speech: LibriLight ll60k hours
884
+
885
+
886
+ mhubert_base_vp_en_es_fr_it3
887
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
888
+
889
+ - The multilingual model from `Textless Speech-to-Speech Translation on Real Data <https://arxiv.org/abs/2112.08352>`_
890
+
891
+
892
+ ESPnetHuBERT
893
+ ----------------------
894
+ `Reducing Barriers to Self-Supervised Learning: HuBERT Pre-training with Academic Compute <https://arxiv.org/abs/2306.06672>`_
895
+
896
+ .. code-block:: bash
897
+
898
+ @inproceedings{chen23l_interspeech,
899
+ author={William Chen and Xuankai Chang and Yifan Peng and Zhaoheng Ni and Soumi Maiti and Shinji Watanabe},
900
+ title={{Reducing Barriers to Self-Supervised Learning: HuBERT Pre-training with Academic Compute}},
901
+ year=2023,
902
+ booktitle={Proc. INTERSPEECH 2023},
903
+ pages={4404--4408},
904
+ doi={10.21437/Interspeech.2023-1176}
905
+ }
906
+
907
+
908
+ espnet_hubert_custom
909
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
910
+
911
+ This entry expects you to provide the source of the checkpoint: :code:`ckpt`, which should be
912
+ the local path of the checkpoint pretrained from ESPnet (e.g., latest.pth).
913
+
914
+
915
+ espnet_hubert_base_iter0
916
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
917
+
918
+ - Unlabeled Speech: LibriSpeech 960hr (first iteration of HuBERT pre-training)
919
+
920
+
921
+ espnet_hubert_base_iter1
922
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
923
+
924
+ - Unlabeled Speech: LibriSpeech 960hr (second iteration of HuBERT pre-training)
925
+
926
+
927
+ espnet_hubert_large_gs_ll60k
928
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
929
+
930
+ - Unlabeled Speech: LibriLight ll60k hours
931
+ - Labeled Speech: GigaSpeech 10k hours (to get units)
932
+
933
+
934
+ WavLabLM
935
+ ----------------------
936
+ `Joint Prediction and Denoising for Large-scale Multilingual Self-supervised Learning <https://arxiv.org/abs/2309.15317>`_
937
+
938
+ .. code-block:: bash
939
+
940
+ @inproceedings{chen23joint,
941
+ author={William Chen and Jiatong Shi and Brian Yan and Dan Berrebbi and Wangyou Zhang and Yifan Peng and Xuankai Chang and Soumi Maiti and Shinji Watanabe},
942
+ title={Joint Prediction and Denoising for Large-scale Multilingual Self-supervised Learning},
943
+ year=2023,
944
+ booktitle={IEEE Automatic Speech Recognition and Understanding Workshop (ASRU)},
945
+ }
946
+
947
+
948
+ cvhubert
949
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
950
+
951
+ - Unlabeled Speech: Commonvoice V11 Multilingual Data (13.6k hours)
952
+ - only 20ms resolution version is provided. `check huggingface for other resolutions <https://huggingface.co/espnet/espnet_cvhubert/tree/main>`_
953
+
954
+
955
+ wavlablm_ek_40k
956
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
957
+
958
+ - Unlabeled Speech: Openli110 (Combination of Commonvoice, Voxpopuli, MLS, Googlei18n, around 39k hours)
959
+ - Initialed from hubert_large_ll60k and continue train with English based k-means from librispeech
960
+
961
+
962
+ wavlablm_mk_40k
963
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
964
+
965
+ - Unlabeled Speech: Openli110 (Combination of Commonvoice, Voxpopuli, MLS, Googlei18n, around 39k hours)
966
+ - Trained from scratch and use a multilingual k-means from the training data
967
+
968
+
969
+ wavlablm_ms_40k
970
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
971
+
972
+ - Unlabeled Speech: Openli110 (Combination of Commonvoice, Voxpopuli, MLS, Googlei18n, around 39k hours)
973
+ - Trained from scratch and use a multilingual k-means from the training data with a multi-stage training
974
+
975
+
976
+ Multiresolution HuBERT (MR-HuBERT)
977
+ ----------------------
978
+ `Multi-resolution HuBERT: Multi-resolution Speech Self-Supervised Learning with Masked Unit Prediction <https://openreview.net/pdf?id=kUuKFW7DIF>`_
979
+
980
+ .. code-block:: bash
981
+
982
+ @inproceedings{anonymous2023multiresolution,
983
+ title={Multi-resolution Hu{BERT}: Multi-resolution Speech Self-Supervised Learning with Masked Unit Prediction},
984
+ author={Anonymous},
985
+ booktitle={Submitted to The Twelfth International Conference on Learning Representations},
986
+ year={2023},
987
+ url={https://openreview.net/forum?id=kUuKFW7DIF},
988
+ note={under review}
989
+ }
990
+
991
+
992
+ multires_hubert_custom
993
+ ~~~~~~~~~~~~~~~~~~~~~
994
+
995
+ This entry expects you to provide the source of the checkpoint: :code:`ckpt`, which should be
996
+ the local path or a url of the checkpoint converted by :code:`s3prl/upstream/multires_hubert/convert.py` (
997
+ from a regular fairseq checkpoint.)
998
+ For more available checkpoints, please check `Fairseq official release <https://github.com/facebookresearch/fairseq/blob/main/examples/mr_hubert/README.md>`_
999
+ Related converted checkpoints are also at `S3PRL HuggingFace Repo <https://huggingface.co/s3prl/mr_hubert>`_
1000
+
1001
+
1002
+ multires_hubert_base
1003
+ ~~~~~~~~~~~~~~~~~~~~~
1004
+
1005
+ - Unlabled Speech: LibriSpeech 960hr
1006
+ - K-means extracted from `hubert_base`_
1007
+
1008
+
1009
+ multires_hubert_large
1010
+ ~~~~~~~~~~~~~~~~~~~~~
1011
+
1012
+ - Unlabeled Speech: LibriLight 60khr
1013
+ - K-means extracted from `hubert_base`_
1014
+
1015
+
1016
+ multires_hubert_multilingual_base
1017
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1018
+
1019
+ - Unlabeled Speech: Voxpopuli 100khr
1020
+ - K-means extracted from `hubert_base`_
1021
+
1022
+
1023
+ multires_hubert_multilingual_large400k
1024
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1025
+
1026
+ - Unlabeled Speech: Voxpopuli 100khr
1027
+ - K-means extracted from `hubert_base`_
1028
+ - Training steps 400k
1029
+
1030
+
1031
+ multires_hubert_multilingual_large600k
1032
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1033
+
1034
+ - Unlabeled Speech: Voxpopuli 100khr
1035
+ - K-means extracted from `hubert_base`_
1036
+ - Training steps 600k
1037
+
1038
+
1039
+ DistilHuBERT
1040
+ ----------------------
1041
+ `DistilHuBERT: Speech Representation Learning by Layer-wise Distillation of Hidden-unit BERT <https://arxiv.org/abs/2110.01900>`_
1042
+
1043
+ .. code-block:: bash
1044
+
1045
+ @inproceedings{chang2022distilhubert,
1046
+ title={DistilHuBERT: Speech representation learning by layer-wise distillation of hidden-unit BERT},
1047
+ author={Chang, Heng-Jui and Yang, Shu-wen and Lee, Hung-yi},
1048
+ booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
1049
+ pages={7087--7091},
1050
+ year={2022},
1051
+ organization={IEEE}
1052
+ }
1053
+
1054
+
1055
+ distilhubert
1056
+ ~~~~~~~~~~~~~~~~~~~~~
1057
+
1058
+ Alias of `distilhubert_base`_
1059
+
1060
+
1061
+ distilhubert_base
1062
+ ~~~~~~~~~~~~~~~~~~~~~
1063
+
1064
+ - Teacher: `hubert_base`_
1065
+ - Unlabled Speech: LibriSpeech 960hr
1066
+
1067
+
1068
+ HuBERT-MGR
1069
+ --------------------------------------------------
1070
+ `Improving Distortion Robustness of Self-supervised Speech Processing Tasks with Domain Adaptation <https://arxiv.org/abs/2203.16104>`_
1071
+
1072
+ .. code-block:: bash
1073
+
1074
+ @article{huang2022improving,
1075
+ title={Improving Distortion Robustness of Self-supervised Speech Processing Tasks with Domain Adaptation},
1076
+ author={Huang, Kuan Po and Fu, Yu-Kuan and Zhang, Yu and Lee, Hung-yi},
1077
+ journal={arXiv preprint arXiv:2203.16104},
1078
+ year={2022}
1079
+ }
1080
+
1081
+
1082
+ hubert_base_robust_mgr
1083
+ ~~~~~~~~~~~~~~~~~~~~~~~
1084
+
1085
+ - Unlabled Speech: LibriSpeech 960hr
1086
+ - Augmentation: MUSAN, gaussian, reverberation
1087
+
1088
+
1089
+ Unispeech-SAT
1090
+ --------------------------------------------------
1091
+ `Unispeech-sat: Universal speech representation learning with speaker aware pre-training <https://arxiv.org/abs/2110.05752>`_
1092
+
1093
+ .. code-block:: bash
1094
+
1095
+ @inproceedings{chen2022unispeech,
1096
+ title={Unispeech-sat: Universal speech representation learning with speaker aware pre-training},
1097
+ author={Chen, Sanyuan and Wu, Yu and Wang, Chengyi and Chen, Zhengyang and Chen, Zhuo and Liu, Shujie and Wu, Jian and Qian, Yao and Wei, Furu and Li, Jinyu and others},
1098
+ booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
1099
+ pages={6152--6156},
1100
+ year={2022},
1101
+ organization={IEEE}
1102
+ }
1103
+
1104
+
1105
+ unispeech_sat
1106
+ ~~~~~~~~~~~~~~~~~~~~~
1107
+
1108
+ Alias of `unispeech_sat_base`_
1109
+
1110
+
1111
+ unispeech_sat_base
1112
+ ~~~~~~~~~~~~~~~~~~~~~~
1113
+
1114
+ - Model Architecture: 12 layers Transformer blocks
1115
+ - Unlabled Speech: LibriSpeech 960 hours
1116
+
1117
+
1118
+ unispeech_sat_base_plus
1119
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
1120
+
1121
+ - Model Architecture: 12 layers Transformer blocks
1122
+ - Unlabled Speech: LibriLight 60k hours + Gigaspeech 10k hours + VoxPopuli 24k hours = 94k hours
1123
+
1124
+
1125
+ unispeech_sat_large
1126
+ ~~~~~~~~~~~~~~~~~~~~~~~~
1127
+
1128
+ - Model Architecture: 24 layers Transformer blocks
1129
+ - Unlabled Speech: LibriLight 60k hours + Gigaspeech 10k hours + VoxPopuli 24k hours = 94k hours
1130
+
1131
+
1132
+
1133
+ WavLM
1134
+ --------------------------------------------------
1135
+ `WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing <https://arxiv.org/abs/2110.13900>`_
1136
+
1137
+ .. code-block:: bash
1138
+
1139
+ @article{Chen2021WavLM,
1140
+ title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
1141
+ author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Michael Zeng and Furu Wei},
1142
+ eprint={2110.13900},
1143
+ archivePrefix={arXiv},
1144
+ primaryClass={cs.CL},
1145
+ year={2021}
1146
+ }
1147
+
1148
+
1149
+ wavlm
1150
+ ~~~~~~~~~~~~~~~~~
1151
+
1152
+ Alias of `wavlm_base_plus`_
1153
+
1154
+
1155
+ wavlm_base
1156
+ ~~~~~~~~~~~~~~~~
1157
+
1158
+ - Model Architecture: 12 layers Transformer blocks
1159
+ - Unlabled Speech: LibriSpeech 960 hours
1160
+
1161
+
1162
+ wavlm_base_plus
1163
+ ~~~~~~~~~~~~~~~~~~~~~
1164
+
1165
+ - Model Architecture: 12 layers Transformer blocks
1166
+ - Unlabled Speech: LibriLight 60k hours + Gigaspeech 10k hours + VoxPopuli 24k hours = 94k hours
1167
+
1168
+
1169
+ wavlm_large
1170
+ ~~~~~~~~~~~~~~~~~~~~~
1171
+
1172
+ - Model Architecture: 24 layers Transformer blocks
1173
+ - Unlabled Speech: LibriLight 60k hours + Gigaspeech 10k hours + VoxPopuli 24k hours = 94k hours
1174
+
1175
+
1176
+ data2vec
1177
+ --------------------------------------------------
1178
+ `data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language <https://arxiv.org/abs/2202.03555>`_
1179
+
1180
+ .. code-block:: bash
1181
+
1182
+ @article{baevski2022data2vec,
1183
+ title={Data2vec: A general framework for self-supervised learning in speech, vision and language},
1184
+ author={Baevski, Alexei and Hsu, Wei-Ning and Xu, Qiantong and Babu, Arun and Gu, Jiatao and Auli, Michael},
1185
+ journal={arXiv preprint arXiv:2202.03555},
1186
+ year={2022}
1187
+ }
1188
+
1189
+
1190
+ data2vec
1191
+ ~~~~~~~~~~~~~~~~~
1192
+
1193
+ Alias of `data2vec_base_960`_
1194
+
1195
+
1196
+ data2vec_base_960
1197
+ ~~~~~~~~~~~~~~~~~~
1198
+
1199
+ - Model Architecture: 12 layers Transformer blocks
1200
+ - Unlabled Speech: LibriSpeech 960 hours
1201
+
1202
+
1203
+ data2vec_large_ll60k
1204
+ ~~~~~~~~~~~~~~~~~~~~~
1205
+
1206
+ - Model Architecture: 24 layers Transformer blocks
1207
+ - Unlabled Speech: LibriLight 60k hours
1208
+
1209
+
1210
+ AST
1211
+ --------------------------------------------------
1212
+ `AST: Audio Spectrogram Transformer <https://arxiv.org/abs/2104.01778>`_
1213
+
1214
+ .. code-block:: bash
1215
+
1216
+ @article{gong2021ast,
1217
+ title={Ast: Audio spectrogram transformer},
1218
+ author={Gong, Yuan and Chung, Yu-An and Glass, James},
1219
+ journal={arXiv preprint arXiv:2104.01778},
1220
+ year={2021}
1221
+ }
1222
+
1223
+
1224
+ All the entries below support the following :code:`extra_conf`:
1225
+
1226
+ ==================== ====================
1227
+ column description
1228
+ ==================== ====================
1229
+ window_secs (float) -
1230
+ The segment waveform length to feed into the
1231
+ AST model. If the input waveform is longer than this
1232
+ length, do sliding windowing on the waveform and concat
1233
+ the results along the time axis.
1234
+ stride_secs (float) -
1235
+ When doing sliding window on the waveform (see
1236
+ above), the stride seconds between windows.
1237
+ ==================== ====================
1238
+
1239
+
1240
+ ast
1241
+ ~~~~~~~~~~~~~~~~~~
1242
+
1243
+ - Labeled Data: AudioSet
1244
+
1245
+
1246
+ SSAST
1247
+ --------------------------------------------------
1248
+ `SSAST: Self-Supervised Audio Spectrogram Transformer <https://arxiv.org/abs/2110.09784>`_
1249
+
1250
+ .. code-block:: bash
1251
+
1252
+ @inproceedings{gong2022ssast,
1253
+ title={Ssast: Self-supervised audio spectrogram transformer},
1254
+ author={Gong, Yuan and Lai, Cheng-I and Chung, Yu-An and Glass, James},
1255
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
1256
+ volume={36},
1257
+ number={10},
1258
+ pages={10699--10709},
1259
+ year={2022}
1260
+ }
1261
+
1262
+
1263
+ All the entries below support the following :code:`extra_conf`:
1264
+
1265
+ ==================== ====================
1266
+ column description
1267
+ ==================== ====================
1268
+ window_secs (float) -
1269
+ The segment waveform length to feed into the
1270
+ AST model. If the input waveform is longer than this
1271
+ length, do sliding windowing on the waveform and concat
1272
+ the results along the time axis.
1273
+ ==================== ====================
1274
+
1275
+
1276
+ ssast_frame_base
1277
+ ~~~~~~~~~~~~~~~~~~
1278
+
1279
+ - Unlabled Data: LibriSpeech & AudioSet
1280
+ - fbank patch size: 128 (freq) * 2 (time)
1281
+
1282
+ ssast_patch_base
1283
+ ~~~~~~~~~~~~~~~~~~~
1284
+
1285
+ - Unlabled Data: LibriSpeech & AudioSet
1286
+ - fbank patch size: 16 (freq) * 16 (time)
1287
+
1288
+
1289
+ MAE-AST
1290
+ --------------------------------------------------
1291
+ `MAE-AST: Masked Autoencoding Audio Spectrogram Transformer <https://arxiv.org/abs/2203.16691>`_
1292
+
1293
+ .. code-block:: bash
1294
+
1295
+ @article{baade2022mae,
1296
+ title={MAE-AST: Masked Autoencoding Audio Spectrogram Transformer},
1297
+ author={Baade, Alan and Peng, Puyuan and Harwath, David},
1298
+ journal={arXiv preprint arXiv:2203.16691},
1299
+ year={2022}
1300
+ }
1301
+
1302
+
1303
+ mae_ast_frame
1304
+ ~~~~~~~~~~~~~~~~~~
1305
+
1306
+ - Unlabled Data: LibriSpeech & AudioSet
1307
+ - fbank patch size: 128 (freq) * 2 (time)
1308
+
1309
+
1310
+ mae_ast_patch
1311
+ ~~~~~~~~~~~~~~~~~~
1312
+
1313
+ - Unlabled Data: LibriSpeech & AudioSet
1314
+ - fbank patch size: 16 (freq) * 16 (time)
1315
+
1316
+
1317
+ Byol-A
1318
+ --------------------------------------------------
1319
+ `BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation <https://arxiv.org/abs/2103.06695>`_
1320
+
1321
+ .. code-block:: bash
1322
+
1323
+ @inproceedings{niizumi2021byol,
1324
+ title={BYOL for audio: Self-supervised learning for general-purpose audio representation},
1325
+ author={Niizumi, Daisuke and Takeuchi, Daiki and Ohishi, Yasunori and Harada, Noboru and Kashino, Kunio},
1326
+ booktitle={2021 International Joint Conference on Neural Networks (IJCNN)},
1327
+ pages={1--8},
1328
+ year={2021},
1329
+ organization={IEEE}
1330
+ }
1331
+
1332
+
1333
+ All the entries below support the following :code:`extra_conf`:
1334
+
1335
+ ==================== ====================
1336
+ column description
1337
+ ==================== ====================
1338
+ window_secs (float) -
1339
+ The segment waveform length to feed into the
1340
+ AST model. If the input waveform is longer than this
1341
+ length, do sliding windowing on the waveform and concat
1342
+ the results along the time axis.
1343
+ stride_secs (float) -
1344
+ When doing sliding window on the waveform (see
1345
+ above), the stride seconds between windows.
1346
+ ==================== ====================
1347
+
1348
+
1349
+ byol_a_2048
1350
+ ~~~~~~~~~~~~~~~~~~
1351
+
1352
+ - Unlabled Data: AudioSet
1353
+
1354
+
1355
+ byol_a_1024
1356
+ ~~~~~~~~~~~~~~~~~~
1357
+
1358
+ - Unlabled Data: AudioSet
1359
+
1360
+
1361
+ byol_a_512
1362
+ ~~~~~~~~~~~~~~~~~~
1363
+
1364
+ - Unlabled Data: AudioSet
1365
+
1366
+
1367
+ Byol-S
1368
+ --------------------------------------------------
1369
+ `BYOL-S: Learning Self-supervised Speech Representations by Bootstrapping <https://arxiv.org/abs/2206.12038>`_
1370
+
1371
+ .. code-block:: bash
1372
+
1373
+ @article{elbanna2022byol,
1374
+ title={Byol-s: Learning self-supervised speech representations by bootstrapping},
1375
+ author={Elbanna, Gasser and Scheidwasser-Clow, Neil and Kegler, Mikolaj and Beckmann, Pierre and Hajal, Karl El and Cernak, Milos},
1376
+ journal={arXiv preprint arXiv:2206.12038},
1377
+ year={2022}
1378
+ }
1379
+
1380
+
1381
+ byol_s_default
1382
+ ~~~~~~~~~~~~~~~~~~
1383
+
1384
+ - Unlabled Data: AudioSet (Speech subset)
1385
+
1386
+
1387
+ byol_s_cvt
1388
+ ~~~~~~~~~~~~~~~~~~
1389
+
1390
+ - Unlabled Data: AudioSet (Speech subset)
1391
+
1392
+
1393
+ byol_s_resnetish34
1394
+ ~~~~~~~~~~~~~~~~~~
1395
+
1396
+ - Unlabled Data: AudioSet (Speech subset)
1397
+
1398
+
1399
+ VGGish
1400
+ --------------------------------------------------
1401
+ `CNN Architectures for Large-Scale Audio Classification <https://arxiv.org/abs/1609.09430>`_
1402
+
1403
+ .. code-block:: bash
1404
+
1405
+ @inproceedings{hershey2017cnn,
1406
+ title={CNN architectures for large-scale audio classification},
1407
+ author={Hershey, Shawn and Chaudhuri, Sourish and Ellis, Daniel PW and Gemmeke, Jort F and Jansen, Aren and Moore, R Channing and Plakal, Manoj and Platt, Devin and Saurous, Rif A and Seybold, Bryan and others},
1408
+ booktitle={2017 ieee international conference on acoustics, speech and signal processing (icassp)},
1409
+ pages={131--135},
1410
+ year={2017},
1411
+ organization={IEEE}
1412
+ }
1413
+
1414
+
1415
+ vggish
1416
+ ~~~~~~~~~~~~~~~~~~
1417
+
1418
+ - Labaled Data: AudioSet
1419
+
1420
+
1421
+ PaSST
1422
+ --------------------------------------------------
1423
+ `Efficient Training of Audio Transformers with Patchout <https://arxiv.org/abs/2110.05069>`_
1424
+
1425
+ .. code-block:: bash
1426
+
1427
+ @article{koutini2021efficient,
1428
+ title={Efficient training of audio transformers with patchout},
1429
+ author={Koutini, Khaled and Schl{\"u}ter, Jan and Eghbal-zadeh, Hamid and Widmer, Gerhard},
1430
+ journal={arXiv preprint arXiv:2110.05069},
1431
+ year={2021}
1432
+ }
1433
+
1434
+ All the entries below support the following :code:`extra_conf`:
1435
+
1436
+ ==================== ====================
1437
+ column description
1438
+ ==================== ====================
1439
+ window_secs (float) -
1440
+ The segment waveform length to feed into the
1441
+ model. If the input waveform is longer than this
1442
+ length, do sliding windowing on the waveform and concat
1443
+ the results along the time axis.
1444
+ stride_secs (float) -
1445
+ When doing sliding window on the waveform (see
1446
+ above), the stride seconds between windows.
1447
+ ==================== ====================
1448
+
1449
+ passt_base
1450
+ ~~~~~~~~~~~~~~~~~~
1451
+
1452
+ - Labaled Data: AudioSet
1453
+
1454
+
1455
+ Authors:
1456
+
1457
+ - Leo 2022
s3prl_s3prl_main/docs/util/is_valid.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("module_root")
8
+ parser.add_argument("valid_paths")
9
+ args = parser.parse_args()
10
+
11
+ with open(args.valid_paths) as file:
12
+ valid_paths = [line.strip() for line in file.readlines()]
13
+
14
+ ignored_paths = []
15
+ module_root_name = Path(args.module_root).stem
16
+ for item in os.listdir(args.module_root):
17
+ pattern = f"{module_root_name}/{item}"
18
+ if pattern not in valid_paths:
19
+ ignored_paths.append(pattern)
20
+
21
+ print(" ".join(ignored_paths))
s3prl_s3prl_main/example/customize.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+
4
+ from s3prl.problem import SuperbASR
5
+
6
+
7
+ class LowResourceLinearSuperbASR(SuperbASR):
8
+ def prepare_data(
9
+ self, prepare_data: dict, target_dir: str, cache_dir: str, get_path_only=False
10
+ ):
11
+ train_path, valid_path, test_paths = super().prepare_data(
12
+ prepare_data, target_dir, cache_dir, get_path_only
13
+ )
14
+
15
+ # Take only the first 100 utterances for training
16
+ df = pd.read_csv(train_path)
17
+ df = df.iloc[:100]
18
+ df.to_csv(train_path, index=False)
19
+
20
+ return train_path, valid_path, test_paths
21
+
22
+ def build_downstream(
23
+ self,
24
+ build_downstream: dict,
25
+ downstream_input_size: int,
26
+ downstream_output_size: int,
27
+ downstream_input_stride: int,
28
+ ):
29
+ import torch
30
+
31
+ class Model(torch.nn.Module):
32
+ def __init__(self, input_size, output_size) -> None:
33
+ super().__init__()
34
+ self.linear = torch.nn.Linear(input_size, output_size)
35
+
36
+ def forward(self, x, x_len):
37
+ return self.linear(x), x_len
38
+
39
+ return Model(downstream_input_size, downstream_output_size)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ LowResourceLinearSuperbASR().main()
s3prl_s3prl_main/example/run_asr.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python3 s3prl/main.py SuperbASR --target_dir result/asr --prepare_data.dataset_root /home/leo/d/datasets/LibriSpeech/ --build_upstream.name apc
s3prl_s3prl_main/example/run_sid.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python3 s3prl/main.py SuperbSID --target_dir result/tmp/ic --prepare_data.dataset_root /home/leo/d/datasets/VoxCeleb1/ --build_upstream.name apc
4
+
s3prl_s3prl_main/example/ssl/pretrain.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+
12
+ from s3prl import Container, Logs, Object, Output
13
+ from s3prl.dataset.base import AugmentedDynamicItemDataset, DataLoader
14
+ from s3prl.sampler import DistributedBatchSamplerWrapper
15
+ from s3prl.util.configuration import parse_override, qualname_to_cls
16
+ from s3prl.util.seed import fix_random_seeds
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ logger = logging.getLogger(__name__)
20
+
21
+ DRYRUN_CONFIG = dict(
22
+ Trainer=dict(
23
+ total_steps=200000,
24
+ log_step=5000,
25
+ valid_step=5000,
26
+ save_step=5000,
27
+ eval_batch=8,
28
+ ),
29
+ )
30
+
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "problem", help="The problem module. E.g. `s3prl.problem.ssl.tera.Tera`"
36
+ )
37
+ parser.add_argument("dataset_root", help="The dataset root for pretrain.")
38
+ parser.add_argument("save_to", help="The directory to save checkpoint")
39
+ parser.add_argument("--n_jobs", type=int, default=8)
40
+ parser.add_argument(
41
+ "--override",
42
+ default=None,
43
+ help=(
44
+ "Override the default_config of the problem module. "
45
+ "E.g. --override ValidSampler.batch_size=4,,TestSampler.batch_size=4"
46
+ ),
47
+ )
48
+ parser.add_argument("--resume", action="store_true")
49
+ parser.add_argument("--dryrun", action="store_true")
50
+ parser.add_argument("--seed", type=int, default=1337)
51
+ args = parser.parse_args()
52
+
53
+ fix_random_seeds(args.seed)
54
+ problem = qualname_to_cls(args.problem)
55
+ config = Container(deepcopy(problem.default_config))
56
+
57
+ for key, value in vars(args).items():
58
+ if key not in ["override"]:
59
+ config[key] = value
60
+
61
+ if args.dryrun:
62
+ config.override(DRYRUN_CONFIG)
63
+
64
+ if isinstance(args.override, str) and len(args.override) > 0:
65
+ override_dict = parse_override(args.override)
66
+ config.override(override_dict)
67
+
68
+ return problem, config
69
+
70
+
71
+ def main():
72
+ logging.basicConfig(level=logging.INFO)
73
+
74
+ problem, config = parse_args()
75
+ save_to = Path(config.save_to)
76
+ save_to.mkdir(exist_ok=True, parents=True)
77
+
78
+ # configure any upstream
79
+ body = problem.Body(**config.Body)
80
+ head = problem.Head(**config.Head)
81
+ loss = problem.Loss(**config.Loss)
82
+ stats = Container()
83
+
84
+ logger.info("Preparing corpus")
85
+ corpus = problem.Corpus(config.dataset_root, **config.Corpus)
86
+ train_data, valid_data, test_data, corpus_stats = corpus().split(3)
87
+ stats.add(corpus_stats)
88
+
89
+ logger.info("Preparing train data")
90
+ train_dataset = AugmentedDynamicItemDataset(train_data, tools=stats)
91
+ train_dataset = problem.TrainData(**config.TrainData)(train_dataset)
92
+ assert train_dataset.get_tool("feat_dim") == problem.input_size
93
+ train_sampler = DistributedBatchSamplerWrapper(
94
+ problem.TrainSampler(train_dataset, **config.TrainSampler),
95
+ num_replicas=1,
96
+ rank=0,
97
+ )
98
+ train_dataloader = DataLoader(
99
+ train_dataset,
100
+ train_sampler,
101
+ num_workers=config.n_jobs,
102
+ )
103
+ stats.add(train_dataset.all_tools())
104
+
105
+ logger.info("Preparing valid data")
106
+ valid_dataset = AugmentedDynamicItemDataset(valid_data, tools=stats)
107
+ valid_dataset = problem.ValidData(**config.ValidData)(valid_dataset)
108
+ valid_sampler = DistributedBatchSamplerWrapper(
109
+ problem.ValidSampler(valid_dataset, **config.ValidSampler),
110
+ num_replicas=1,
111
+ rank=0,
112
+ )
113
+ valid_dataloader = DataLoader(
114
+ valid_dataset,
115
+ valid_sampler,
116
+ num_workers=12,
117
+ )
118
+
119
+ logger.info("Preparing test data")
120
+ test_dataset = AugmentedDynamicItemDataset(test_data, tools=stats)
121
+ test_dataset = problem.TestData(**config.TestData)(test_dataset)
122
+ test_sampler = DistributedBatchSamplerWrapper(
123
+ problem.ValidSampler(test_dataset, **config.TestSampler),
124
+ num_replicas=1,
125
+ rank=0,
126
+ )
127
+ test_dataloader = DataLoader(
128
+ test_dataset,
129
+ test_sampler,
130
+ num_workers=12,
131
+ )
132
+
133
+ sorted_ckpt_dirs = sorted(
134
+ [
135
+ file
136
+ for file in save_to.iterdir()
137
+ if file.is_dir() and str(file).endswith(".ckpts")
138
+ ],
139
+ key=os.path.getmtime,
140
+ )
141
+
142
+ if config.resume and len(sorted_ckpt_dirs) > 0:
143
+ logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
144
+ task = Object.load_checkpoint(sorted_ckpt_dirs[1] / "task.ckpt").to(device)
145
+ else:
146
+ logger.info("Create a new model")
147
+ task = problem.Task(body, head, loss, **stats)
148
+ task = task.to(device)
149
+
150
+ # ALL THE FOLLOWING CODES ARE FOR TRAINER
151
+ # WHICH CAN BE LARGELY SIMPLIFIED WHEN USING OTHER TRAINER PACKAGES
152
+
153
+ opt_cls_qualname, opt_cfgs = config.Optimizer.split(1)
154
+ optimizer = qualname_to_cls(opt_cls_qualname)(task.parameters(), **opt_cfgs)
155
+ if config.resume and len(sorted_ckpt_dirs) > 0:
156
+ optimizer.load_state_dict(torch.load(sorted_ckpt_dirs[-1] / "optimizer.ckpt"))
157
+
158
+ if config.Trainer.use_valid:
159
+ if config.resume and len(sorted_ckpt_dirs) > 0:
160
+ valid_best_score = torch.load(
161
+ sorted_ckpt_dirs[-1] / "valid_best_score.ckpt"
162
+ )[config.Trainer.valid_metric]
163
+ else:
164
+ valid_best_score = -100000 if config.Trainer.valid_higher_better else 100000
165
+
166
+ def save_checkpoint(name):
167
+ ckpt_dir: Path = save_to / f"{name}.ckpts"
168
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
169
+ logger.info(f"Save checkpoint to: {ckpt_dir}")
170
+
171
+ if hasattr(problem, "save_checkpoint"):
172
+ logger.info(f"Save upstream checkpoint to: {ckpt_dir}")
173
+ problem.save_checkpoint(config, body, head, ckpt_dir / "upstream.ckpt")
174
+ task.save_checkpoint(ckpt_dir / "task.ckpt")
175
+ torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.ckpt")
176
+ torch.save(
177
+ {config.Trainer.valid_metric: valid_best_score},
178
+ ckpt_dir / "valid_best_score.ckpt",
179
+ )
180
+
181
+ pbar = tqdm(total=config.Trainer.total_steps, desc="Total")
182
+ train_completed = False
183
+ accum_grad_steps = 0
184
+ while not train_completed:
185
+ batch_results = []
186
+ for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
187
+ pbar.update(1)
188
+ global_step = pbar.n
189
+
190
+ assert isinstance(batch, Output)
191
+ batch = batch.to(device)
192
+
193
+ task.train()
194
+ result = task.train_step(**batch)
195
+ assert isinstance(result, Output)
196
+
197
+ result.loss /= config.Trainer.gradient_accumulate_steps
198
+ result.loss.backward()
199
+
200
+ grad_norm = torch.nn.utils.clip_grad_norm_(
201
+ task.parameters(), max_norm=config.Trainer.gradient_clipping
202
+ )
203
+
204
+ if math.isnan(grad_norm):
205
+ logger.warning(f"Grad norm is NaN at step {global_step}")
206
+ optimizer.zero_grad()
207
+ accum_grad_steps = 0
208
+ else:
209
+ accum_grad_steps += 1
210
+ if accum_grad_steps == config.Trainer.gradient_accumulate_steps:
211
+ optimizer.step()
212
+ optimizer.zero_grad()
213
+ accum_grad_steps = 0
214
+ batch_results.append(result.cacheable())
215
+
216
+ if global_step % config.Trainer.log_step == 0:
217
+ logs: Logs = task.train_reduction(batch_results).logs
218
+ logger.info(f"[Train] step {global_step}")
219
+ for name, value in logs.Scalar.items():
220
+ if name == "loss":
221
+ value *= config.Trainer.gradient_accumulate_steps
222
+ logger.info(f"{name}: {value}")
223
+ batch_results = []
224
+
225
+ if global_step % config.Trainer.valid_step == 0:
226
+ with torch.no_grad():
227
+ if config.Trainer.use_valid:
228
+ valid_results = []
229
+ for batch_idx, batch in enumerate(
230
+ tqdm(
231
+ valid_dataloader,
232
+ desc="Valid",
233
+ total=len(valid_dataloader),
234
+ )
235
+ ):
236
+ if batch_idx == config.Trainer.get("eval_batch", -1):
237
+ break
238
+ batch = batch.to(device)
239
+ task.eval()
240
+ result = task.valid_step(**batch)
241
+ valid_results.append(result.cacheable())
242
+
243
+ logs: Logs = task.valid_reduction(valid_results).slice(1)
244
+ logger.info(f"[Valid] step {global_step}")
245
+ for name, value in logs.Scalar.items():
246
+ logger.info(f"{name}: {value}")
247
+ if name == config.Trainer.valid_metric:
248
+ cond1 = config.Trainer.valid_higher_better and (
249
+ value > valid_best_score
250
+ )
251
+ cond2 = (not config.Trainer.valid_higher_better) and (
252
+ value < valid_best_score
253
+ )
254
+ if cond1 or cond2:
255
+ valid_best_score = value
256
+ save_checkpoint("valid_best")
257
+
258
+ if (
259
+ global_step % config.Trainer.save_step == 0
260
+ or global_step == config.Trainer.total_steps
261
+ ):
262
+ save_checkpoint(f"global_step_{global_step}")
263
+
264
+ if global_step == config.Trainer.total_steps:
265
+ train_completed = True
266
+ break
267
+
268
+ test_results = []
269
+ for batch_idx, batch in enumerate(
270
+ tqdm(test_dataloader, desc="Test", total=len(test_dataloader))
271
+ ):
272
+ if batch_idx == config.Trainer.get("eval_batch", -1):
273
+ break
274
+ batch = batch.to(device)
275
+ result = task.test_step(**batch)
276
+ test_results.append(result.cacheable())
277
+
278
+ logs: Logs = task.test_reduction(test_results).slice(1)
279
+ logger.info(f"[Test] step {global_step}")
280
+ for name, value in logs.Scalar.items():
281
+ logger.info(f"{name}: {value}")
282
+
283
+
284
+ if __name__ == "__main__":
285
+ main()
s3prl_s3prl_main/example/superb/train.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from s3prl import Container, Logs, Object, Output
12
+ from s3prl.dataset.base import AugmentedDynamicItemDataset, DataLoader
13
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
14
+ from s3prl.sampler import DistributedBatchSamplerWrapper
15
+ from s3prl.util.configuration import parse_override, qualname_to_cls
16
+ from s3prl.util.seed import fix_random_seeds
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ logger = logging.getLogger(__name__)
20
+
21
+ DRYRUN_CONFIG = dict(
22
+ Trainer=dict(
23
+ total_steps=10,
24
+ log_step=2,
25
+ valid_step=5,
26
+ save_step=5,
27
+ eval_batch=5,
28
+ ),
29
+ )
30
+
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("upstream", help="The upstream name. E.g. wav2vec2")
35
+ parser.add_argument(
36
+ "problem",
37
+ help="The problem module. E.g. s3prl.problem.SuperbSID",
38
+ )
39
+ parser.add_argument(
40
+ "dataset_root",
41
+ help="The dataset root of your problem.",
42
+ )
43
+ parser.add_argument("save_to", help="The directory to save checkpoint")
44
+ parser.add_argument("--feature_selection", default="hidden_states")
45
+ parser.add_argument("--n_jobs", type=int, default=6)
46
+ parser.add_argument(
47
+ "--override",
48
+ default=None,
49
+ help=(
50
+ "Override the default_config of the problem module. "
51
+ "E.g. --override ValidSampler.batch_size=4,,TestSampler.batch_size=4"
52
+ ),
53
+ )
54
+ parser.add_argument("--resume", action="store_true")
55
+ parser.add_argument("--dryrun", action="store_true")
56
+ parser.add_argument("--seed", type=int, default=1337)
57
+ args = parser.parse_args()
58
+
59
+ fix_random_seeds(args.seed)
60
+ problem = qualname_to_cls(args.problem)
61
+ config = Container(deepcopy(problem.default_config))
62
+
63
+ for key, value in vars(args).items():
64
+ if key not in ["override"]:
65
+ config[key] = value
66
+
67
+ if args.dryrun:
68
+ config.override(DRYRUN_CONFIG)
69
+
70
+ if isinstance(args.override, str) and len(args.override) > 0:
71
+ override_dict = parse_override(args.override)
72
+ config.override(override_dict)
73
+
74
+ return problem, config
75
+
76
+
77
+ def main():
78
+ logging.basicConfig(level=logging.INFO)
79
+
80
+ problem, config = parse_args()
81
+ save_to = Path(config.save_to)
82
+ save_to.mkdir(exist_ok=True, parents=True)
83
+
84
+ # configure any upstream
85
+ upstream = S3PRLUpstream(config.upstream, config.feature_selection)
86
+ stats = Container(upstream_rate=upstream.downsample_rate)
87
+
88
+ logger.info("Preparing corpus")
89
+ corpus = problem.Corpus(config.dataset_root, **config.Corpus)
90
+ train_data, valid_data, test_data, corpus_stats = corpus().split(3)
91
+ stats.add(corpus_stats)
92
+
93
+ logger.info("Preparing train data")
94
+ train_dataset = AugmentedDynamicItemDataset(train_data, tools=stats)
95
+ train_dataset = problem.TrainData(**config.TrainData)(train_dataset)
96
+ train_sampler = DistributedBatchSamplerWrapper(
97
+ problem.TrainSampler(train_dataset, **config.TrainSampler),
98
+ num_replicas=1,
99
+ rank=0,
100
+ )
101
+ train_dataloader = DataLoader(
102
+ train_dataset,
103
+ train_sampler,
104
+ num_workers=config.n_jobs,
105
+ )
106
+ stats.add(train_dataset.all_tools())
107
+
108
+ logger.info("Preparing valid data")
109
+ valid_dataset = AugmentedDynamicItemDataset(valid_data, tools=stats)
110
+ valid_dataset = problem.ValidData(**config.ValidData)(valid_dataset)
111
+ valid_sampler = DistributedBatchSamplerWrapper(
112
+ problem.ValidSampler(valid_dataset, **config.ValidSampler),
113
+ num_replicas=1,
114
+ rank=0,
115
+ )
116
+ valid_dataloader = DataLoader(
117
+ valid_dataset,
118
+ valid_sampler,
119
+ num_workers=12,
120
+ )
121
+
122
+ logger.info("Preparing test data")
123
+ test_dataset = AugmentedDynamicItemDataset(test_data, tools=stats)
124
+ test_dataset = problem.TestData(**config.TestData)(test_dataset)
125
+ test_sampler = DistributedBatchSamplerWrapper(
126
+ problem.ValidSampler(test_dataset, **config.TestSampler),
127
+ num_replicas=1,
128
+ rank=0,
129
+ )
130
+ test_dataloader = DataLoader(
131
+ test_dataset,
132
+ test_sampler,
133
+ num_workers=12,
134
+ )
135
+
136
+ sorted_ckpt_dirs = sorted(
137
+ [
138
+ file
139
+ for file in save_to.iterdir()
140
+ if file.is_dir() and str(file).endswith(".ckpts")
141
+ ],
142
+ key=os.path.getmtime,
143
+ )
144
+
145
+ if config.resume and len(sorted_ckpt_dirs) > 0:
146
+ logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
147
+ task = Object.load_checkpoint(sorted_ckpt_dirs[1] / "task.ckpt").to(device)
148
+ else:
149
+ logger.info("Create a new model")
150
+ downstream = problem.Downstream(
151
+ upstream.output_size,
152
+ **stats,
153
+ )
154
+ model = UpstreamDownstreamModel(upstream, downstream)
155
+ # task = problem.Task(model, **{**stats, **config.Task})
156
+ task = problem.Task(model, **stats, **config.Task)
157
+ task = task.to(device)
158
+
159
+ # ALL THE FOLLOWING CODES ARE FOR TRAINER
160
+ # WHICH CAN BE LARGELY SIMPLIFIED WHEN USING OTHER TRAINER PACKAGES
161
+
162
+ opt_cls_qualname, opt_cfgs = config.Optimizer.split(1)
163
+ optimizer = qualname_to_cls(opt_cls_qualname)(task.parameters(), **opt_cfgs)
164
+ if config.resume and len(sorted_ckpt_dirs) > 0:
165
+ optimizer.load_state_dict(torch.load(sorted_ckpt_dirs[-1] / "optimizer.ckpt"))
166
+
167
+ if config.Trainer.use_valid:
168
+ if config.resume and len(sorted_ckpt_dirs) > 0:
169
+ valid_best_score = torch.load(
170
+ sorted_ckpt_dirs[-1] / "valid_best_score.ckpt"
171
+ )[config.Trainer.valid_metric]
172
+ else:
173
+ valid_best_score = -100000 if config.Trainer.valid_higher_better else 100000
174
+
175
+ def save_checkpoint(name):
176
+ ckpt_dir: Path = save_to / f"{name}.ckpts"
177
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
178
+ logger.info(f"Save checkpoint to: {ckpt_dir}")
179
+
180
+ task.save_checkpoint(ckpt_dir / "task.ckpt")
181
+ torch.save(optimizer.state_dict(), ckpt_dir / "optimizer.ckpt")
182
+ torch.save(
183
+ {config.Trainer.valid_metric: valid_best_score},
184
+ ckpt_dir / "valid_best_score.ckpt",
185
+ )
186
+
187
+ pbar = tqdm(total=config.Trainer.total_steps, desc="Total")
188
+ train_completed = False
189
+ accum_grad_steps = 0
190
+ while not train_completed:
191
+ batch_results = []
192
+ for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
193
+ pbar.update(1)
194
+ global_step = pbar.n
195
+
196
+ assert isinstance(batch, Output)
197
+ batch = batch.to(device)
198
+
199
+ task.train()
200
+ result = task.train_step(**batch)
201
+ assert isinstance(result, Output)
202
+
203
+ result.loss /= config.Trainer.gradient_accumulate_steps
204
+ result.loss.backward()
205
+
206
+ grad_norm = torch.nn.utils.clip_grad_norm_(
207
+ task.parameters(), max_norm=config.Trainer.gradient_clipping
208
+ )
209
+
210
+ if math.isnan(grad_norm):
211
+ logger.warning(f"Grad norm is NaN at step {global_step}")
212
+ optimizer.zero_grad()
213
+ accum_grad_steps = 0
214
+ else:
215
+ accum_grad_steps += 1
216
+ if accum_grad_steps == config.Trainer.gradient_accumulate_steps:
217
+ optimizer.step()
218
+ optimizer.zero_grad()
219
+ accum_grad_steps = 0
220
+ batch_results.append(result.cacheable())
221
+
222
+ if global_step % config.Trainer.log_step == 0:
223
+ logs: Logs = task.train_reduction(batch_results).logs
224
+ logger.info(f"[Train] step {global_step}")
225
+ for name, value in logs.Scalar.items():
226
+ if name == "loss":
227
+ value *= config.Trainer.gradient_accumulate_steps
228
+ logger.info(f"{name}: {value}")
229
+ batch_results = []
230
+
231
+ if global_step % config.Trainer.valid_step == 0:
232
+ with torch.no_grad():
233
+ if config.Trainer.use_valid:
234
+ valid_results = []
235
+ for batch_idx, batch in enumerate(
236
+ tqdm(
237
+ valid_dataloader,
238
+ desc="Valid",
239
+ total=len(valid_dataloader),
240
+ )
241
+ ):
242
+ if batch_idx == config.Trainer.get("eval_batch", -1):
243
+ break
244
+ batch = batch.to(device)
245
+ task.eval()
246
+ result = task.valid_step(**batch)
247
+ valid_results.append(result.cacheable())
248
+
249
+ logs: Logs = task.valid_reduction(valid_results).slice(1)
250
+ logger.info(f"[Valid] step {global_step}")
251
+ for name, value in logs.Scalar.items():
252
+ logger.info(f"{name}: {value}")
253
+ if name == config.Trainer.valid_metric:
254
+ cond1 = config.Trainer.valid_higher_better and (
255
+ value > valid_best_score
256
+ )
257
+ cond2 = (not config.Trainer.valid_higher_better) and (
258
+ value < valid_best_score
259
+ )
260
+ if cond1 or cond2:
261
+ valid_best_score = value
262
+ save_checkpoint("valid_best")
263
+
264
+ if (
265
+ global_step % config.Trainer.save_step == 0
266
+ or global_step == config.Trainer.total_steps
267
+ ):
268
+ save_checkpoint(f"global_step_{global_step}")
269
+
270
+ if global_step == config.Trainer.total_steps:
271
+ train_completed = True
272
+ break
273
+
274
+ test_results = []
275
+ for batch_idx, batch in enumerate(
276
+ tqdm(test_dataloader, desc="Test", total=len(test_dataloader))
277
+ ):
278
+ if batch_idx == config.Trainer.get("eval_batch", -1):
279
+ break
280
+ batch = batch.to(device)
281
+ result = task.test_step(**batch)
282
+ test_results.append(result.cacheable())
283
+
284
+ logs: Logs = task.test_reduction(test_results).slice(1)
285
+ logger.info(f"[Test] step {global_step}")
286
+ for name, value in logs.Scalar.items():
287
+ logger.info(f"{name}: {value}")
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()
s3prl_s3prl_main/example/superb_asr/inference.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+
6
+ from s3prl import Dataset, Output, Task
7
+ from s3prl.base.object import Object
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ "load_from", help="The directory containing all the checkpoints"
16
+ )
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+
21
+ def main():
22
+ args = parse_args()
23
+ load_from = Path(args.load_from)
24
+
25
+ task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
26
+ task.eval()
27
+
28
+ test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
29
+ test_dataloader = test_dataset.to_dataloader(batch_size=1, num_workers=6)
30
+
31
+ with torch.no_grad():
32
+ for batch in test_dataloader:
33
+ batch: Output = batch.to(device)
34
+ result = task(**batch.subset("x", "x_len", as_type="dict"))
35
+ for name, prediction in zip(batch.name, result.prediction):
36
+ print(name, prediction)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
s3prl_s3prl_main/example/superb_asr/train.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from s3prl import Logs, Object, Output
12
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
13
+ from s3prl.sampler import DistributedBatchSamplerWrapper
14
+ from s3prl.superb import asr as problem
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("librispeech", help="The root directory of LibriSpeech")
23
+ parser.add_argument("save_to", help="The directory to save checkpoint")
24
+ parser.add_argument("--total_steps", type=int, default=200000)
25
+ parser.add_argument("--log_step", type=int, default=100)
26
+ parser.add_argument("--eval_step", type=int, default=5000)
27
+ parser.add_argument("--save_step", type=int, default=100)
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def main():
33
+ logging.basicConfig()
34
+ logger.setLevel(logging.INFO)
35
+
36
+ args = parse_args()
37
+ librispeech = Path(args.librispeech)
38
+ assert librispeech.is_dir()
39
+ save_to = Path(args.save_to)
40
+ save_to.mkdir(exist_ok=True, parents=True)
41
+
42
+ logger.info("Preparing preprocessor")
43
+ preprocessor = problem.Preprocessor(
44
+ librispeech, splits=["train-clean-100", "dev-clean", "test-clean"]
45
+ )
46
+
47
+ logger.info("Preparing train dataloader")
48
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
49
+ train_sampler = problem.TrainSampler(
50
+ train_dataset, max_timestamp=16000 * 1000, shuffle=True
51
+ )
52
+ train_sampler = DistributedBatchSamplerWrapper(
53
+ train_sampler, num_replicas=1, rank=0
54
+ )
55
+ train_dataloader = DataLoader(
56
+ train_dataset,
57
+ batch_sampler=train_sampler,
58
+ num_workers=4,
59
+ collate_fn=train_dataset.collate_fn,
60
+ )
61
+
62
+ logger.info("Preparing valid dataloader")
63
+ valid_dataset = problem.ValidDataset(
64
+ **preprocessor.valid_data(),
65
+ **train_dataset.statistics(),
66
+ )
67
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
68
+ valid_sampler = problem.ValidSampler(valid_dataset, 8)
69
+ valid_sampler = DistributedBatchSamplerWrapper(
70
+ valid_sampler, num_replicas=1, rank=0
71
+ )
72
+ valid_dataloader = DataLoader(
73
+ valid_dataset,
74
+ batch_sampler=valid_sampler,
75
+ num_workers=4,
76
+ collate_fn=valid_dataset.collate_fn,
77
+ )
78
+
79
+ logger.info("Preparing test dataloader")
80
+ test_dataset = problem.TestDataset(
81
+ **preprocessor.test_data(),
82
+ **train_dataset.statistics(),
83
+ )
84
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
85
+ test_sampler = problem.TestSampler(test_dataset, 8)
86
+ test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0)
87
+ test_dataloader = DataLoader(
88
+ test_dataset,
89
+ batch_sampler=test_sampler,
90
+ num_workers=4,
91
+ collate_fn=test_dataset.collate_fn,
92
+ )
93
+
94
+ latest_task = save_to / "task.ckpt"
95
+ if latest_task.is_file():
96
+ logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
97
+
98
+ # Object.load_checkpoint() from a checkpoint path and
99
+ # Object.from_checkpoint() from a loaded checkpoint dictionary
100
+ # are like AutoModel in Huggingface which you only need to
101
+ # provide the checkpoint for restoring the module.
102
+ #
103
+ # Note that source code definition should be importable, since this
104
+ # auto loading mechanism is just automating the model re-initialization
105
+ # steps instead of scriptify (torch.jit) all the source code in the
106
+ # checkpoint
107
+
108
+ task = Object.load_checkpoint(latest_task).to(device)
109
+
110
+ else:
111
+ logger.info("No last checkpoint found. Create new model")
112
+
113
+ # Model creation block which can be fully customized
114
+ upstream = S3PRLUpstream("apc")
115
+ downstream = problem.DownstreamModel(
116
+ upstream.output_size,
117
+ preprocessor.statistics().output_size,
118
+ hidden_size=[512],
119
+ dropout=[0.2],
120
+ )
121
+ model = UpstreamDownstreamModel(upstream, downstream)
122
+
123
+ # After customize your own model, simply put it into task object
124
+ task = problem.Task(model, preprocessor.statistics().label_loader)
125
+ task = task.to(device)
126
+
127
+ # We do not handle optimizer/scheduler in any special way in S3PRL, since
128
+ # there are lots of dedicated package for this. Hence, we also do not handle
129
+ # the checkpointing for optimizer/scheduler. Depends on what training pipeline
130
+ # the user prefer, either Lightning or SpeechBrain, these frameworks will
131
+ # provide different solutions on how to save these objects. By not handling
132
+ # these objects in S3PRL we are making S3PRL more flexible and agnostic to training pipeline
133
+ # The following optimizer codeblock aims to align with the standard usage
134
+ # of PyTorch which is the standard way to save it.
135
+
136
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
137
+ latest_optimizer = save_to / "optimizer.ckpt"
138
+ if latest_optimizer.is_file():
139
+ optimizer.load_state_dict(torch.load(save_to / "optimizer.ckpt"))
140
+ else:
141
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
142
+
143
+ # The following code block demonstrate how to train with your own training loop
144
+ # This entire block can be easily replaced with Lightning/SpeechBrain Trainer as
145
+ #
146
+ # Trainer(task)
147
+ # Trainer.fit(train_dataloader, valid_dataloader, test_dataloader)
148
+ #
149
+ # As you can see, there is a huge similarity among train/valid/test loops below,
150
+ # so it is a natural step to share these logics with a generic Trainer class
151
+ # as done in Lightning/SpeechBrain
152
+
153
+ pbar = tqdm(total=args.total_steps, desc="Total")
154
+ while True:
155
+ batch_results = []
156
+ for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
157
+ pbar.update(1)
158
+ global_step = pbar.n
159
+
160
+ assert isinstance(batch, Output)
161
+ optimizer.zero_grad()
162
+
163
+ # An Output object can more all its direct
164
+ # attributes/values to the device
165
+ batch = batch.to(device)
166
+
167
+ # An Output object is an OrderedDict so we
168
+ # can use dict decomposition here
169
+ task.train()
170
+ result = task.train_step(**batch)
171
+ assert isinstance(result, Output)
172
+
173
+ # The output of train step must contain
174
+ # at least a loss key
175
+ result.loss.backward()
176
+
177
+ # gradient clipping
178
+ grad_norm = torch.nn.utils.clip_grad_norm_(task.parameters(), max_norm=1.0)
179
+
180
+ if math.isnan(grad_norm):
181
+ logger.warning(f"Grad norm is NaN at step {global_step}")
182
+ else:
183
+ optimizer.step()
184
+
185
+ # Detach from GPU, remove large logging (to Tensorboard or local files)
186
+ # objects like logits and leave only small data like loss scalars / prediction
187
+ # strings, so that these objects can be safely cached in a list in the MEM,
188
+ # and become useful for calculating metrics later
189
+ # The Output class can do these with self.cacheable()
190
+ cacheable_result = result.cacheable()
191
+
192
+ # Cache these small data for later metric calculation
193
+ batch_results.append(cacheable_result)
194
+
195
+ if (global_step + 1) % args.log_step == 0:
196
+ logs: Logs = task.train_reduction(batch_results).logs
197
+ logger.info(f"[Train] step {global_step}")
198
+ for log in logs.values():
199
+ logger.info(f"{log.name}: {log.data}")
200
+ batch_results = []
201
+
202
+ if (global_step + 1) % args.eval_step == 0:
203
+ with torch.no_grad():
204
+ task.eval()
205
+
206
+ # valid
207
+ valid_results = []
208
+ for batch in tqdm(
209
+ valid_dataloader, desc="Valid", total=len(valid_dataloader)
210
+ ):
211
+ batch = batch.to(device)
212
+ result = task.valid_step(**batch)
213
+ cacheable_result = result.cacheable()
214
+ valid_results.append(cacheable_result)
215
+
216
+ logs: Logs = task.valid_reduction(valid_results).logs
217
+ logger.info(f"[Valid] step {global_step}")
218
+ for log in logs.values():
219
+ logger.info(f"{log.name}: {log.data}")
220
+
221
+ if (global_step + 1) % args.save_step == 0:
222
+ task.save_checkpoint(save_to / "task.ckpt")
223
+ torch.save(optimizer.state_dict(), save_to / "optimizer.ckpt")
224
+
225
+ with torch.no_grad():
226
+ # test
227
+ test_results = []
228
+ for batch in tqdm(test_dataloader, desc="Test", total=len(test_dataloader)):
229
+ batch = batch.to(device)
230
+ result = task.test_step(**batch)
231
+ cacheable_result = result.cacheable()
232
+ test_results.append(cacheable_result)
233
+
234
+ logs: Logs = task.test_reduction(test_results).logs
235
+ logger.info(f"[Test] step results")
236
+ for log in logs.values():
237
+ logger.info(f"{log.name}: {log.data}")
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
s3prl_s3prl_main/example/superb_asr/train_with_lightning.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.optim as optim
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+
10
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
11
+ from s3prl.superb import asr as problem
12
+ from s3prl.wrapper import LightningModuleSimpleWrapper
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("librispeech", help="The root directory of LibriSpeech")
21
+ parser.add_argument("save_to", help="The directory to save checkpoint")
22
+ parser.add_argument("--total_steps", type=int, default=200000)
23
+ parser.add_argument("--log_step", type=int, default=100)
24
+ parser.add_argument("--eval_step", type=int, default=5000)
25
+ parser.add_argument("--save_step", type=int, default=100)
26
+ parser.add_argument(
27
+ "--not_resume",
28
+ action="store_true",
29
+ help="Don't resume from the last checkpoint",
30
+ )
31
+
32
+ # for debugging
33
+ parser.add_argument("--limit_train_batches", type=int)
34
+ parser.add_argument("--limit_val_batches", type=int)
35
+ parser.add_argument("--fast_dev_run", action="store_true")
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ logging.basicConfig(level=logging.INFO)
42
+
43
+ args = parse_args()
44
+ librispeech = Path(args.librispeech)
45
+ save_to = Path(args.save_to)
46
+ save_to.mkdir(exist_ok=True, parents=True)
47
+
48
+ logger.info("Preparing preprocessor")
49
+ preprocessor = problem.Preprocessor(librispeech)
50
+
51
+ logger.info("Preparing train dataloader")
52
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
53
+ train_dataloader = train_dataset.to_dataloader(
54
+ batch_size=8,
55
+ num_workers=6,
56
+ shuffle=True,
57
+ )
58
+
59
+ logger.info("Preparing valid dataloader")
60
+ valid_dataset = problem.ValidDataset(
61
+ **preprocessor.valid_data(),
62
+ **train_dataset.statistics(),
63
+ )
64
+ valid_dataloader = valid_dataset.to_dataloader(batch_size=8, num_workers=6)
65
+
66
+ logger.info("Preparing test dataloader")
67
+ test_dataset = problem.TestDataset(
68
+ **preprocessor.test_data(),
69
+ **train_dataset.statistics(),
70
+ )
71
+ test_dataloader = test_dataset.to_dataloader(batch_size=8, num_workers=6)
72
+
73
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
74
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
75
+
76
+ upstream = S3PRLUpstream("apc")
77
+ downstream = problem.DownstreamModel(
78
+ upstream.output_size, preprocessor.statistics().output_size
79
+ )
80
+ model = UpstreamDownstreamModel(upstream, downstream)
81
+ task = problem.Task(model, preprocessor.statistics().label_loader)
82
+
83
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
84
+ lightning_task = LightningModuleSimpleWrapper(task, optimizer)
85
+
86
+ # The above is the usage of our library
87
+
88
+ # The below is pytorch-lightning specific usage, which can be very simple
89
+ # or very sophisticated, depending on how much you want to customized your
90
+ # training loop
91
+
92
+ checkpoint_callback = ModelCheckpoint(
93
+ dirpath=str(save_to),
94
+ filename="superb-asr-{step:02d}-{valid_0_wer:.2f}",
95
+ monitor="valid_0_wer", # since might have multiple valid dataloaders
96
+ save_last=True,
97
+ save_top_k=3, # top 3 best ckpt on valid
98
+ mode="min", # lower, better
99
+ every_n_train_steps=args.save_step,
100
+ )
101
+
102
+ trainer = Trainer(
103
+ callbacks=[checkpoint_callback],
104
+ accelerator="gpu",
105
+ gpus=1,
106
+ max_steps=args.total_steps,
107
+ log_every_n_steps=args.log_step,
108
+ val_check_interval=args.eval_step,
109
+ limit_val_batches=args.limit_val_batches or 1.0,
110
+ limit_train_batches=args.limit_train_batches or 1.0,
111
+ fast_dev_run=args.fast_dev_run,
112
+ )
113
+
114
+ last_ckpt = save_to / "last.ckpt"
115
+ if args.not_resume or not last_ckpt.is_file():
116
+ last_ckpt = None
117
+
118
+ trainer.fit(
119
+ lightning_task,
120
+ train_dataloader,
121
+ val_dataloaders=[valid_dataloader, test_dataloader],
122
+ ckpt_path=last_ckpt,
123
+ )
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
s3prl_s3prl_main/example/superb_sid/inference.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+
6
+ from s3prl import Dataset, Output, Task
7
+ from s3prl.base.object import Object
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ "load_from", help="The directory containing all the checkpoints"
16
+ )
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+
21
+ def main():
22
+ args = parse_args()
23
+ load_from = Path(args.load_from)
24
+
25
+ task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
26
+ task.eval()
27
+
28
+ test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
29
+ test_dataloader = test_dataset.to_dataloader(batch_size=1, num_workers=6)
30
+
31
+ with torch.no_grad():
32
+ for batch in test_dataloader:
33
+ batch: Output = batch.to(device)
34
+ result = task(**batch.subset("x", "x_len", as_type="dict"))
35
+ for name, prediction in zip(batch.name, result.prediction):
36
+ print(name, prediction)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
s3prl_s3prl_main/example/superb_sid/train.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from s3prl import Logs, Object, Output
12
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
13
+ from s3prl.sampler import DistributedBatchSamplerWrapper
14
+ from s3prl.superb import sid as problem
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("voxceleb1", help="The root directory of VoxCeleb1")
23
+ parser.add_argument("save_to", help="The directory to save checkpoint")
24
+ parser.add_argument("--total_steps", type=int, default=200000)
25
+ parser.add_argument("--log_step", type=int, default=100)
26
+ parser.add_argument("--eval_step", type=int, default=5000)
27
+ parser.add_argument("--save_step", type=int, default=100)
28
+ parser.add_argument("--resume", action="store_true")
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ def main():
34
+ logging.basicConfig()
35
+ logger.setLevel(logging.INFO)
36
+
37
+ args = parse_args()
38
+ voxceleb1 = Path(args.voxceleb1)
39
+ assert voxceleb1.is_dir()
40
+ save_to = Path(args.save_to)
41
+ save_to.mkdir(exist_ok=True, parents=True)
42
+
43
+ logger.info("Preparing preprocessor")
44
+ preprocessor = problem.Preprocessor(voxceleb1)
45
+
46
+ logger.info("Preparing train dataloader")
47
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
48
+ train_sampler = problem.TrainSampler(
49
+ train_dataset, max_timestamp=16000 * 200, shuffle=True
50
+ )
51
+ train_sampler = DistributedBatchSamplerWrapper(
52
+ train_sampler, num_replicas=1, rank=0
53
+ )
54
+ train_dataloader = DataLoader(
55
+ train_dataset,
56
+ batch_sampler=train_sampler,
57
+ num_workers=12,
58
+ collate_fn=train_dataset.collate_fn,
59
+ )
60
+
61
+ logger.info("Preparing valid dataloader")
62
+ valid_dataset = problem.ValidDataset(
63
+ **preprocessor.valid_data(),
64
+ **train_dataset.statistics(),
65
+ )
66
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
67
+ valid_sampler = problem.ValidSampler(valid_dataset, 8)
68
+ valid_sampler = DistributedBatchSamplerWrapper(
69
+ valid_sampler, num_replicas=1, rank=0
70
+ )
71
+ valid_dataloader = DataLoader(
72
+ valid_dataset,
73
+ batch_sampler=valid_sampler,
74
+ num_workers=12,
75
+ collate_fn=valid_dataset.collate_fn,
76
+ )
77
+
78
+ logger.info("Preparing test dataloader")
79
+ test_dataset = problem.TestDataset(
80
+ **preprocessor.test_data(),
81
+ **train_dataset.statistics(),
82
+ )
83
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
84
+ test_sampler = problem.TestSampler(test_dataset, 8)
85
+ test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0)
86
+ test_dataloader = DataLoader(
87
+ test_dataset, batch_size=8, num_workers=12, collate_fn=test_dataset.collate_fn
88
+ )
89
+
90
+ latest_task = save_to / "task.ckpt"
91
+ if args.resume and latest_task.is_file():
92
+ logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
93
+
94
+ # Object.load_checkpoint() from a checkpoint path and
95
+ # Object.from_checkpoint() from a loaded checkpoint dictionary
96
+ # are like AutoModel in Huggingface which you only need to
97
+ # provide the checkpoint for restoring the module.
98
+ #
99
+ # Note that source code definition should be importable, since this
100
+ # auto loading mechanism is just automating the model re-initialization
101
+ # steps instead of scriptify (torch.jit) all the source code in the
102
+ # checkpoint
103
+
104
+ task = Object.load_checkpoint(latest_task).to(device)
105
+
106
+ else:
107
+ logger.info("No last checkpoint found. Create new model")
108
+
109
+ # Model creation block which can be fully customized
110
+ upstream = S3PRLUpstream("wav2vec2")
111
+ downstream = problem.DownstreamModel(
112
+ upstream.output_size, len(preprocessor.statistics().category)
113
+ )
114
+ model = UpstreamDownstreamModel(upstream, downstream)
115
+
116
+ # After customize your own model, simply put it into task object
117
+ task = problem.Task(model, preprocessor.statistics().category)
118
+ task = task.to(device)
119
+
120
+ # We do not handle optimizer/scheduler in any special way in S3PRL, since
121
+ # there are lots of dedicated package for this. Hence, we also do not handle
122
+ # the checkpointing for optimizer/scheduler. Depends on what training pipeline
123
+ # the user prefer, either Lightning or SpeechBrain, these frameworks will
124
+ # provide different solutions on how to save these objects. By not handling
125
+ # these objects in S3PRL we are making S3PRL more flexible and agnostic to training pipeline
126
+ # The following optimizer codeblock aims to align with the standard usage
127
+ # of PyTorch which is the standard way to save it.
128
+
129
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
130
+ latest_optimizer = save_to / "optimizer.ckpt"
131
+ if args.resume and latest_optimizer.is_file():
132
+ optimizer.load_state_dict(torch.load(save_to / "optimizer.ckpt"))
133
+ else:
134
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
135
+
136
+ # The following code block demonstrate how to train with your own training loop
137
+ # This entire block can be easily replaced with Lightning/SpeechBrain Trainer as
138
+ #
139
+ # Trainer(task)
140
+ # Trainer.fit(train_dataloader, valid_dataloader, test_dataloader)
141
+ #
142
+ # As you can see, there is a huge similarity among train/valid/test loops below,
143
+ # so it is a natural step to share these logics with a generic Trainer class
144
+ # as done in Lightning/SpeechBrain
145
+
146
+ pbar = tqdm(total=args.total_steps, desc="Total")
147
+ while True:
148
+ batch_results = []
149
+ for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
150
+ pbar.update(1)
151
+ global_step = pbar.n
152
+
153
+ assert isinstance(batch, Output)
154
+ optimizer.zero_grad()
155
+
156
+ # An Output object can more all its direct
157
+ # attributes/values to the device
158
+ batch = batch.to(device)
159
+
160
+ # An Output object is an OrderedDict so we
161
+ # can use dict decomposition here
162
+ task.train()
163
+ result = task.train_step(**batch)
164
+ assert isinstance(result, Output)
165
+
166
+ # The output of train step must contain
167
+ # at least a loss key
168
+ result.loss.backward()
169
+
170
+ # gradient clipping
171
+ grad_norm = torch.nn.utils.clip_grad_norm_(task.parameters(), max_norm=1.0)
172
+
173
+ if math.isnan(grad_norm):
174
+ logger.warning(f"Grad norm is NaN at step {global_step}")
175
+ else:
176
+ optimizer.step()
177
+
178
+ # Detach from GPU, remove large logging (to Tensorboard or local files)
179
+ # objects like logits and leave only small data like loss scalars / prediction
180
+ # strings, so that these objects can be safely cached in a list in the MEM,
181
+ # and become useful for calculating metrics later
182
+ # The Output class can do these with self.cacheable()
183
+ cacheable_result = result.cacheable()
184
+
185
+ # Cache these small data for later metric calculation
186
+ batch_results.append(cacheable_result)
187
+
188
+ if (global_step + 1) % args.log_step == 0:
189
+ logs: Logs = task.train_reduction(batch_results).logs
190
+ logger.info(f"[Train] step {global_step}")
191
+ for log in logs.values():
192
+ logger.info(f"{log.name}: {log.data}")
193
+ batch_results = []
194
+
195
+ if (global_step + 1) % args.eval_step == 0:
196
+ with torch.no_grad():
197
+ task.eval()
198
+
199
+ # valid
200
+ valid_results = []
201
+ for batch in tqdm(
202
+ valid_dataloader, desc="Valid", total=len(valid_dataloader)
203
+ ):
204
+ batch = batch.to(device)
205
+ result = task.valid_step(**batch)
206
+ cacheable_result = result.cacheable()
207
+ valid_results.append(cacheable_result)
208
+
209
+ logs: Logs = task.valid_reduction(valid_results).logs
210
+ logger.info(f"[Valid] step {global_step}")
211
+ for log in logs.values():
212
+ logger.info(f"{log.name}: {log.data}")
213
+
214
+ # test
215
+ test_results = []
216
+ for batch in tqdm(
217
+ test_dataloader, desc="Test", total=len(test_dataloader)
218
+ ):
219
+ batch = batch.to(device)
220
+ result = task.test_step(**batch)
221
+ cacheable_result = result.cacheable()
222
+ test_results.append(cacheable_result)
223
+
224
+ logs: Logs = task.test_reduction(test_results).logs
225
+ logger.info(f"[Test] step {global_step}")
226
+ for log in logs.values():
227
+ logger.info(f"{log.name}: {log.data}")
228
+
229
+ if (global_step + 1) % args.save_step == 0:
230
+ task.save_checkpoint(save_to / "task.ckpt")
231
+ torch.save(optimizer.state_dict(), save_to / "optimizer.ckpt")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()
s3prl_s3prl_main/example/superb_sid/train_with_lightning.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.optim as optim
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+
10
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
11
+ from s3prl.superb import sid as problem
12
+ from s3prl.wrapper import LightningModuleSimpleWrapper
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("voxceleb1", help="The root directory of VoxCeleb1")
21
+ parser.add_argument("save_to", help="The directory to save checkpoint")
22
+ parser.add_argument("--total_steps", type=int, default=200000)
23
+ parser.add_argument("--log_step", type=int, default=100)
24
+ parser.add_argument("--eval_step", type=int, default=5000)
25
+ parser.add_argument("--save_step", type=int, default=100)
26
+ parser.add_argument(
27
+ "--not_resume",
28
+ action="store_true",
29
+ help="Don't resume from the last checkpoint",
30
+ )
31
+
32
+ # for debugging
33
+ parser.add_argument("--limit_train_batches", type=int)
34
+ parser.add_argument("--limit_val_batches", type=int)
35
+ parser.add_argument("--fast_dev_run", action="store_true")
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ logging.basicConfig(level=logging.INFO)
42
+
43
+ args = parse_args()
44
+ voxceleb1 = Path(args.voxceleb1)
45
+ save_to = Path(args.save_to)
46
+ save_to.mkdir(exist_ok=True, parents=True)
47
+
48
+ logger.info("Preparing preprocessor")
49
+ preprocessor = problem.Preprocessor(voxceleb1)
50
+
51
+ logger.info("Preparing train dataloader")
52
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
53
+ train_dataloader = train_dataset.to_dataloader(
54
+ batch_size=8,
55
+ num_workers=6,
56
+ shuffle=True,
57
+ )
58
+
59
+ logger.info("Preparing valid dataloader")
60
+ valid_dataset = problem.ValidDataset(
61
+ **preprocessor.valid_data(),
62
+ **train_dataset.statistics(),
63
+ )
64
+ valid_dataloader = valid_dataset.to_dataloader(batch_size=8, num_workers=6)
65
+
66
+ logger.info("Preparing test dataloader")
67
+ test_dataset = problem.TestDataset(
68
+ **preprocessor.test_data(),
69
+ **train_dataset.statistics(),
70
+ )
71
+ test_dataloader = test_dataset.to_dataloader(batch_size=8, num_workers=6)
72
+
73
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
74
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
75
+
76
+ upstream = S3PRLUpstream("apc")
77
+ downstream = problem.DownstreamModel(
78
+ upstream.output_size, len(preprocessor.statistics().category)
79
+ )
80
+ model = UpstreamDownstreamModel(upstream, downstream)
81
+ task = problem.Task(model, preprocessor.statistics().category)
82
+
83
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
84
+ lightning_task = LightningModuleSimpleWrapper(task, optimizer)
85
+
86
+ # The above is the usage of our library
87
+
88
+ # The below is pytorch-lightning specific usage, which can be very simple
89
+ # or very sophisticated, depending on how much you want to customized your
90
+ # training loop
91
+
92
+ checkpoint_callback = ModelCheckpoint(
93
+ dirpath=str(save_to),
94
+ filename="superb-sid-{step:02d}-{valid_0_accuracy:.2f}",
95
+ monitor="valid_0_accuracy", # since might have multiple valid dataloaders
96
+ save_last=True,
97
+ save_top_k=3, # top 3 best ckpt on valid
98
+ mode="max", # higher, better
99
+ every_n_train_steps=args.save_step,
100
+ )
101
+
102
+ trainer = Trainer(
103
+ callbacks=[checkpoint_callback],
104
+ accelerator="gpu",
105
+ gpus=1,
106
+ max_steps=args.total_steps,
107
+ log_every_n_steps=args.log_step,
108
+ val_check_interval=args.eval_step,
109
+ limit_val_batches=args.limit_val_batches or 1.0,
110
+ limit_train_batches=args.limit_train_batches or 1.0,
111
+ fast_dev_run=args.fast_dev_run,
112
+ )
113
+
114
+ last_ckpt = save_to / "last.ckpt"
115
+ if args.not_resume or not last_ckpt.is_file():
116
+ last_ckpt = None
117
+
118
+ trainer.fit(
119
+ lightning_task,
120
+ train_dataloader,
121
+ val_dataloaders=[valid_dataloader, test_dataloader],
122
+ ckpt_path=last_ckpt,
123
+ )
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
s3prl_s3prl_main/example/superb_sv/inference.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ from s3prl import Output
8
+ from s3prl.base.object import Object
9
+ from s3prl.dataset import Dataset
10
+ from s3prl.task import Task
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--load_from",
19
+ type=str,
20
+ default="result/sv",
21
+ help="The directory containing all the checkpoints",
22
+ )
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = parse_args()
29
+ load_from = Path(args.load_from)
30
+
31
+ task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
32
+ task.eval()
33
+
34
+ test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
35
+ test_dataloader = DataLoader(
36
+ test_dataset, batch_size=1, num_workers=6, collate_fn=test_dataset.collate_fn
37
+ )
38
+
39
+ with torch.no_grad():
40
+ for batch in test_dataloader:
41
+ batch: Output = batch.to(device)
42
+ result = task(**batch.subset("x", "x_len", as_type="dict"))
43
+ print(result.hidden_states.shape)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
s3prl_s3prl_main/example/superb_sv/train.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from s3prl import Logs, Object, Output
12
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
13
+ from s3prl.sampler import DistributedBatchSamplerWrapper
14
+ from s3prl.superb import sv as problem
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--voxceleb1",
24
+ type=str,
25
+ default="/work/jason410/PublicData/Voxceleb1",
26
+ help="The root directory of VoxCeleb1",
27
+ )
28
+ parser.add_argument(
29
+ "--save_to",
30
+ type=str,
31
+ default="result/sv",
32
+ help="The directory to save checkpoint",
33
+ )
34
+ parser.add_argument("--total_steps", type=int, default=200000)
35
+ parser.add_argument("--log_step", type=int, default=100)
36
+ parser.add_argument("--eval_step", type=int, default=200)
37
+ parser.add_argument("--save_step", type=int, default=100)
38
+
39
+ parser.add_argument("--backbone", type=str, default="XVector")
40
+ parser.add_argument("--pooling_type", type=str, default="TAP")
41
+ parser.add_argument("--loss_type", type=str, default="softmax")
42
+ parser.add_argument("--spk_embd_dim", type=int, default=1500)
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ def main():
48
+ logging.basicConfig()
49
+ logger.setLevel(logging.INFO)
50
+
51
+ args = parse_args()
52
+ voxceleb1 = Path(args.voxceleb1)
53
+ assert voxceleb1.is_dir()
54
+ save_to = Path(args.save_to)
55
+ save_to.mkdir(exist_ok=True, parents=True)
56
+
57
+ logger.info("Preparing preprocessor")
58
+ preprocessor = problem.Preprocessor(voxceleb1)
59
+
60
+ logger.info("Preparing train dataloader")
61
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
62
+ train_sampler = problem.TrainSampler(
63
+ train_dataset, max_timestamp=16000 * 1000, shuffle=True
64
+ )
65
+ train_sampler = DistributedBatchSamplerWrapper(
66
+ train_sampler, num_replicas=1, rank=0
67
+ )
68
+ train_dataloader = DataLoader(
69
+ train_dataset,
70
+ batch_sampler=train_sampler,
71
+ num_workers=6,
72
+ collate_fn=train_dataset.collate_fn,
73
+ )
74
+
75
+ logger.info("Preparing valid dataloader")
76
+ valid_dataset = problem.ValidDataset(
77
+ **preprocessor.valid_data(),
78
+ **train_dataset.statistics(),
79
+ )
80
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
81
+ # valid_dataset_reload = Object.load_checkpoint(save_to / "valid_dataset.ckpt")
82
+ valid_sampler = problem.TrainSampler(
83
+ valid_dataset, max_timestamp=16000 * 1000, shuffle=True
84
+ )
85
+ valid_sampler = DistributedBatchSamplerWrapper(
86
+ valid_sampler, num_replicas=1, rank=0
87
+ )
88
+ valid_dataloader = DataLoader(
89
+ valid_dataset,
90
+ batch_sampler=valid_sampler,
91
+ num_workers=6,
92
+ collate_fn=valid_dataset.collate_fn,
93
+ )
94
+
95
+ logger.info("Preparing test dataloader")
96
+ test_dataset = problem.TestDataset(
97
+ **preprocessor.test_data(),
98
+ **train_dataset.statistics(),
99
+ )
100
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
101
+ test_sampler = problem.TestSampler(test_dataset, 8)
102
+ test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0)
103
+ test_dataloader = DataLoader(
104
+ test_dataset, batch_size=1, num_workers=6, collate_fn=test_dataset.collate_fn
105
+ )
106
+
107
+ latest_task = save_to / "task.ckpt"
108
+ if latest_task.is_file():
109
+ logger.info("Last checkpoint found. Load model and optimizer from checkpoint")
110
+
111
+ # Object.load_checkpoint() from a checkpoint path and
112
+ # Object.from_checkpoint() from a loaded checkpoint dictionary
113
+ # are like AutoModel in Huggingface which you only need to
114
+ # provide the checkpoint for restoring the module.
115
+ #
116
+ # Note that source code definition should be importable, since this
117
+ # auto loading mechanism is just automating the model re-initialization
118
+ # steps instead of scriptify (torch.jit) all the source code in the
119
+ # checkpoint
120
+
121
+ task = Object.load_checkpoint(latest_task).to(device)
122
+
123
+ else:
124
+ logger.info("No last checkpoint found. Create new model")
125
+
126
+ # Model creation block which can be fully customized
127
+ upstream = S3PRLUpstream("apc")
128
+ # Have to specify the backbone, pooling_type
129
+ downstream = problem.speaker_embedding_extractor(
130
+ backbone=args.backbone,
131
+ pooling_type=args.pooling_type,
132
+ input_size=upstream.output_size,
133
+ output_size=args.spk_embd_dim,
134
+ )
135
+
136
+ model = UpstreamDownstreamModel(upstream, downstream)
137
+
138
+ # Have to specify the loss_type
139
+ task = problem.Task(
140
+ model=model,
141
+ categories=preprocessor.statistics().category,
142
+ loss_type=args.loss_type,
143
+ trials=test_dataset.statistics().label,
144
+ )
145
+ task = task.to(device)
146
+
147
+ # We do not handle optimizer/scheduler in any special way in S3PRL, since
148
+ # there are lots of dedicated package for this. Hence, we also do not handle
149
+ # the checkpointing for optimizer/scheduler. Depends on what training pipeline
150
+ # the user prefer, either Lightning or SpeechBrain, these frameworks will
151
+ # provide different solutions on how to save these objects. By not handling
152
+ # these objects in S3PRL we are making S3PRL more flexible and agnostic to training pipeline
153
+ # The following optimizer codeblock aims to align with the standard usage
154
+ # of PyTorch which is the standard way to save it.
155
+
156
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
157
+ latest_optimizer = save_to / "optimizer.ckpt"
158
+
159
+ if latest_optimizer.is_file():
160
+ optimizer.load_state_dict(torch.load(save_to / "optimizer.ckpt"))
161
+ else:
162
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
163
+
164
+ # The following code block demonstrate how to train with your own training loop
165
+ # This entire block can be easily replaced with Lightning/SpeechBrain Trainer as
166
+ #
167
+ # Trainer(task)
168
+ # Trainer.fit(train_dataloader, valid_dataloader, test_dataloader)
169
+ #
170
+ # As you can see, there is a huge similarity among train/valid/test loops below,
171
+ # so it is a natural step to share these logics with a generic Trainer class
172
+ # as done in Lightning/SpeechBrain
173
+
174
+ pbar = tqdm(total=args.total_steps, desc="Total")
175
+ while True:
176
+ batch_results = []
177
+ for batch in tqdm(train_dataloader, desc="Train", total=len(train_dataloader)):
178
+ pbar.update(1)
179
+ global_step = pbar.n
180
+
181
+ assert isinstance(batch, Output)
182
+ optimizer.zero_grad()
183
+
184
+ # An Output object can more all its direct
185
+ # attributes/values to the device
186
+ batch = batch.to(device)
187
+
188
+ # An Output object is an OrderedDict so we
189
+ # can use dict decomposition here
190
+ task.train()
191
+ result = task.train_step(**batch)
192
+ assert isinstance(result, Output)
193
+
194
+ # The output of train step must contain
195
+ # at least a loss key
196
+ result.loss.backward()
197
+
198
+ # gradient clipping
199
+ grad_norm = torch.nn.utils.clip_grad_norm_(task.parameters(), max_norm=1.0)
200
+
201
+ if math.isnan(grad_norm):
202
+ logger.warning(f"Grad norm is NaN at step {global_step}")
203
+ else:
204
+ optimizer.step()
205
+
206
+ # Detach from GPU, remove large logging (to Tensorboard or local files)
207
+ # objects like logits and leave only small data like loss scalars / prediction
208
+ # strings, so that these objects can be safely cached in a list in the MEM,
209
+ # and become useful for calculating metrics later
210
+ # The Output class can do these with self.cacheable()
211
+ cacheable_result = result.cacheable()
212
+
213
+ # Cache these small data for later metric calculation
214
+ batch_results.append(cacheable_result)
215
+
216
+ if (global_step + 1) % args.log_step == 0:
217
+ logs: Logs = task.train_reduction(batch_results).logs
218
+ logger.info(f"[Train] step {global_step}")
219
+ for log in logs.values():
220
+ logger.info(f"{log.name}: {log.data}")
221
+ batch_results = []
222
+
223
+ if (global_step + 1) % args.eval_step == 0:
224
+ with torch.no_grad():
225
+ task.eval()
226
+
227
+ # valid
228
+ valid_results = []
229
+ for batch in tqdm(
230
+ valid_dataloader, desc="Valid", total=len(valid_dataloader)
231
+ ):
232
+ batch = batch.to(device)
233
+ result = task.valid_step(**batch)
234
+ cacheable_result = result.cacheable()
235
+ valid_results.append(cacheable_result)
236
+
237
+ logs: Logs = task.valid_reduction(valid_results).logs
238
+ logger.info(f"[Valid] step {global_step}")
239
+ for log in logs.values():
240
+ logger.info(f"{log.name}: {log.data}")
241
+
242
+ # test
243
+ test_results = []
244
+ for batch in tqdm(
245
+ test_dataloader, desc="Test", total=len(test_dataloader)
246
+ ):
247
+ batch = batch.to(device)
248
+ result = task.test_step(**batch)
249
+ test_results.append(result)
250
+ # for key, value in zip(result.name, result.output):
251
+ # test_results[key] = value
252
+
253
+ logs: Logs = task.test_reduction(
254
+ batch_results=test_results,
255
+ ).logs
256
+ logger.info(f"[Test] step {global_step}")
257
+ for log in logs.values():
258
+ logger.info(f"{log.name}: {log.data}")
259
+
260
+ if (global_step + 1) % args.save_step == 0:
261
+ task.save_checkpoint(save_to / "task.ckpt")
262
+ torch.save(optimizer.state_dict(), save_to / "optimizer.ckpt")
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
s3prl_s3prl_main/example/superb_sv/train_with_lightning.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.optim as optim
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from torch.utils.data import DataLoader
10
+
11
+ from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
12
+ from s3prl.sampler import DistributedBatchSamplerWrapper
13
+ from s3prl.superb import sv as problem
14
+ from s3prl.wrapper import LightningModuleSimpleWrapper
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--voxceleb1",
24
+ type=str,
25
+ default="/work/jason410/PublicData/Voxceleb1",
26
+ help="The root directory of VoxCeleb1",
27
+ )
28
+ parser.add_argument(
29
+ "--save_to",
30
+ type=str,
31
+ default="lightning_result/sv",
32
+ help="The directory to save checkpoint",
33
+ )
34
+ parser.add_argument("--total_steps", type=int, default=200000)
35
+ parser.add_argument("--log_step", type=int, default=100)
36
+ parser.add_argument("--eval_step", type=int, default=1000)
37
+ parser.add_argument("--save_step", type=int, default=100)
38
+ parser.add_argument(
39
+ "--not_resume",
40
+ action="store_true",
41
+ help="Don't resume from the last checkpoint",
42
+ )
43
+
44
+ # for debugging
45
+ parser.add_argument("--limit_train_batches", type=int)
46
+ parser.add_argument("--limit_val_batches", type=int)
47
+ parser.add_argument("--fast_dev_run", action="store_true")
48
+
49
+ parser.add_argument("--backbone", type=str, default="XVector")
50
+ parser.add_argument("--pooling_type", type=str, default="TAP")
51
+ parser.add_argument("--loss_type", type=str, default="softmax")
52
+ parser.add_argument("--spk_embd_dim", type=int, default=1500)
53
+ args = parser.parse_args()
54
+ return args
55
+
56
+
57
+ def main():
58
+ logging.basicConfig()
59
+ logger.setLevel(logging.INFO)
60
+
61
+ args = parse_args()
62
+ voxceleb1 = Path(args.voxceleb1)
63
+ assert voxceleb1.is_dir()
64
+ save_to = Path(args.save_to)
65
+ save_to.mkdir(exist_ok=True, parents=True)
66
+
67
+ logger.info("Preparing preprocessor")
68
+ preprocessor = problem.Preprocessor(voxceleb1)
69
+
70
+ logger.info("Preparing train dataloader")
71
+ train_dataset = problem.TrainDataset(**preprocessor.train_data())
72
+ train_sampler = problem.TrainSampler(
73
+ train_dataset, max_timestamp=16000 * 1000, shuffle=True
74
+ )
75
+ train_sampler = DistributedBatchSamplerWrapper(
76
+ train_sampler, num_replicas=1, rank=0
77
+ )
78
+ train_dataloader = DataLoader(
79
+ train_dataset,
80
+ batch_sampler=train_sampler,
81
+ num_workers=6,
82
+ collate_fn=train_dataset.collate_fn,
83
+ )
84
+
85
+ logger.info("Preparing valid dataloader")
86
+ valid_dataset = problem.ValidDataset(
87
+ **preprocessor.valid_data(),
88
+ **train_dataset.statistics(),
89
+ )
90
+ valid_dataset.save_checkpoint(save_to / "valid_dataset.ckpt")
91
+ # valid_dataset_reload = Object.load_checkpoint(save_to / "valid_dataset.ckpt")
92
+ valid_sampler = problem.TrainSampler(
93
+ valid_dataset, max_timestamp=16000 * 1000, shuffle=True
94
+ )
95
+ valid_sampler = DistributedBatchSamplerWrapper(
96
+ valid_sampler, num_replicas=1, rank=0
97
+ )
98
+ valid_dataloader = DataLoader(
99
+ valid_dataset,
100
+ batch_sampler=valid_sampler,
101
+ num_workers=6,
102
+ collate_fn=valid_dataset.collate_fn,
103
+ )
104
+
105
+ logger.info("Preparing test dataloader")
106
+ test_dataset = problem.TestDataset(
107
+ **preprocessor.test_data(),
108
+ **train_dataset.statistics(),
109
+ )
110
+ test_dataset.save_checkpoint(save_to / "test_dataset.ckpt")
111
+ test_sampler = problem.TestSampler(test_dataset, 8)
112
+ test_sampler = DistributedBatchSamplerWrapper(test_sampler, num_replicas=1, rank=0)
113
+ test_dataloader = DataLoader(
114
+ test_dataset, batch_size=1, num_workers=6, collate_fn=test_dataset.collate_fn
115
+ )
116
+
117
+ upstream = S3PRLUpstream("apc")
118
+ # Have to specify the backbone, pooling_type, spk_embd_dim
119
+ downstream = problem.speaker_embedding_extractor(
120
+ backbone=args.backbone,
121
+ pooling_type=args.pooling_type,
122
+ input_size=upstream.output_size,
123
+ output_size=args.spk_embd_dim,
124
+ )
125
+ model = UpstreamDownstreamModel(upstream, downstream)
126
+ # Have to specify the loss_type
127
+ task = problem.Task(
128
+ model=model,
129
+ categories=preprocessor.statistics().category,
130
+ loss_type=args.loss_type,
131
+ trials=test_dataset.statistics().label,
132
+ )
133
+
134
+ optimizer = optim.Adam(task.parameters(), lr=1e-3)
135
+ lightning_task = LightningModuleSimpleWrapper(task, optimizer)
136
+
137
+ # The above is the usage of our library
138
+
139
+ # The below is pytorch-lightning specific usage, which can be very simple
140
+ # or very sophisticated, depending on how much you want to customized your
141
+ # training loop
142
+
143
+ checkpoint_callback = ModelCheckpoint(
144
+ dirpath=str(save_to),
145
+ filename="superb-sv-{step:02d}-{valid_0_accuracy:.2f}",
146
+ monitor="valid_0_accuracy", # since might have multiple valid dataloaders
147
+ save_last=True,
148
+ save_top_k=3, # top 3 best ckpt on valid
149
+ mode="max", # higher, better
150
+ every_n_train_steps=args.save_step,
151
+ )
152
+
153
+ trainer = Trainer(
154
+ callbacks=[checkpoint_callback],
155
+ accelerator="gpu",
156
+ gpus=1,
157
+ max_steps=args.total_steps,
158
+ log_every_n_steps=args.log_step,
159
+ val_check_interval=args.eval_step,
160
+ limit_val_batches=args.limit_val_batches or 1.0,
161
+ limit_train_batches=args.limit_train_batches or 1.0,
162
+ fast_dev_run=args.fast_dev_run,
163
+ )
164
+
165
+ last_ckpt = save_to / "last.ckpt"
166
+ if args.not_resume or not last_ckpt.is_file():
167
+ last_ckpt = None
168
+
169
+ trainer.fit(
170
+ lightning_task,
171
+ train_dataloader,
172
+ val_dataloaders=valid_dataloader,
173
+ ckpt_path=last_ckpt,
174
+ )
175
+
176
+ trainer.test(
177
+ lightning_task,
178
+ dataloaders=test_dataloader,
179
+ ckpt_path=last_ckpt,
180
+ )
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
s3prl_s3prl_main/external_tools/install_espnet.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ if [ $# != 0 ]; then
5
+ echo "Usage: $0"
6
+ exit 1;
7
+ fi
8
+
9
+ if [ ! -e espnet.done ]; then
10
+ if ! python3 -c "import espnet2" &> /dev/null; then
11
+ pip install "espnet>=202308"
12
+ else
13
+ echo echo "espnet is already installed"
14
+ fi
15
+ touch espnet.done
16
+ else
17
+ echo "espnet is already installed"
18
+ fi
19
+
20
+
s3prl_s3prl_main/file/S3PRL-integration.png ADDED
s3prl_s3prl_main/file/S3PRL-interface.png ADDED
s3prl_s3prl_main/file/S3PRL-logo.png ADDED
s3prl_s3prl_main/file/license.svg ADDED
s3prl_s3prl_main/find_content.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ -z $1 ]; then
4
+ echo "Usage: $0 <pattern>"
5
+ exit 1
6
+ fi
7
+
8
+ ag $1 $(cat valid_paths.txt)
s3prl_s3prl_main/hubconf.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from s3prl import hub
2
+
3
+ for _option in hub.options():
4
+ globals()[_option] = getattr(hub, _option)
s3prl_s3prl_main/pyrightconfig.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "exclude": [
3
+ "**/result/**",
4
+ ],
5
+ }
s3prl_s3prl_main/pytest.ini ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [pytest]
2
+ minversion = 6.0
3
+ addopts = -ra -q --log-cli-level INFO
4
+ testpaths =
5
+ test
s3prl_s3prl_main/requirements/all.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ joblib >=0.12.4
2
+ librosa >=0.7.2
3
+ scipy >=1.5.4
4
+ scikit-learn >=0.23.2
5
+ pandas>=1.1.5
6
+ matplotlib >=3.3.4
7
+ Pillow >=6.2.2
8
+ numba >=0.48
9
+ cython >=0.29.21
10
+ packaging >=20.9
11
+ transformers >=4.10.0,<5.0
12
+ dtw-python ==1.3.0
13
+ asteroid ==0.4.4
14
+ sacrebleu >=2.0.0
15
+ h5py
16
+ sox
17
+ tabulate
18
+ intervaltree
19
+ lxml
20
+ soundfile
21
+ pysndfx
22
+ nltk
23
+ normalise
24
+ editdistance
25
+ easydict
26
+ catalyst
27
+ sentencepiece
28
+ huggingface_hub >=0.2.1
29
+ mutagen
30
+ pydantic
31
+ sed_eval
32
+ more_itertools
33
+ tensorboard
s3prl_s3prl_main/requirements/dev.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ pytest
3
+ pytest-xdist
4
+ tox
5
+ pre-commit
6
+ flake8 ==3.9.2
7
+ black ==22.3.0
8
+ isort ==5.10.1
9
+ sphinx ==5.1.1 # follow https://github.com/sphinx-doc/sphinx/issues/6316
10
+ furo
11
+ importlib_metadata <5.0 # required by flake8