Commit
·
9756d99
1
Parent(s):
23d93ea
First model version
Browse files- .gitattributes +1 -0
- .python-version +1 -0
- LICENSE +203 -0
- README.md +27 -0
- __pycache__/base_bert.cpython-38.pyc +0 -0
- __pycache__/bert.cpython-38.pyc +0 -0
- __pycache__/config.cpython-38.pyc +0 -0
- __pycache__/optimizer.cpython-38.pyc +0 -0
- __pycache__/tokenizer.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- base_bert.py +248 -0
- bert.py +225 -0
- cfimdb-classifier.pt +3 -0
- classifier.py +406 -0
- config.py +222 -0
- data/ids-cfimdb-dev.csv +3 -0
- data/ids-cfimdb-test-student.csv +3 -0
- data/ids-cfimdb-train.csv +3 -0
- data/ids-sst-dev.csv +3 -0
- data/ids-sst-test-student.csv +3 -0
- data/ids-sst-train.csv +3 -0
- data/quora-dev.csv +3 -0
- data/quora-test-student.csv +3 -0
- data/quora-train.csv +3 -0
- data/sts-dev.csv +3 -0
- data/sts-test-student.csv +3 -0
- data/sts-train.csv +3 -0
- datasets.py +272 -0
- evaluation.py +205 -0
- multitask_classifier.py +340 -0
- optimizer.py +90 -0
- optimizer_test.npy +3 -0
- optimizer_test.py +34 -0
- predictions/README +2 -0
- predictions/last-linear-layer-cfimdb-dev-out.csv +3 -0
- predictions/last-linear-layer-cfimdb-test-out.csv +3 -0
- predictions/last-linear-layer-sst-dev-out.csv +3 -0
- predictions/last-linear-layer-sst-test-out.csv +3 -0
- prepare_submit.py +18 -0
- sanity_check.data +0 -0
- sanity_check.py +19 -0
- setup.sh +13 -0
- sst-classifier.pt +3 -0
- tokenizer.py +0 -0
- utils.py +347 -0
- zemo1.py +53 -0
- zemo2.py +41 -0
- zemo3.py +32 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.8.20
|
LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright [yyyy] [name of copyright owner]
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CS 224N Default Final Project - Multitask BERT
|
2 |
+
|
3 |
+
This is the default final project for the Stanford CS 224N class. Please refer to the project handout on the course website for detailed instructions and an overview of the codebase.
|
4 |
+
|
5 |
+
This project comprises two parts. In the first part, you will implement some important components of the BERT model to better understand its architecture.
|
6 |
+
In the second part, you will use the embeddings produced by your BERT model on three downstream tasks: sentiment classification, paraphrase detection, and semantic similarity. You will implement extensions to improve your model's performance on the three downstream tasks.
|
7 |
+
|
8 |
+
In broad strokes, Part 1 of this project targets:
|
9 |
+
* bert.py: Missing code blocks.
|
10 |
+
* classifier.py: Missing code blocks.
|
11 |
+
* optimizer.py: Missing code blocks.
|
12 |
+
|
13 |
+
And Part 2 targets:
|
14 |
+
* multitask_classifier.py: Missing code blocks.
|
15 |
+
* datasets.py: Possibly useful functions/classes for extensions.
|
16 |
+
* evaluation.py: Possibly useful functions/classes for extensions.
|
17 |
+
|
18 |
+
## Setup instructions
|
19 |
+
|
20 |
+
Follow `setup.sh` to properly setup a conda environment and install dependencies.
|
21 |
+
|
22 |
+
## Acknowledgement
|
23 |
+
|
24 |
+
The BERT implementation part of the project was adapted from the "minbert" assignment developed at Carnegie Mellon University's [CS11-711 Advanced NLP](http://phontron.com/class/anlp2021/index.html),
|
25 |
+
created by Shuyan Zhou, Zhengbao Jiang, Ritam Dutt, Brendon Boldt, Aditya Veerubhotla, and Graham Neubig.
|
26 |
+
|
27 |
+
Parts of the code are from the [`transformers`](https://github.com/huggingface/transformers) library ([Apache License 2.0](./LICENSE)).
|
__pycache__/base_bert.cpython-38.pyc
ADDED
Binary file (7.19 kB). View file
|
|
__pycache__/bert.cpython-38.pyc
ADDED
Binary file (6.3 kB). View file
|
|
__pycache__/config.cpython-38.pyc
ADDED
Binary file (6.64 kB). View file
|
|
__pycache__/optimizer.cpython-38.pyc
ADDED
Binary file (2.37 kB). View file
|
|
__pycache__/tokenizer.cpython-38.pyc
ADDED
Binary file (76.3 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (9.09 kB). View file
|
|
base_bert.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from torch import device, dtype
|
3 |
+
from config import BertConfig, PretrainedConfig
|
4 |
+
from utils import *
|
5 |
+
|
6 |
+
|
7 |
+
class BertPreTrainedModel(nn.Module):
|
8 |
+
config_class = BertConfig
|
9 |
+
base_model_prefix = "bert"
|
10 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
11 |
+
_keys_to_ignore_on_load_unexpected = None
|
12 |
+
|
13 |
+
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
14 |
+
super().__init__()
|
15 |
+
self.config = config
|
16 |
+
self.name_or_path = config.name_or_path
|
17 |
+
|
18 |
+
def init_weights(self):
|
19 |
+
# Initialize weights
|
20 |
+
self.apply(self._init_weights)
|
21 |
+
|
22 |
+
def _init_weights(self, module):
|
23 |
+
""" Initialize the weights """
|
24 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
25 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
26 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
27 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
28 |
+
elif isinstance(module, nn.LayerNorm):
|
29 |
+
module.bias.data.zero_()
|
30 |
+
module.weight.data.fill_(1.0)
|
31 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
32 |
+
module.bias.data.zero_()
|
33 |
+
|
34 |
+
@property
|
35 |
+
def dtype(self) -> dtype:
|
36 |
+
return get_parameter_dtype(self)
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
40 |
+
config = kwargs.pop("config", None)
|
41 |
+
state_dict = kwargs.pop("state_dict", None)
|
42 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
43 |
+
force_download = kwargs.pop("force_download", False)
|
44 |
+
resume_download = kwargs.pop("resume_download", False)
|
45 |
+
proxies = kwargs.pop("proxies", None)
|
46 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
47 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
48 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
49 |
+
revision = kwargs.pop("revision", None)
|
50 |
+
mirror = kwargs.pop("mirror", None)
|
51 |
+
|
52 |
+
# Load config if we don't provide a configuration
|
53 |
+
if not isinstance(config, PretrainedConfig):
|
54 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
55 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
56 |
+
config_path,
|
57 |
+
*model_args,
|
58 |
+
cache_dir=cache_dir,
|
59 |
+
return_unused_kwargs=True,
|
60 |
+
force_download=force_download,
|
61 |
+
resume_download=resume_download,
|
62 |
+
proxies=proxies,
|
63 |
+
local_files_only=local_files_only,
|
64 |
+
use_auth_token=use_auth_token,
|
65 |
+
revision=revision,
|
66 |
+
**kwargs,
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
model_kwargs = kwargs
|
70 |
+
|
71 |
+
# Load model
|
72 |
+
if pretrained_model_name_or_path is not None:
|
73 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
74 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
75 |
+
# Load from a PyTorch checkpoint
|
76 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
77 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
78 |
+
archive_file = pretrained_model_name_or_path
|
79 |
+
else:
|
80 |
+
archive_file = hf_bucket_url(
|
81 |
+
pretrained_model_name_or_path,
|
82 |
+
filename=WEIGHTS_NAME,
|
83 |
+
revision=revision,
|
84 |
+
mirror=mirror,
|
85 |
+
)
|
86 |
+
try:
|
87 |
+
# Load from URL or cache if already cached
|
88 |
+
resolved_archive_file = cached_path(
|
89 |
+
archive_file,
|
90 |
+
cache_dir=cache_dir,
|
91 |
+
force_download=force_download,
|
92 |
+
proxies=proxies,
|
93 |
+
resume_download=resume_download,
|
94 |
+
local_files_only=local_files_only,
|
95 |
+
use_auth_token=use_auth_token,
|
96 |
+
)
|
97 |
+
except EnvironmentError as err:
|
98 |
+
#logger.error(err)
|
99 |
+
msg = (
|
100 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
101 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
102 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
|
103 |
+
)
|
104 |
+
raise EnvironmentError(msg)
|
105 |
+
else:
|
106 |
+
resolved_archive_file = None
|
107 |
+
|
108 |
+
config.name_or_path = pretrained_model_name_or_path
|
109 |
+
|
110 |
+
# Instantiate model.
|
111 |
+
model = cls(config, *model_args, **model_kwargs)
|
112 |
+
|
113 |
+
if state_dict is None:
|
114 |
+
try:
|
115 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
|
116 |
+
except Exception:
|
117 |
+
raise OSError(
|
118 |
+
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
119 |
+
f"at '{resolved_archive_file}'"
|
120 |
+
)
|
121 |
+
|
122 |
+
missing_keys = []
|
123 |
+
unexpected_keys = []
|
124 |
+
error_msgs = []
|
125 |
+
|
126 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
127 |
+
old_keys = []
|
128 |
+
new_keys = []
|
129 |
+
m = {'embeddings.word_embeddings': 'word_embedding',
|
130 |
+
'embeddings.position_embeddings': 'pos_embedding',
|
131 |
+
'embeddings.token_type_embeddings': 'tk_type_embedding',
|
132 |
+
'embeddings.LayerNorm': 'embed_layer_norm',
|
133 |
+
'embeddings.dropout': 'embed_dropout',
|
134 |
+
'encoder.layer': 'bert_layers',
|
135 |
+
'pooler.dense': 'pooler_dense',
|
136 |
+
'pooler.activation': 'pooler_af',
|
137 |
+
'attention.self': "self_attention",
|
138 |
+
'attention.output.dense': 'attention_dense',
|
139 |
+
'attention.output.LayerNorm': 'attention_layer_norm',
|
140 |
+
'attention.output.dropout': 'attention_dropout',
|
141 |
+
'intermediate.dense': 'interm_dense',
|
142 |
+
'intermediate.intermediate_act_fn': 'interm_af',
|
143 |
+
'output.dense': 'out_dense',
|
144 |
+
'output.LayerNorm': 'out_layer_norm',
|
145 |
+
'output.dropout': 'out_dropout'}
|
146 |
+
|
147 |
+
for key in state_dict.keys():
|
148 |
+
new_key = None
|
149 |
+
if "gamma" in key:
|
150 |
+
new_key = key.replace("gamma", "weight")
|
151 |
+
if "beta" in key:
|
152 |
+
new_key = key.replace("beta", "bias")
|
153 |
+
for x, y in m.items():
|
154 |
+
if new_key is not None:
|
155 |
+
_key = new_key
|
156 |
+
else:
|
157 |
+
_key = key
|
158 |
+
if x in key:
|
159 |
+
new_key = _key.replace(x, y)
|
160 |
+
if new_key:
|
161 |
+
old_keys.append(key)
|
162 |
+
new_keys.append(new_key)
|
163 |
+
|
164 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
165 |
+
# print(old_key, new_key)
|
166 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
167 |
+
|
168 |
+
# copy state_dict so _load_from_state_dict can modify it
|
169 |
+
metadata = getattr(state_dict, "_metadata", None)
|
170 |
+
state_dict = state_dict.copy()
|
171 |
+
if metadata is not None:
|
172 |
+
state_dict._metadata = metadata
|
173 |
+
|
174 |
+
your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
|
175 |
+
for k in state_dict:
|
176 |
+
if k not in your_bert_params and not k.startswith("cls."):
|
177 |
+
possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
|
178 |
+
raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
|
179 |
+
|
180 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
181 |
+
# so we need to apply the function recursively.
|
182 |
+
def load(module: nn.Module, prefix=""):
|
183 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
184 |
+
module._load_from_state_dict(
|
185 |
+
state_dict,
|
186 |
+
prefix,
|
187 |
+
local_metadata,
|
188 |
+
True,
|
189 |
+
missing_keys,
|
190 |
+
unexpected_keys,
|
191 |
+
error_msgs,
|
192 |
+
)
|
193 |
+
for name, child in module._modules.items():
|
194 |
+
if child is not None:
|
195 |
+
load(child, prefix + name + ".")
|
196 |
+
|
197 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
198 |
+
start_prefix = ""
|
199 |
+
model_to_load = model
|
200 |
+
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
201 |
+
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
202 |
+
start_prefix = cls.base_model_prefix + "."
|
203 |
+
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
204 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
205 |
+
load(model_to_load, prefix=start_prefix)
|
206 |
+
|
207 |
+
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
208 |
+
base_model_state_dict = model_to_load.state_dict().keys()
|
209 |
+
head_model_state_dict_without_base_prefix = [
|
210 |
+
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
211 |
+
]
|
212 |
+
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
213 |
+
|
214 |
+
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
215 |
+
# the user.
|
216 |
+
if cls._keys_to_ignore_on_load_missing is not None:
|
217 |
+
for pat in cls._keys_to_ignore_on_load_missing:
|
218 |
+
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
219 |
+
|
220 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
221 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
222 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
223 |
+
|
224 |
+
if len(error_msgs) > 0:
|
225 |
+
raise RuntimeError(
|
226 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
227 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
228 |
+
)
|
229 |
+
)
|
230 |
+
|
231 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
232 |
+
model.eval()
|
233 |
+
|
234 |
+
if output_loading_info:
|
235 |
+
loading_info = {
|
236 |
+
"missing_keys": missing_keys,
|
237 |
+
"unexpected_keys": unexpected_keys,
|
238 |
+
"error_msgs": error_msgs,
|
239 |
+
}
|
240 |
+
return model, loading_info
|
241 |
+
|
242 |
+
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
243 |
+
import torch_xla.core.xla_model as xm
|
244 |
+
|
245 |
+
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
246 |
+
model.to(xm.xla_device())
|
247 |
+
|
248 |
+
return model
|
bert.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from base_bert import BertPreTrainedModel
|
6 |
+
from utils import *
|
7 |
+
|
8 |
+
|
9 |
+
class BertSelfAttention(nn.Module):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.num_attention_heads = config.num_attention_heads
|
14 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
15 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
16 |
+
|
17 |
+
# Initialize the linear transformation layers for key, value, query.
|
18 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
19 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
20 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
21 |
+
# This dropout is applied to normalized attention scores following the original
|
22 |
+
# implementation of transformer. Although it is a bit unusual, we empirically
|
23 |
+
# observe that it yields better performance.
|
24 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
25 |
+
|
26 |
+
def transform(self, x, linear_layer):
|
27 |
+
# The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
|
28 |
+
bs, seq_len = x.shape[:2]
|
29 |
+
proj = linear_layer(x)
|
30 |
+
# Next, we need to produce multiple heads for the proj. This is done by spliting the
|
31 |
+
# hidden state to self.num_attention_heads, each of size self.attention_head_size.
|
32 |
+
proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
|
33 |
+
# By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
|
34 |
+
proj = proj.transpose(1, 2)
|
35 |
+
return proj
|
36 |
+
|
37 |
+
def attention(self, key, query, value, attention_mask):
|
38 |
+
"""
|
39 |
+
key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
40 |
+
attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
|
41 |
+
"""
|
42 |
+
|
43 |
+
d_k = query.size(-1) # attention_head_size
|
44 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
|
45 |
+
# attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
|
46 |
+
|
47 |
+
# Apply attention mask
|
48 |
+
attention_scores = attention_scores + attention_mask
|
49 |
+
|
50 |
+
# Normalize scores with softmax and apply dropout.
|
51 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
52 |
+
attention_probs = self.dropout(attention_probs)
|
53 |
+
|
54 |
+
context = torch.matmul(attention_probs, value)
|
55 |
+
# context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
56 |
+
|
57 |
+
# Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
|
58 |
+
context = context.transpose(1, 2).contiguous()
|
59 |
+
context = context.view(context.size(0), context.size(1), -1)
|
60 |
+
|
61 |
+
return context
|
62 |
+
|
63 |
+
|
64 |
+
def forward(self, hidden_states, attention_mask):
|
65 |
+
"""
|
66 |
+
hidden_states: [bs, seq_len, hidden_state]
|
67 |
+
attention_mask: [bs, 1, 1, seq_len]
|
68 |
+
output: [bs, seq_len, hidden_state]
|
69 |
+
"""
|
70 |
+
# First, we have to generate the key, value, query for each token for multi-head attention
|
71 |
+
# using self.transform (more details inside the function).
|
72 |
+
# Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
|
73 |
+
key_layer = self.transform(hidden_states, self.key)
|
74 |
+
value_layer = self.transform(hidden_states, self.value)
|
75 |
+
query_layer = self.transform(hidden_states, self.query)
|
76 |
+
# Calculate the multi-head attention.
|
77 |
+
attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
|
78 |
+
return attn_value
|
79 |
+
|
80 |
+
|
81 |
+
class BertLayer(nn.Module):
|
82 |
+
def __init__(self, config):
|
83 |
+
super().__init__()
|
84 |
+
# Multi-head attention.
|
85 |
+
self.self_attention = BertSelfAttention(config)
|
86 |
+
# Add-norm for multi-head attention.
|
87 |
+
self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
88 |
+
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
89 |
+
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
|
90 |
+
# Feed forward.
|
91 |
+
self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
92 |
+
self.interm_af = F.gelu
|
93 |
+
# Add-norm for feed forward.
|
94 |
+
self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
95 |
+
self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
96 |
+
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
|
97 |
+
|
98 |
+
|
99 |
+
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
|
100 |
+
transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
|
101 |
+
transformed_output = dropout(transformed_output) # Áp dụng dropout
|
102 |
+
added_output = input + transformed_output # Kết hợp input và output
|
103 |
+
normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
|
104 |
+
return normalized_output
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, hidden_states, attention_mask):
|
108 |
+
# 1. Multi-head attention
|
109 |
+
attention_output = self.self_attention(hidden_states, attention_mask)
|
110 |
+
|
111 |
+
# 2. Add-norm after attention
|
112 |
+
attention_output = self.add_norm(
|
113 |
+
hidden_states,
|
114 |
+
attention_output,
|
115 |
+
self.attention_dense,
|
116 |
+
self.attention_dropout,
|
117 |
+
self.attention_layer_norm
|
118 |
+
)
|
119 |
+
|
120 |
+
# 3. Feed-forward network
|
121 |
+
intermediate_output = self.interm_af(self.interm_dense(attention_output))
|
122 |
+
|
123 |
+
# 4. Add-norm after feed-forward
|
124 |
+
layer_output = self.add_norm(
|
125 |
+
attention_output,
|
126 |
+
intermediate_output,
|
127 |
+
self.out_dense,
|
128 |
+
self.out_dropout,
|
129 |
+
self.out_layer_norm
|
130 |
+
)
|
131 |
+
|
132 |
+
return layer_output
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
class BertModel(BertPreTrainedModel):
|
138 |
+
"""
|
139 |
+
The BERT model returns the final embeddings for each token in a sentence.
|
140 |
+
|
141 |
+
The model consists of:
|
142 |
+
1. Embedding layers (used in self.embed).
|
143 |
+
2. A stack of n BERT layers (used in self.encode).
|
144 |
+
3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
|
145 |
+
"""
|
146 |
+
def __init__(self, config):
|
147 |
+
super().__init__(config)
|
148 |
+
self.config = config
|
149 |
+
|
150 |
+
# Embedding layers.
|
151 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
152 |
+
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
153 |
+
self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
154 |
+
self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
155 |
+
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
|
156 |
+
# Register position_ids (1, len position emb) to buffer because it is a constant.
|
157 |
+
position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
|
158 |
+
self.register_buffer('position_ids', position_ids)
|
159 |
+
|
160 |
+
# BERT encoder.
|
161 |
+
self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
162 |
+
|
163 |
+
# [CLS] token transformations.
|
164 |
+
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
165 |
+
self.pooler_af = nn.Tanh()
|
166 |
+
|
167 |
+
self.init_weights()
|
168 |
+
|
169 |
+
|
170 |
+
def embed(self, input_ids):
|
171 |
+
input_shape = input_ids.size()
|
172 |
+
seq_length = input_shape[1]
|
173 |
+
|
174 |
+
inputs_embeds = self.word_embedding(input_ids)
|
175 |
+
|
176 |
+
pos_ids = self.position_ids[:, :seq_length]
|
177 |
+
pos_embeds = self.pos_embedding(pos_ids)
|
178 |
+
|
179 |
+
# Since we are not considering token type, this embedding is just a placeholder.
|
180 |
+
tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
181 |
+
tk_type_embeds = self.tk_type_embedding(tk_type_ids)
|
182 |
+
|
183 |
+
embeddings = inputs_embeds + pos_embeds + tk_type_embeds
|
184 |
+
embeddings = self.embed_layer_norm(embeddings)
|
185 |
+
embeddings = self.embed_dropout(embeddings)
|
186 |
+
|
187 |
+
return embeddings
|
188 |
+
|
189 |
+
|
190 |
+
def encode(self, hidden_states, attention_mask):
|
191 |
+
"""
|
192 |
+
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
|
193 |
+
attention_mask: [batch_size, seq_len]
|
194 |
+
"""
|
195 |
+
# Get the extended attention mask for self-attention.
|
196 |
+
# Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
|
197 |
+
# Distinguishes between non-padding tokens (with a value of 0) and padding tokens
|
198 |
+
# (with a value of a large negative number).
|
199 |
+
extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
|
200 |
+
|
201 |
+
# Pass the hidden states through the encoder layers.
|
202 |
+
for i, layer_module in enumerate(self.bert_layers):
|
203 |
+
# Feed the encoding from the last bert_layer to the next.
|
204 |
+
hidden_states = layer_module(hidden_states, extended_attention_mask)
|
205 |
+
|
206 |
+
return hidden_states
|
207 |
+
|
208 |
+
|
209 |
+
def forward(self, input_ids, attention_mask):
|
210 |
+
"""
|
211 |
+
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
|
212 |
+
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
|
213 |
+
"""
|
214 |
+
# Get the embedding for each input token.
|
215 |
+
embedding_output = self.embed(input_ids=input_ids)
|
216 |
+
|
217 |
+
# Feed to a transformer (a stack of BertLayers).
|
218 |
+
sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
|
219 |
+
|
220 |
+
# Get cls token hidden state.
|
221 |
+
first_tk = sequence_output[:, 0]
|
222 |
+
first_tk = self.pooler_dense(first_tk)
|
223 |
+
first_tk = self.pooler_af(first_tk)
|
224 |
+
|
225 |
+
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
|
cfimdb-classifier.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1c66df3c0ce0e4326519041f49707f102df5f680de5ded1b5125ba689a9d141
|
3 |
+
size 438045778
|
classifier.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random, numpy as np, argparse
|
2 |
+
from types import SimpleNamespace
|
3 |
+
import csv
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from sklearn.metrics import f1_score, accuracy_score
|
9 |
+
|
10 |
+
from tokenizer import BertTokenizer
|
11 |
+
from bert import BertModel
|
12 |
+
from optimizer import AdamW
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
TQDM_DISABLE=False
|
17 |
+
|
18 |
+
|
19 |
+
# Fix the random seed.
|
20 |
+
def seed_everything(seed=11711):
|
21 |
+
random.seed(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed_all(seed)
|
26 |
+
torch.backends.cudnn.benchmark = False
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
|
29 |
+
|
30 |
+
class BertSentimentClassifier(torch.nn.Module):
|
31 |
+
'''
|
32 |
+
This module performs sentiment classification using BERT embeddings on the SST dataset.
|
33 |
+
|
34 |
+
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
|
35 |
+
Thus, your forward() should return one logit for each of the 5 classes.
|
36 |
+
'''
|
37 |
+
def __init__(self, config):
|
38 |
+
super(BertSentimentClassifier, self).__init__()
|
39 |
+
self.num_labels = config.num_labels
|
40 |
+
self.bert: BertModel = BertModel.from_pretrained('bert-base-uncased')
|
41 |
+
|
42 |
+
# Pretrain mode does not require updating BERT paramters.
|
43 |
+
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
44 |
+
for param in self.bert.parameters():
|
45 |
+
if config.fine_tune_mode == 'last-linear-layer':
|
46 |
+
param.requires_grad = False
|
47 |
+
elif config.fine_tune_mode == 'full-model':
|
48 |
+
param.requires_grad = True
|
49 |
+
|
50 |
+
# Create any instance variables you need to classify the sentiment of BERT embeddings.
|
51 |
+
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, input_ids, attention_mask):
|
55 |
+
'''Takes a batch of sentences and returns logits for sentiment classes'''
|
56 |
+
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
|
57 |
+
# HINT: You should consider what is an appropriate return value given that
|
58 |
+
# the training loop currently uses F.cross_entropy as the loss function.
|
59 |
+
|
60 |
+
# Get the embedding for each input token.
|
61 |
+
embedding_output = self.bert.embed(input_ids=input_ids)
|
62 |
+
|
63 |
+
# Feed to a transformer (BERT layers).
|
64 |
+
sequence_output = self.bert.encode(embedding_output, attention_mask=attention_mask)
|
65 |
+
|
66 |
+
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
|
67 |
+
cls_token_output = sequence_output[:, 0, :] # The first token is [CLS]
|
68 |
+
|
69 |
+
# Pass the [CLS] token representation through the classifier.
|
70 |
+
logits = self.classifier(cls_token_output)
|
71 |
+
|
72 |
+
return logits
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
class SentimentDataset(Dataset):
|
77 |
+
def __init__(self, dataset, args):
|
78 |
+
self.dataset = dataset
|
79 |
+
self.p = args
|
80 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return len(self.dataset)
|
84 |
+
|
85 |
+
def __getitem__(self, idx):
|
86 |
+
return self.dataset[idx]
|
87 |
+
|
88 |
+
def pad_data(self, data):
|
89 |
+
sents = [x[0] for x in data]
|
90 |
+
labels = [x[1] for x in data]
|
91 |
+
sent_ids = [x[2] for x in data]
|
92 |
+
|
93 |
+
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
94 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
95 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
96 |
+
labels = torch.LongTensor(labels)
|
97 |
+
|
98 |
+
return token_ids, attention_mask, labels, sents, sent_ids
|
99 |
+
|
100 |
+
def collate_fn(self, all_data):
|
101 |
+
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
|
102 |
+
|
103 |
+
batched_data = {
|
104 |
+
'token_ids': token_ids,
|
105 |
+
'attention_mask': attention_mask,
|
106 |
+
'labels': labels,
|
107 |
+
'sents': sents,
|
108 |
+
'sent_ids': sent_ids
|
109 |
+
}
|
110 |
+
|
111 |
+
return batched_data
|
112 |
+
|
113 |
+
|
114 |
+
class SentimentTestDataset(Dataset):
|
115 |
+
def __init__(self, dataset, args):
|
116 |
+
self.dataset = dataset
|
117 |
+
self.p = args
|
118 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
119 |
+
|
120 |
+
def __len__(self):
|
121 |
+
return len(self.dataset)
|
122 |
+
|
123 |
+
def __getitem__(self, idx):
|
124 |
+
return self.dataset[idx]
|
125 |
+
|
126 |
+
def pad_data(self, data):
|
127 |
+
sents = [x[0] for x in data]
|
128 |
+
sent_ids = [x[1] for x in data]
|
129 |
+
|
130 |
+
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
131 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
132 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
133 |
+
|
134 |
+
return token_ids, attention_mask, sents, sent_ids
|
135 |
+
|
136 |
+
def collate_fn(self, all_data):
|
137 |
+
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
138 |
+
|
139 |
+
batched_data = {
|
140 |
+
'token_ids': token_ids,
|
141 |
+
'attention_mask': attention_mask,
|
142 |
+
'sents': sents,
|
143 |
+
'sent_ids': sent_ids
|
144 |
+
}
|
145 |
+
|
146 |
+
return batched_data
|
147 |
+
|
148 |
+
|
149 |
+
# Load the data: a list of (sentence, label).
|
150 |
+
def load_data(filename, flag='train'):
|
151 |
+
num_labels = {}
|
152 |
+
data = []
|
153 |
+
if flag == 'test':
|
154 |
+
with open(filename, 'r') as fp:
|
155 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
156 |
+
sent = record['sentence'].lower().strip()
|
157 |
+
sent_id = record['id'].lower().strip()
|
158 |
+
data.append((sent,sent_id))
|
159 |
+
else:
|
160 |
+
with open(filename, 'r') as fp:
|
161 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
162 |
+
sent = record['sentence'].lower().strip()
|
163 |
+
sent_id = record['id'].lower().strip()
|
164 |
+
label = int(record['sentiment'].strip())
|
165 |
+
if label not in num_labels:
|
166 |
+
num_labels[label] = len(num_labels)
|
167 |
+
data.append((sent, label,sent_id))
|
168 |
+
print(f"load {len(data)} data from {filename}")
|
169 |
+
|
170 |
+
if flag == 'train':
|
171 |
+
return data, len(num_labels)
|
172 |
+
else:
|
173 |
+
return data
|
174 |
+
|
175 |
+
|
176 |
+
# Evaluate the model on dev examples.
|
177 |
+
def model_eval(dataloader, model, device):
|
178 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
179 |
+
y_true = []
|
180 |
+
y_pred = []
|
181 |
+
sents = []
|
182 |
+
sent_ids = []
|
183 |
+
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
184 |
+
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
185 |
+
batch['labels'], batch['sents'], batch['sent_ids']
|
186 |
+
|
187 |
+
b_ids = b_ids.to(device)
|
188 |
+
b_mask = b_mask.to(device)
|
189 |
+
|
190 |
+
logits = model(b_ids, b_mask)
|
191 |
+
logits = logits.detach().cpu().numpy()
|
192 |
+
preds = np.argmax(logits, axis=1).flatten()
|
193 |
+
|
194 |
+
b_labels = b_labels.flatten()
|
195 |
+
y_true.extend(b_labels)
|
196 |
+
y_pred.extend(preds)
|
197 |
+
sents.extend(b_sents)
|
198 |
+
sent_ids.extend(b_sent_ids)
|
199 |
+
|
200 |
+
f1 = f1_score(y_true, y_pred, average='macro')
|
201 |
+
acc = accuracy_score(y_true, y_pred)
|
202 |
+
|
203 |
+
return acc, f1, y_pred, y_true, sents, sent_ids
|
204 |
+
|
205 |
+
|
206 |
+
# Evaluate the model on test examples.
|
207 |
+
def model_test_eval(dataloader, model, device):
|
208 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
209 |
+
y_pred = []
|
210 |
+
sents = []
|
211 |
+
sent_ids = []
|
212 |
+
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
213 |
+
b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
214 |
+
batch['sents'], batch['sent_ids']
|
215 |
+
|
216 |
+
b_ids = b_ids.to(device)
|
217 |
+
b_mask = b_mask.to(device)
|
218 |
+
|
219 |
+
logits = model(b_ids, b_mask)
|
220 |
+
logits = logits.detach().cpu().numpy()
|
221 |
+
preds = np.argmax(logits, axis=1).flatten()
|
222 |
+
|
223 |
+
y_pred.extend(preds)
|
224 |
+
sents.extend(b_sents)
|
225 |
+
sent_ids.extend(b_sent_ids)
|
226 |
+
|
227 |
+
return y_pred, sents, sent_ids
|
228 |
+
|
229 |
+
|
230 |
+
def save_model(model, optimizer, args, config, filepath):
|
231 |
+
save_info = {
|
232 |
+
'model': model.state_dict(),
|
233 |
+
'optim': optimizer.state_dict(),
|
234 |
+
'args': args,
|
235 |
+
'model_config': config,
|
236 |
+
'system_rng': random.getstate(),
|
237 |
+
'numpy_rng': np.random.get_state(),
|
238 |
+
'torch_rng': torch.random.get_rng_state(),
|
239 |
+
}
|
240 |
+
|
241 |
+
torch.save(save_info, filepath)
|
242 |
+
print(f"save the model to {filepath}")
|
243 |
+
|
244 |
+
|
245 |
+
def train(args):
|
246 |
+
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
247 |
+
# Create the data and its corresponding datasets and dataloader.
|
248 |
+
train_data, num_labels = load_data(args.train, 'train')
|
249 |
+
dev_data = load_data(args.dev, 'valid')
|
250 |
+
|
251 |
+
train_dataset = SentimentDataset(train_data, args)
|
252 |
+
dev_dataset = SentimentDataset(dev_data, args)
|
253 |
+
|
254 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
255 |
+
collate_fn=train_dataset.collate_fn)
|
256 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
257 |
+
collate_fn=dev_dataset.collate_fn)
|
258 |
+
|
259 |
+
# Init model.
|
260 |
+
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
261 |
+
'num_labels': num_labels,
|
262 |
+
'hidden_size': 768,
|
263 |
+
'data_dir': '.',
|
264 |
+
'fine_tune_mode': args.fine_tune_mode}
|
265 |
+
|
266 |
+
config = SimpleNamespace(**config)
|
267 |
+
|
268 |
+
model = BertSentimentClassifier(config)
|
269 |
+
model = model.to(device)
|
270 |
+
|
271 |
+
lr = args.lr
|
272 |
+
optimizer = AdamW(model.parameters(), lr=lr)
|
273 |
+
best_dev_acc = 0
|
274 |
+
|
275 |
+
# Run for the specified number of epochs.
|
276 |
+
for epoch in range(args.epochs):
|
277 |
+
model.train()
|
278 |
+
train_loss = 0
|
279 |
+
num_batches = 0
|
280 |
+
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
|
281 |
+
b_ids, b_mask, b_labels = (batch['token_ids'],
|
282 |
+
batch['attention_mask'], batch['labels'])
|
283 |
+
|
284 |
+
b_ids = b_ids.to(device)
|
285 |
+
b_mask = b_mask.to(device)
|
286 |
+
b_labels = b_labels.to(device)
|
287 |
+
|
288 |
+
optimizer.zero_grad()
|
289 |
+
logits = model(b_ids, b_mask)
|
290 |
+
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
|
291 |
+
|
292 |
+
loss.backward()
|
293 |
+
optimizer.step()
|
294 |
+
|
295 |
+
train_loss += loss.item()
|
296 |
+
num_batches += 1
|
297 |
+
|
298 |
+
train_loss = train_loss / (num_batches)
|
299 |
+
|
300 |
+
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
|
301 |
+
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
|
302 |
+
|
303 |
+
if dev_acc > best_dev_acc:
|
304 |
+
best_dev_acc = dev_acc
|
305 |
+
save_model(model, optimizer, args, config, args.filepath)
|
306 |
+
|
307 |
+
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
|
308 |
+
|
309 |
+
|
310 |
+
def test(args):
|
311 |
+
with torch.no_grad():
|
312 |
+
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
313 |
+
saved = torch.load(args.filepath)
|
314 |
+
config = saved['model_config']
|
315 |
+
model = BertSentimentClassifier(config)
|
316 |
+
model.load_state_dict(saved['model'])
|
317 |
+
model = model.to(device)
|
318 |
+
print(f"load model from {args.filepath}")
|
319 |
+
|
320 |
+
dev_data = load_data(args.dev, 'valid')
|
321 |
+
dev_dataset = SentimentDataset(dev_data, args)
|
322 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)
|
323 |
+
|
324 |
+
test_data = load_data(args.test, 'test')
|
325 |
+
test_dataset = SentimentTestDataset(test_data, args)
|
326 |
+
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
|
327 |
+
|
328 |
+
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
|
329 |
+
print('DONE DEV')
|
330 |
+
test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
|
331 |
+
print('DONE Test')
|
332 |
+
with open(args.dev_out, "w+") as f:
|
333 |
+
print(f"dev acc :: {dev_acc :.3f}")
|
334 |
+
f.write(f"id \t Predicted_Sentiment \n")
|
335 |
+
for p, s in zip(dev_sent_ids,dev_pred ):
|
336 |
+
f.write(f"{p} , {s} \n")
|
337 |
+
|
338 |
+
with open(args.test_out, "w+") as f:
|
339 |
+
f.write(f"id \t Predicted_Sentiment \n")
|
340 |
+
for p, s in zip(test_sent_ids,test_pred ):
|
341 |
+
f.write(f"{p} , {s} \n")
|
342 |
+
|
343 |
+
|
344 |
+
def get_args():
|
345 |
+
parser = argparse.ArgumentParser()
|
346 |
+
parser.add_argument("--seed", type=int, default=11711)
|
347 |
+
parser.add_argument("--epochs", type=int, default=10)
|
348 |
+
parser.add_argument("--fine-tune-mode", type=str,
|
349 |
+
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
350 |
+
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
351 |
+
parser.add_argument("--use_gpu", action='store_true')
|
352 |
+
|
353 |
+
parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
|
354 |
+
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
355 |
+
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
|
356 |
+
default=1e-3)
|
357 |
+
|
358 |
+
args = parser.parse_args()
|
359 |
+
return args
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
args = get_args()
|
364 |
+
seed_everything(args.seed)
|
365 |
+
|
366 |
+
print('Training Sentiment Classifier on SST...')
|
367 |
+
config = SimpleNamespace(
|
368 |
+
filepath='sst-classifier.pt',
|
369 |
+
lr=args.lr,
|
370 |
+
use_gpu=args.use_gpu,
|
371 |
+
epochs=args.epochs,
|
372 |
+
batch_size=args.batch_size,
|
373 |
+
hidden_dropout_prob=args.hidden_dropout_prob,
|
374 |
+
train='data/ids-sst-train.csv',
|
375 |
+
dev='data/ids-sst-dev.csv',
|
376 |
+
test='data/ids-sst-test-student.csv',
|
377 |
+
fine_tune_mode=args.fine_tune_mode,
|
378 |
+
dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
|
379 |
+
test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
|
380 |
+
)
|
381 |
+
|
382 |
+
train(config)
|
383 |
+
|
384 |
+
print('Evaluating on SST...')
|
385 |
+
test(config)
|
386 |
+
|
387 |
+
print('Training Sentiment Classifier on cfimdb...')
|
388 |
+
config = SimpleNamespace(
|
389 |
+
filepath='cfimdb-classifier.pt',
|
390 |
+
lr=args.lr,
|
391 |
+
use_gpu=args.use_gpu,
|
392 |
+
epochs=args.epochs,
|
393 |
+
batch_size=8,
|
394 |
+
hidden_dropout_prob=args.hidden_dropout_prob,
|
395 |
+
train='data/ids-cfimdb-train.csv',
|
396 |
+
dev='data/ids-cfimdb-dev.csv',
|
397 |
+
test='data/ids-cfimdb-test-student.csv',
|
398 |
+
fine_tune_mode=args.fine_tune_mode,
|
399 |
+
dev_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-dev-out.csv',
|
400 |
+
test_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-test-out.csv'
|
401 |
+
)
|
402 |
+
|
403 |
+
train(config)
|
404 |
+
|
405 |
+
print('Evaluating on cfimdb...')
|
406 |
+
test(config)
|
config.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple, Dict, Any, Optional
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from collections import OrderedDict
|
5 |
+
import torch
|
6 |
+
from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
|
7 |
+
|
8 |
+
class PretrainedConfig(object):
|
9 |
+
model_type: str = ""
|
10 |
+
is_composition: bool = False
|
11 |
+
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
# Attributes with defaults
|
14 |
+
self.return_dict = kwargs.pop("return_dict", True)
|
15 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
16 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
17 |
+
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
18 |
+
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
19 |
+
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
20 |
+
self.tie_word_embeddings = kwargs.pop(
|
21 |
+
"tie_word_embeddings", True
|
22 |
+
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
23 |
+
|
24 |
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
25 |
+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
26 |
+
self.is_decoder = kwargs.pop("is_decoder", False)
|
27 |
+
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
28 |
+
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
29 |
+
|
30 |
+
# Parameters for sequence generation
|
31 |
+
self.max_length = kwargs.pop("max_length", 20)
|
32 |
+
self.min_length = kwargs.pop("min_length", 0)
|
33 |
+
self.do_sample = kwargs.pop("do_sample", False)
|
34 |
+
self.early_stopping = kwargs.pop("early_stopping", False)
|
35 |
+
self.num_beams = kwargs.pop("num_beams", 1)
|
36 |
+
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
37 |
+
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
38 |
+
self.temperature = kwargs.pop("temperature", 1.0)
|
39 |
+
self.top_k = kwargs.pop("top_k", 50)
|
40 |
+
self.top_p = kwargs.pop("top_p", 1.0)
|
41 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
42 |
+
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
43 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
44 |
+
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
45 |
+
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
46 |
+
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
47 |
+
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
48 |
+
self.output_scores = kwargs.pop("output_scores", False)
|
49 |
+
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
50 |
+
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
51 |
+
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
52 |
+
|
53 |
+
# Fine-tuning task arguments
|
54 |
+
self.architectures = kwargs.pop("architectures", None)
|
55 |
+
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
56 |
+
self.id2label = kwargs.pop("id2label", None)
|
57 |
+
self.label2id = kwargs.pop("label2id", None)
|
58 |
+
if self.id2label is not None:
|
59 |
+
kwargs.pop("num_labels", None)
|
60 |
+
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
61 |
+
# Keys are always strings in JSON so convert ids to int here.
|
62 |
+
else:
|
63 |
+
self.num_labels = kwargs.pop("num_labels", 2)
|
64 |
+
|
65 |
+
# Tokenizer arguments
|
66 |
+
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
67 |
+
self.prefix = kwargs.pop("prefix", None)
|
68 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
69 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
70 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
71 |
+
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
72 |
+
|
73 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
74 |
+
|
75 |
+
# task specific arguments
|
76 |
+
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
77 |
+
|
78 |
+
# TPU arguments
|
79 |
+
self.xla_device = kwargs.pop("xla_device", None)
|
80 |
+
|
81 |
+
# Name or path to the pretrained checkpoint
|
82 |
+
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
83 |
+
|
84 |
+
# Drop the transformers version info
|
85 |
+
kwargs.pop("transformers_version", None)
|
86 |
+
|
87 |
+
# Additional attributes without default values
|
88 |
+
for key, value in kwargs.items():
|
89 |
+
try:
|
90 |
+
setattr(self, key, value)
|
91 |
+
except AttributeError as err:
|
92 |
+
raise err
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
96 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
97 |
+
return cls.from_dict(config_dict, **kwargs)
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
101 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
102 |
+
text = reader.read()
|
103 |
+
return json.loads(text)
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
107 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
108 |
+
|
109 |
+
config = cls(**config_dict)
|
110 |
+
|
111 |
+
if hasattr(config, "pruned_heads"):
|
112 |
+
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
113 |
+
|
114 |
+
# Update config with kwargs if needed
|
115 |
+
to_remove = []
|
116 |
+
for key, value in kwargs.items():
|
117 |
+
if hasattr(config, key):
|
118 |
+
setattr(config, key, value)
|
119 |
+
to_remove.append(key)
|
120 |
+
for key in to_remove:
|
121 |
+
kwargs.pop(key, None)
|
122 |
+
|
123 |
+
if return_unused_kwargs:
|
124 |
+
return config, kwargs
|
125 |
+
else:
|
126 |
+
return config
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def get_config_dict(
|
130 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
131 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
132 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
133 |
+
force_download = kwargs.pop("force_download", False)
|
134 |
+
resume_download = kwargs.pop("resume_download", False)
|
135 |
+
proxies = kwargs.pop("proxies", None)
|
136 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
137 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
138 |
+
revision = kwargs.pop("revision", None)
|
139 |
+
|
140 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
141 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
142 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
143 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
144 |
+
config_file = pretrained_model_name_or_path
|
145 |
+
else:
|
146 |
+
config_file = hf_bucket_url(
|
147 |
+
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
148 |
+
)
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Load from URL or cache if already cached
|
152 |
+
resolved_config_file = cached_path(
|
153 |
+
config_file,
|
154 |
+
cache_dir=cache_dir,
|
155 |
+
force_download=force_download,
|
156 |
+
proxies=proxies,
|
157 |
+
resume_download=resume_download,
|
158 |
+
local_files_only=local_files_only,
|
159 |
+
use_auth_token=use_auth_token,
|
160 |
+
)
|
161 |
+
# Load config dict
|
162 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
163 |
+
|
164 |
+
except EnvironmentError as err:
|
165 |
+
msg = (
|
166 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
167 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
168 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
169 |
+
)
|
170 |
+
raise EnvironmentError(msg)
|
171 |
+
|
172 |
+
except json.JSONDecodeError:
|
173 |
+
msg = (
|
174 |
+
"Couldn't reach server at '{}' to download configuration file or "
|
175 |
+
"configuration file is not a valid JSON file. "
|
176 |
+
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
177 |
+
)
|
178 |
+
raise EnvironmentError(msg)
|
179 |
+
|
180 |
+
return config_dict, kwargs
|
181 |
+
|
182 |
+
|
183 |
+
class BertConfig(PretrainedConfig):
|
184 |
+
model_type = "bert"
|
185 |
+
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
vocab_size=30522,
|
189 |
+
hidden_size=768,
|
190 |
+
num_hidden_layers=12,
|
191 |
+
num_attention_heads=12,
|
192 |
+
intermediate_size=3072,
|
193 |
+
hidden_act="gelu",
|
194 |
+
hidden_dropout_prob=0.1,
|
195 |
+
attention_probs_dropout_prob=0.1,
|
196 |
+
max_position_embeddings=512,
|
197 |
+
type_vocab_size=2,
|
198 |
+
initializer_range=0.02,
|
199 |
+
layer_norm_eps=1e-12,
|
200 |
+
pad_token_id=0,
|
201 |
+
gradient_checkpointing=False,
|
202 |
+
position_embedding_type="absolute",
|
203 |
+
use_cache=True,
|
204 |
+
**kwargs
|
205 |
+
):
|
206 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
207 |
+
|
208 |
+
self.vocab_size = vocab_size
|
209 |
+
self.hidden_size = hidden_size
|
210 |
+
self.num_hidden_layers = num_hidden_layers
|
211 |
+
self.num_attention_heads = num_attention_heads
|
212 |
+
self.hidden_act = hidden_act
|
213 |
+
self.intermediate_size = intermediate_size
|
214 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
215 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
216 |
+
self.max_position_embeddings = max_position_embeddings
|
217 |
+
self.type_vocab_size = type_vocab_size
|
218 |
+
self.initializer_range = initializer_range
|
219 |
+
self.layer_norm_eps = layer_norm_eps
|
220 |
+
self.gradient_checkpointing = gradient_checkpointing
|
221 |
+
self.position_embedding_type = position_embedding_type
|
222 |
+
self.use_cache = use_cache
|
data/ids-cfimdb-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
|
3 |
+
size 249095
|
data/ids-cfimdb-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
|
3 |
+
size 495595
|
data/ids-cfimdb-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
|
3 |
+
size 1693182
|
data/ids-sst-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
|
3 |
+
size 151384
|
data/ids-sst-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
|
3 |
+
size 313202
|
data/ids-sst-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
|
3 |
+
size 1175139
|
data/quora-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e9dc46b273a711d82a065f55e1754a9b92c10ad7345ebe0b0ebba61397dda4a
|
3 |
+
size 6896912
|
data/quora-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fa130f532cdde70287081aa04af13a4b12e3aa862e9162763d15fb46385497a
|
3 |
+
size 13487951
|
data/quora-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cd59e1ddb3a5b5d03f4a885c64e67aaf50122d9ab9ed7a476b5d2d6f7137ae8
|
3 |
+
size 48270674
|
data/sts-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce3cad6f16062586ac7ba462c28b010a9be10c530fd5074165860d7b7ab4e93d
|
3 |
+
size 132265
|
data/sts-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dee455745b72e9ca3ff74e7c056bd73e34bad5b8d5641045a2c1e7e131866f47
|
3 |
+
size 256677
|
data/sts-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15d12efc2d656fffb1d61ac1f08ec4227f43925fd16f420c037cbd063699c21b
|
3 |
+
size 928832
|
datasets.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
'''
|
4 |
+
This module contains our Dataset classes and functions that load the three datasets
|
5 |
+
for training and evaluating multitask BERT.
|
6 |
+
|
7 |
+
Feel free to edit code in this file if you wish to modify the way in which the data
|
8 |
+
examples are preprocessed.
|
9 |
+
'''
|
10 |
+
|
11 |
+
import csv
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from tokenizer import BertTokenizer
|
16 |
+
|
17 |
+
|
18 |
+
def preprocess_string(s):
|
19 |
+
return ' '.join(s.lower()
|
20 |
+
.replace('.', ' .')
|
21 |
+
.replace('?', ' ?')
|
22 |
+
.replace(',', ' ,')
|
23 |
+
.replace('\'', ' \'')
|
24 |
+
.split())
|
25 |
+
|
26 |
+
|
27 |
+
class SentenceClassificationDataset(Dataset):
|
28 |
+
def __init__(self, dataset, args):
|
29 |
+
self.dataset = dataset
|
30 |
+
self.p = args
|
31 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.dataset)
|
35 |
+
|
36 |
+
def __getitem__(self, idx):
|
37 |
+
return self.dataset[idx]
|
38 |
+
|
39 |
+
def pad_data(self, data):
|
40 |
+
|
41 |
+
sents = [x[0] for x in data]
|
42 |
+
labels = [x[1] for x in data]
|
43 |
+
sent_ids = [x[2] for x in data]
|
44 |
+
|
45 |
+
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
46 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
47 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
48 |
+
labels = torch.LongTensor(labels)
|
49 |
+
|
50 |
+
return token_ids, attention_mask, labels, sents, sent_ids
|
51 |
+
|
52 |
+
def collate_fn(self, all_data):
|
53 |
+
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
|
54 |
+
|
55 |
+
batched_data = {
|
56 |
+
'token_ids': token_ids,
|
57 |
+
'attention_mask': attention_mask,
|
58 |
+
'labels': labels,
|
59 |
+
'sents': sents,
|
60 |
+
'sent_ids': sent_ids
|
61 |
+
}
|
62 |
+
|
63 |
+
return batched_data
|
64 |
+
|
65 |
+
|
66 |
+
# Unlike SentenceClassificationDataset, we do not load labels in SentenceClassificationTestDataset.
|
67 |
+
class SentenceClassificationTestDataset(Dataset):
|
68 |
+
def __init__(self, dataset, args):
|
69 |
+
self.dataset = dataset
|
70 |
+
self.p = args
|
71 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.dataset)
|
75 |
+
|
76 |
+
def __getitem__(self, idx):
|
77 |
+
return self.dataset[idx]
|
78 |
+
|
79 |
+
def pad_data(self, data):
|
80 |
+
sents = [x[0] for x in data]
|
81 |
+
sent_ids = [x[1] for x in data]
|
82 |
+
|
83 |
+
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
84 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
85 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
86 |
+
|
87 |
+
return token_ids, attention_mask, sents, sent_ids
|
88 |
+
|
89 |
+
def collate_fn(self, all_data):
|
90 |
+
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
91 |
+
|
92 |
+
batched_data = {
|
93 |
+
'token_ids': token_ids,
|
94 |
+
'attention_mask': attention_mask,
|
95 |
+
'sents': sents,
|
96 |
+
'sent_ids': sent_ids
|
97 |
+
}
|
98 |
+
|
99 |
+
return batched_data
|
100 |
+
|
101 |
+
|
102 |
+
class SentencePairDataset(Dataset):
|
103 |
+
def __init__(self, dataset, args, isRegression=False):
|
104 |
+
self.dataset = dataset
|
105 |
+
self.p = args
|
106 |
+
self.isRegression = isRegression
|
107 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.dataset)
|
111 |
+
|
112 |
+
def __getitem__(self, idx):
|
113 |
+
return self.dataset[idx]
|
114 |
+
|
115 |
+
def pad_data(self, data):
|
116 |
+
sent1 = [x[0] for x in data]
|
117 |
+
sent2 = [x[1] for x in data]
|
118 |
+
labels = [x[2] for x in data]
|
119 |
+
sent_ids = [x[3] for x in data]
|
120 |
+
|
121 |
+
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
|
122 |
+
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
|
123 |
+
|
124 |
+
token_ids = torch.LongTensor(encoding1['input_ids'])
|
125 |
+
attention_mask = torch.LongTensor(encoding1['attention_mask'])
|
126 |
+
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
|
127 |
+
|
128 |
+
token_ids2 = torch.LongTensor(encoding2['input_ids'])
|
129 |
+
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
|
130 |
+
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
|
131 |
+
if self.isRegression:
|
132 |
+
labels = torch.DoubleTensor(labels)
|
133 |
+
else:
|
134 |
+
labels = torch.LongTensor(labels)
|
135 |
+
|
136 |
+
return (token_ids, token_type_ids, attention_mask,
|
137 |
+
token_ids2, token_type_ids2, attention_mask2,
|
138 |
+
labels,sent_ids)
|
139 |
+
|
140 |
+
def collate_fn(self, all_data):
|
141 |
+
(token_ids, token_type_ids, attention_mask,
|
142 |
+
token_ids2, token_type_ids2, attention_mask2,
|
143 |
+
labels, sent_ids) = self.pad_data(all_data)
|
144 |
+
|
145 |
+
batched_data = {
|
146 |
+
'token_ids_1': token_ids,
|
147 |
+
'token_type_ids_1': token_type_ids,
|
148 |
+
'attention_mask_1': attention_mask,
|
149 |
+
'token_ids_2': token_ids2,
|
150 |
+
'token_type_ids_2': token_type_ids2,
|
151 |
+
'attention_mask_2': attention_mask2,
|
152 |
+
'labels': labels,
|
153 |
+
'sent_ids': sent_ids
|
154 |
+
}
|
155 |
+
|
156 |
+
return batched_data
|
157 |
+
|
158 |
+
|
159 |
+
# Unlike SentencePairDataset, we do not load labels in SentencePairTestDataset.
|
160 |
+
class SentencePairTestDataset(Dataset):
|
161 |
+
def __init__(self, dataset, args):
|
162 |
+
self.dataset = dataset
|
163 |
+
self.p = args
|
164 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return len(self.dataset)
|
168 |
+
|
169 |
+
def __getitem__(self, idx):
|
170 |
+
return self.dataset[idx]
|
171 |
+
|
172 |
+
def pad_data(self, data):
|
173 |
+
sent1 = [x[0] for x in data]
|
174 |
+
sent2 = [x[1] for x in data]
|
175 |
+
sent_ids = [x[2] for x in data]
|
176 |
+
|
177 |
+
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
|
178 |
+
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
|
179 |
+
|
180 |
+
token_ids = torch.LongTensor(encoding1['input_ids'])
|
181 |
+
attention_mask = torch.LongTensor(encoding1['attention_mask'])
|
182 |
+
token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
|
183 |
+
|
184 |
+
token_ids2 = torch.LongTensor(encoding2['input_ids'])
|
185 |
+
attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
|
186 |
+
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
|
187 |
+
|
188 |
+
|
189 |
+
return (token_ids, token_type_ids, attention_mask,
|
190 |
+
token_ids2, token_type_ids2, attention_mask2,
|
191 |
+
sent_ids)
|
192 |
+
|
193 |
+
def collate_fn(self, all_data):
|
194 |
+
(token_ids, token_type_ids, attention_mask,
|
195 |
+
token_ids2, token_type_ids2, attention_mask2,
|
196 |
+
sent_ids) = self.pad_data(all_data)
|
197 |
+
|
198 |
+
batched_data = {
|
199 |
+
'token_ids_1': token_ids,
|
200 |
+
'token_type_ids_1': token_type_ids,
|
201 |
+
'attention_mask_1': attention_mask,
|
202 |
+
'token_ids_2': token_ids2,
|
203 |
+
'token_type_ids_2': token_type_ids2,
|
204 |
+
'attention_mask_2': attention_mask2,
|
205 |
+
'sent_ids': sent_ids
|
206 |
+
}
|
207 |
+
|
208 |
+
return batched_data
|
209 |
+
|
210 |
+
|
211 |
+
def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'):
|
212 |
+
sentiment_data = []
|
213 |
+
num_labels = {}
|
214 |
+
if split == 'test':
|
215 |
+
with open(sentiment_filename, 'r') as fp:
|
216 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
217 |
+
sent = record['sentence'].lower().strip()
|
218 |
+
sent_id = record['id'].lower().strip()
|
219 |
+
sentiment_data.append((sent,sent_id))
|
220 |
+
else:
|
221 |
+
with open(sentiment_filename, 'r') as fp:
|
222 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
223 |
+
sent = record['sentence'].lower().strip()
|
224 |
+
sent_id = record['id'].lower().strip()
|
225 |
+
label = int(record['sentiment'].strip())
|
226 |
+
if label not in num_labels:
|
227 |
+
num_labels[label] = len(num_labels)
|
228 |
+
sentiment_data.append((sent, label,sent_id))
|
229 |
+
|
230 |
+
print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")
|
231 |
+
|
232 |
+
paraphrase_data = []
|
233 |
+
if split == 'test':
|
234 |
+
with open(paraphrase_filename, 'r') as fp:
|
235 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
236 |
+
sent_id = record['id'].lower().strip()
|
237 |
+
paraphrase_data.append((preprocess_string(record['sentence1']),
|
238 |
+
preprocess_string(record['sentence2']),
|
239 |
+
sent_id))
|
240 |
+
|
241 |
+
else:
|
242 |
+
with open(paraphrase_filename, 'r') as fp:
|
243 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
244 |
+
try:
|
245 |
+
sent_id = record['id'].lower().strip()
|
246 |
+
paraphrase_data.append((preprocess_string(record['sentence1']),
|
247 |
+
preprocess_string(record['sentence2']),
|
248 |
+
int(float(record['is_duplicate'])),sent_id))
|
249 |
+
except:
|
250 |
+
pass
|
251 |
+
|
252 |
+
print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")
|
253 |
+
|
254 |
+
similarity_data = []
|
255 |
+
if split == 'test':
|
256 |
+
with open(similarity_filename, 'r') as fp:
|
257 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
258 |
+
sent_id = record['id'].lower().strip()
|
259 |
+
similarity_data.append((preprocess_string(record['sentence1']),
|
260 |
+
preprocess_string(record['sentence2'])
|
261 |
+
,sent_id))
|
262 |
+
else:
|
263 |
+
with open(similarity_filename, 'r') as fp:
|
264 |
+
for record in csv.DictReader(fp,delimiter = '\t'):
|
265 |
+
sent_id = record['id'].lower().strip()
|
266 |
+
similarity_data.append((preprocess_string(record['sentence1']),
|
267 |
+
preprocess_string(record['sentence2']),
|
268 |
+
float(record['similarity']),sent_id))
|
269 |
+
|
270 |
+
print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
|
271 |
+
|
272 |
+
return sentiment_data, num_labels, paraphrase_data, similarity_data
|
evaluation.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
'''
|
4 |
+
Multitask BERT evaluation functions.
|
5 |
+
|
6 |
+
When training your multitask model, you will find it useful to call
|
7 |
+
model_eval_multitask to evaluate your model on the 3 tasks' dev sets.
|
8 |
+
'''
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from sklearn.metrics import f1_score, accuracy_score
|
12 |
+
from tqdm import tqdm
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
TQDM_DISABLE = False
|
17 |
+
|
18 |
+
|
19 |
+
# Evaluate multitask model on SST only.
|
20 |
+
def model_eval_sst(dataloader, model, device):
|
21 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
22 |
+
y_true = []
|
23 |
+
y_pred = []
|
24 |
+
sents = []
|
25 |
+
sent_ids = []
|
26 |
+
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
27 |
+
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
|
28 |
+
batch['labels'], batch['sents'], batch['sent_ids']
|
29 |
+
|
30 |
+
b_ids = b_ids.to(device)
|
31 |
+
b_mask = b_mask.to(device)
|
32 |
+
|
33 |
+
logits = model.predict_sentiment(b_ids, b_mask)
|
34 |
+
logits = logits.detach().cpu().numpy()
|
35 |
+
preds = np.argmax(logits, axis=1).flatten()
|
36 |
+
|
37 |
+
b_labels = b_labels.flatten()
|
38 |
+
y_true.extend(b_labels)
|
39 |
+
y_pred.extend(preds)
|
40 |
+
sents.extend(b_sents)
|
41 |
+
sent_ids.extend(b_sent_ids)
|
42 |
+
|
43 |
+
f1 = f1_score(y_true, y_pred, average='macro')
|
44 |
+
acc = accuracy_score(y_true, y_pred)
|
45 |
+
|
46 |
+
return acc, f1, y_pred, y_true, sents, sent_ids
|
47 |
+
|
48 |
+
|
49 |
+
# Evaluate multitask model on dev sets.
|
50 |
+
def model_eval_multitask(sentiment_dataloader,
|
51 |
+
paraphrase_dataloader,
|
52 |
+
sts_dataloader,
|
53 |
+
model, device):
|
54 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
# Evaluate sentiment classification.
|
58 |
+
sst_y_true = []
|
59 |
+
sst_y_pred = []
|
60 |
+
sst_sent_ids = []
|
61 |
+
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
62 |
+
b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']
|
63 |
+
|
64 |
+
b_ids = b_ids.to(device)
|
65 |
+
b_mask = b_mask.to(device)
|
66 |
+
|
67 |
+
logits = model.predict_sentiment(b_ids, b_mask)
|
68 |
+
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
|
69 |
+
b_labels = b_labels.flatten().cpu().numpy()
|
70 |
+
|
71 |
+
sst_y_pred.extend(y_hat)
|
72 |
+
sst_y_true.extend(b_labels)
|
73 |
+
sst_sent_ids.extend(b_sent_ids)
|
74 |
+
|
75 |
+
sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))
|
76 |
+
|
77 |
+
# Evaluate paraphrase detection.
|
78 |
+
para_y_true = []
|
79 |
+
para_y_pred = []
|
80 |
+
para_sent_ids = []
|
81 |
+
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
82 |
+
(b_ids1, b_mask1,
|
83 |
+
b_ids2, b_mask2,
|
84 |
+
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
85 |
+
batch['token_ids_2'], batch['attention_mask_2'],
|
86 |
+
batch['labels'], batch['sent_ids'])
|
87 |
+
|
88 |
+
b_ids1 = b_ids1.to(device)
|
89 |
+
b_mask1 = b_mask1.to(device)
|
90 |
+
b_ids2 = b_ids2.to(device)
|
91 |
+
b_mask2 = b_mask2.to(device)
|
92 |
+
|
93 |
+
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
|
94 |
+
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
|
95 |
+
b_labels = b_labels.flatten().cpu().numpy()
|
96 |
+
|
97 |
+
para_y_pred.extend(y_hat)
|
98 |
+
para_y_true.extend(b_labels)
|
99 |
+
para_sent_ids.extend(b_sent_ids)
|
100 |
+
|
101 |
+
paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))
|
102 |
+
|
103 |
+
# Evaluate semantic textual similarity.
|
104 |
+
sts_y_true = []
|
105 |
+
sts_y_pred = []
|
106 |
+
sts_sent_ids = []
|
107 |
+
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
108 |
+
(b_ids1, b_mask1,
|
109 |
+
b_ids2, b_mask2,
|
110 |
+
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
111 |
+
batch['token_ids_2'], batch['attention_mask_2'],
|
112 |
+
batch['labels'], batch['sent_ids'])
|
113 |
+
|
114 |
+
b_ids1 = b_ids1.to(device)
|
115 |
+
b_mask1 = b_mask1.to(device)
|
116 |
+
b_ids2 = b_ids2.to(device)
|
117 |
+
b_mask2 = b_mask2.to(device)
|
118 |
+
|
119 |
+
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
|
120 |
+
y_hat = logits.flatten().cpu().numpy()
|
121 |
+
b_labels = b_labels.flatten().cpu().numpy()
|
122 |
+
|
123 |
+
sts_y_pred.extend(y_hat)
|
124 |
+
sts_y_true.extend(b_labels)
|
125 |
+
sts_sent_ids.extend(b_sent_ids)
|
126 |
+
pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)
|
127 |
+
sts_corr = pearson_mat[1][0]
|
128 |
+
|
129 |
+
print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')
|
130 |
+
print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')
|
131 |
+
print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')
|
132 |
+
|
133 |
+
return (sentiment_accuracy,sst_y_pred, sst_sent_ids,
|
134 |
+
paraphrase_accuracy, para_y_pred, para_sent_ids,
|
135 |
+
sts_corr, sts_y_pred, sts_sent_ids)
|
136 |
+
|
137 |
+
|
138 |
+
# Evaluate multitask model on test sets.
|
139 |
+
def model_eval_test_multitask(sentiment_dataloader,
|
140 |
+
paraphrase_dataloader,
|
141 |
+
sts_dataloader,
|
142 |
+
model, device):
|
143 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
144 |
+
|
145 |
+
with torch.no_grad():
|
146 |
+
# Evaluate sentiment classification.
|
147 |
+
sst_y_pred = []
|
148 |
+
sst_sent_ids = []
|
149 |
+
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
150 |
+
b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
|
151 |
+
|
152 |
+
b_ids = b_ids.to(device)
|
153 |
+
b_mask = b_mask.to(device)
|
154 |
+
|
155 |
+
logits = model.predict_sentiment(b_ids, b_mask)
|
156 |
+
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
|
157 |
+
|
158 |
+
sst_y_pred.extend(y_hat)
|
159 |
+
sst_sent_ids.extend(b_sent_ids)
|
160 |
+
|
161 |
+
# Evaluate paraphrase detection.
|
162 |
+
para_y_pred = []
|
163 |
+
para_sent_ids = []
|
164 |
+
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
165 |
+
(b_ids1, b_mask1,
|
166 |
+
b_ids2, b_mask2,
|
167 |
+
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
168 |
+
batch['token_ids_2'], batch['attention_mask_2'],
|
169 |
+
batch['sent_ids'])
|
170 |
+
|
171 |
+
b_ids1 = b_ids1.to(device)
|
172 |
+
b_mask1 = b_mask1.to(device)
|
173 |
+
b_ids2 = b_ids2.to(device)
|
174 |
+
b_mask2 = b_mask2.to(device)
|
175 |
+
|
176 |
+
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
|
177 |
+
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
|
178 |
+
|
179 |
+
para_y_pred.extend(y_hat)
|
180 |
+
para_sent_ids.extend(b_sent_ids)
|
181 |
+
|
182 |
+
# Evaluate semantic textual similarity.
|
183 |
+
sts_y_pred = []
|
184 |
+
sts_sent_ids = []
|
185 |
+
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
|
186 |
+
(b_ids1, b_mask1,
|
187 |
+
b_ids2, b_mask2,
|
188 |
+
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
|
189 |
+
batch['token_ids_2'], batch['attention_mask_2'],
|
190 |
+
batch['sent_ids'])
|
191 |
+
|
192 |
+
b_ids1 = b_ids1.to(device)
|
193 |
+
b_mask1 = b_mask1.to(device)
|
194 |
+
b_ids2 = b_ids2.to(device)
|
195 |
+
b_mask2 = b_mask2.to(device)
|
196 |
+
|
197 |
+
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
|
198 |
+
y_hat = logits.flatten().cpu().numpy()
|
199 |
+
|
200 |
+
sts_y_pred.extend(y_hat)
|
201 |
+
sts_sent_ids.extend(b_sent_ids)
|
202 |
+
|
203 |
+
return (sst_y_pred, sst_sent_ids,
|
204 |
+
para_y_pred, para_sent_ids,
|
205 |
+
sts_y_pred, sts_sent_ids)
|
multitask_classifier.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Multitask BERT class, starter training code, evaluation, and test code.
|
3 |
+
|
4 |
+
Of note are:
|
5 |
+
* class MultitaskBERT: Your implementation of multitask BERT.
|
6 |
+
* function train_multitask: Training procedure for MultitaskBERT. Starter code
|
7 |
+
copies training procedure from `classifier.py` (single-task SST).
|
8 |
+
* function test_multitask: Test procedure for MultitaskBERT. This function generates
|
9 |
+
the required files for submission.
|
10 |
+
|
11 |
+
Running `python multitask_classifier.py` trains and tests your MultitaskBERT and
|
12 |
+
writes all required submission files.
|
13 |
+
'''
|
14 |
+
|
15 |
+
import random, numpy as np, argparse
|
16 |
+
from types import SimpleNamespace
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
|
23 |
+
from bert import BertModel
|
24 |
+
from optimizer import AdamW
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
from datasets import (
|
28 |
+
SentenceClassificationDataset,
|
29 |
+
SentenceClassificationTestDataset,
|
30 |
+
SentencePairDataset,
|
31 |
+
SentencePairTestDataset,
|
32 |
+
load_multitask_data
|
33 |
+
)
|
34 |
+
|
35 |
+
from evaluation import model_eval_sst, model_eval_multitask, model_eval_test_multitask
|
36 |
+
|
37 |
+
|
38 |
+
TQDM_DISABLE=False
|
39 |
+
|
40 |
+
|
41 |
+
# Fix the random seed.
|
42 |
+
def seed_everything(seed=11711):
|
43 |
+
random.seed(seed)
|
44 |
+
np.random.seed(seed)
|
45 |
+
torch.manual_seed(seed)
|
46 |
+
torch.cuda.manual_seed(seed)
|
47 |
+
torch.cuda.manual_seed_all(seed)
|
48 |
+
torch.backends.cudnn.benchmark = False
|
49 |
+
torch.backends.cudnn.deterministic = True
|
50 |
+
|
51 |
+
|
52 |
+
BERT_HIDDEN_SIZE = 768
|
53 |
+
N_SENTIMENT_CLASSES = 5
|
54 |
+
|
55 |
+
|
56 |
+
class MultitaskBERT(nn.Module):
|
57 |
+
'''
|
58 |
+
This module should use BERT for 3 tasks:
|
59 |
+
|
60 |
+
- Sentiment classification (predict_sentiment)
|
61 |
+
- Paraphrase detection (predict_paraphrase)
|
62 |
+
- Semantic Textual Similarity (predict_similarity)
|
63 |
+
'''
|
64 |
+
def __init__(self, config):
|
65 |
+
super(MultitaskBERT, self).__init__()
|
66 |
+
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
67 |
+
# last-linear-layer mode does not require updating BERT paramters.
|
68 |
+
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
69 |
+
for param in self.bert.parameters():
|
70 |
+
if config.fine_tune_mode == 'last-linear-layer':
|
71 |
+
param.requires_grad = False
|
72 |
+
elif config.fine_tune_mode == 'full-model':
|
73 |
+
param.requires_grad = True
|
74 |
+
# You will want to add layers here to perform the downstream tasks.
|
75 |
+
### TODO
|
76 |
+
raise NotImplementedError
|
77 |
+
|
78 |
+
|
79 |
+
def forward(self, input_ids, attention_mask):
|
80 |
+
'Takes a batch of sentences and produces embeddings for them.'
|
81 |
+
# The final BERT embedding is the hidden state of [CLS] token (the first token)
|
82 |
+
# Here, you can start by just returning the embeddings straight from BERT.
|
83 |
+
# When thinking of improvements, you can later try modifying this
|
84 |
+
# (e.g., by adding other layers).
|
85 |
+
### TODO
|
86 |
+
raise NotImplementedError
|
87 |
+
|
88 |
+
|
89 |
+
def predict_sentiment(self, input_ids, attention_mask):
|
90 |
+
'''Given a batch of sentences, outputs logits for classifying sentiment.
|
91 |
+
There are 5 sentiment classes:
|
92 |
+
(0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
|
93 |
+
Thus, your output should contain 5 logits for each sentence.
|
94 |
+
'''
|
95 |
+
### TODO
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
|
99 |
+
def predict_paraphrase(self,
|
100 |
+
input_ids_1, attention_mask_1,
|
101 |
+
input_ids_2, attention_mask_2):
|
102 |
+
'''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.
|
103 |
+
Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
|
104 |
+
during evaluation.
|
105 |
+
'''
|
106 |
+
### TODO
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
|
110 |
+
def predict_similarity(self,
|
111 |
+
input_ids_1, attention_mask_1,
|
112 |
+
input_ids_2, attention_mask_2):
|
113 |
+
'''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
|
114 |
+
Note that your output should be unnormalized (a logit).
|
115 |
+
'''
|
116 |
+
### TODO
|
117 |
+
raise NotImplementedError
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def save_model(model, optimizer, args, config, filepath):
|
123 |
+
save_info = {
|
124 |
+
'model': model.state_dict(),
|
125 |
+
'optim': optimizer.state_dict(),
|
126 |
+
'args': args,
|
127 |
+
'model_config': config,
|
128 |
+
'system_rng': random.getstate(),
|
129 |
+
'numpy_rng': np.random.get_state(),
|
130 |
+
'torch_rng': torch.random.get_rng_state(),
|
131 |
+
}
|
132 |
+
|
133 |
+
torch.save(save_info, filepath)
|
134 |
+
print(f"save the model to {filepath}")
|
135 |
+
|
136 |
+
|
137 |
+
def train_multitask(args):
|
138 |
+
'''Train MultitaskBERT.
|
139 |
+
|
140 |
+
Currently only trains on SST dataset. The way you incorporate training examples
|
141 |
+
from other datasets into the training procedure is up to you. To begin, take a
|
142 |
+
look at test_multitask below to see how you can use the custom torch `Dataset`s
|
143 |
+
in datasets.py to load in examples from the Quora and SemEval datasets.
|
144 |
+
'''
|
145 |
+
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
146 |
+
# Create the data and its corresponding datasets and dataloader.
|
147 |
+
sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
|
148 |
+
sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')
|
149 |
+
|
150 |
+
sst_train_data = SentenceClassificationDataset(sst_train_data, args)
|
151 |
+
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
|
152 |
+
|
153 |
+
sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
|
154 |
+
collate_fn=sst_train_data.collate_fn)
|
155 |
+
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
|
156 |
+
collate_fn=sst_dev_data.collate_fn)
|
157 |
+
|
158 |
+
# Init model.
|
159 |
+
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
|
160 |
+
'num_labels': num_labels,
|
161 |
+
'hidden_size': 768,
|
162 |
+
'data_dir': '.',
|
163 |
+
'fine_tune_mode': args.fine_tune_mode}
|
164 |
+
|
165 |
+
config = SimpleNamespace(**config)
|
166 |
+
|
167 |
+
model = MultitaskBERT(config)
|
168 |
+
model = model.to(device)
|
169 |
+
|
170 |
+
lr = args.lr
|
171 |
+
optimizer = AdamW(model.parameters(), lr=lr)
|
172 |
+
best_dev_acc = 0
|
173 |
+
|
174 |
+
# Run for the specified number of epochs.
|
175 |
+
for epoch in range(args.epochs):
|
176 |
+
model.train()
|
177 |
+
train_loss = 0
|
178 |
+
num_batches = 0
|
179 |
+
for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
|
180 |
+
b_ids, b_mask, b_labels = (batch['token_ids'],
|
181 |
+
batch['attention_mask'], batch['labels'])
|
182 |
+
|
183 |
+
b_ids = b_ids.to(device)
|
184 |
+
b_mask = b_mask.to(device)
|
185 |
+
b_labels = b_labels.to(device)
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
logits = model.predict_sentiment(b_ids, b_mask)
|
189 |
+
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
|
190 |
+
|
191 |
+
loss.backward()
|
192 |
+
optimizer.step()
|
193 |
+
|
194 |
+
train_loss += loss.item()
|
195 |
+
num_batches += 1
|
196 |
+
|
197 |
+
train_loss = train_loss / (num_batches)
|
198 |
+
|
199 |
+
train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
|
200 |
+
dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)
|
201 |
+
|
202 |
+
if dev_acc > best_dev_acc:
|
203 |
+
best_dev_acc = dev_acc
|
204 |
+
save_model(model, optimizer, args, config, args.filepath)
|
205 |
+
|
206 |
+
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
|
207 |
+
|
208 |
+
|
209 |
+
def test_multitask(args):
|
210 |
+
'''Test and save predictions on the dev and test sets of all three tasks.'''
|
211 |
+
with torch.no_grad():
|
212 |
+
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
|
213 |
+
saved = torch.load(args.filepath)
|
214 |
+
config = saved['model_config']
|
215 |
+
|
216 |
+
model = MultitaskBERT(config)
|
217 |
+
model.load_state_dict(saved['model'])
|
218 |
+
model = model.to(device)
|
219 |
+
print(f"Loaded model to test from {args.filepath}")
|
220 |
+
|
221 |
+
sst_test_data, num_labels,para_test_data, sts_test_data = \
|
222 |
+
load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')
|
223 |
+
|
224 |
+
sst_dev_data, num_labels,para_dev_data, sts_dev_data = \
|
225 |
+
load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')
|
226 |
+
|
227 |
+
sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)
|
228 |
+
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
|
229 |
+
|
230 |
+
sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,
|
231 |
+
collate_fn=sst_test_data.collate_fn)
|
232 |
+
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
|
233 |
+
collate_fn=sst_dev_data.collate_fn)
|
234 |
+
|
235 |
+
para_test_data = SentencePairTestDataset(para_test_data, args)
|
236 |
+
para_dev_data = SentencePairDataset(para_dev_data, args)
|
237 |
+
|
238 |
+
para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,
|
239 |
+
collate_fn=para_test_data.collate_fn)
|
240 |
+
para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
|
241 |
+
collate_fn=para_dev_data.collate_fn)
|
242 |
+
|
243 |
+
sts_test_data = SentencePairTestDataset(sts_test_data, args)
|
244 |
+
sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)
|
245 |
+
|
246 |
+
sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,
|
247 |
+
collate_fn=sts_test_data.collate_fn)
|
248 |
+
sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,
|
249 |
+
collate_fn=sts_dev_data.collate_fn)
|
250 |
+
|
251 |
+
dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, \
|
252 |
+
dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \
|
253 |
+
dev_sts_corr, dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,
|
254 |
+
para_dev_dataloader,
|
255 |
+
sts_dev_dataloader, model, device)
|
256 |
+
|
257 |
+
test_sst_y_pred, \
|
258 |
+
test_sst_sent_ids, test_para_y_pred, test_para_sent_ids, test_sts_y_pred, test_sts_sent_ids = \
|
259 |
+
model_eval_test_multitask(sst_test_dataloader,
|
260 |
+
para_test_dataloader,
|
261 |
+
sts_test_dataloader, model, device)
|
262 |
+
|
263 |
+
with open(args.sst_dev_out, "w+") as f:
|
264 |
+
print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")
|
265 |
+
f.write(f"id \t Predicted_Sentiment \n")
|
266 |
+
for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):
|
267 |
+
f.write(f"{p} , {s} \n")
|
268 |
+
|
269 |
+
with open(args.sst_test_out, "w+") as f:
|
270 |
+
f.write(f"id \t Predicted_Sentiment \n")
|
271 |
+
for p, s in zip(test_sst_sent_ids, test_sst_y_pred):
|
272 |
+
f.write(f"{p} , {s} \n")
|
273 |
+
|
274 |
+
with open(args.para_dev_out, "w+") as f:
|
275 |
+
print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")
|
276 |
+
f.write(f"id \t Predicted_Is_Paraphrase \n")
|
277 |
+
for p, s in zip(dev_para_sent_ids, dev_para_y_pred):
|
278 |
+
f.write(f"{p} , {s} \n")
|
279 |
+
|
280 |
+
with open(args.para_test_out, "w+") as f:
|
281 |
+
f.write(f"id \t Predicted_Is_Paraphrase \n")
|
282 |
+
for p, s in zip(test_para_sent_ids, test_para_y_pred):
|
283 |
+
f.write(f"{p} , {s} \n")
|
284 |
+
|
285 |
+
with open(args.sts_dev_out, "w+") as f:
|
286 |
+
print(f"dev sts corr :: {dev_sts_corr :.3f}")
|
287 |
+
f.write(f"id \t Predicted_Similiary \n")
|
288 |
+
for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):
|
289 |
+
f.write(f"{p} , {s} \n")
|
290 |
+
|
291 |
+
with open(args.sts_test_out, "w+") as f:
|
292 |
+
f.write(f"id \t Predicted_Similiary \n")
|
293 |
+
for p, s in zip(test_sts_sent_ids, test_sts_y_pred):
|
294 |
+
f.write(f"{p} , {s} \n")
|
295 |
+
|
296 |
+
|
297 |
+
def get_args():
|
298 |
+
parser = argparse.ArgumentParser()
|
299 |
+
parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")
|
300 |
+
parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")
|
301 |
+
parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")
|
302 |
+
|
303 |
+
parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
|
304 |
+
parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
|
305 |
+
parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
|
306 |
+
|
307 |
+
parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")
|
308 |
+
parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")
|
309 |
+
parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")
|
310 |
+
|
311 |
+
parser.add_argument("--seed", type=int, default=11711)
|
312 |
+
parser.add_argument("--epochs", type=int, default=10)
|
313 |
+
parser.add_argument("--fine-tune-mode", type=str,
|
314 |
+
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
|
315 |
+
choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
|
316 |
+
parser.add_argument("--use_gpu", action='store_true')
|
317 |
+
|
318 |
+
parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")
|
319 |
+
parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")
|
320 |
+
|
321 |
+
parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
|
322 |
+
parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")
|
323 |
+
|
324 |
+
parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")
|
325 |
+
parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")
|
326 |
+
|
327 |
+
parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
|
328 |
+
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
329 |
+
parser.add_argument("--lr", type=float, help="learning rate", default=1e-5)
|
330 |
+
|
331 |
+
args = parser.parse_args()
|
332 |
+
return args
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
args = get_args()
|
337 |
+
args.filepath = f'{args.fine_tune_mode}-{args.epochs}-{args.lr}-multitask.pt' # Save path.
|
338 |
+
seed_everything(args.seed) # Fix the seed for reproducibility.
|
339 |
+
train_multitask(args)
|
340 |
+
test_multitask(args)
|
optimizer.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Iterable, Tuple
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.optim import Optimizer
|
6 |
+
|
7 |
+
|
8 |
+
class AdamW(Optimizer):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
params: Iterable[torch.nn.parameter.Parameter],
|
12 |
+
lr: float = 1e-3,
|
13 |
+
betas: Tuple[float, float] = (0.9, 0.999),
|
14 |
+
eps: float = 1e-6,
|
15 |
+
weight_decay: float = 0.0,
|
16 |
+
correct_bias: bool = True,
|
17 |
+
):
|
18 |
+
if lr < 0.0:
|
19 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
20 |
+
if not 0.0 <= betas[0] < 1.0:
|
21 |
+
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
22 |
+
if not 0.0 <= betas[1] < 1.0:
|
23 |
+
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
24 |
+
if not 0.0 <= eps:
|
25 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
26 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
27 |
+
super().__init__(params, defaults)
|
28 |
+
|
29 |
+
def step(self, closure: Callable = None):
|
30 |
+
loss = None
|
31 |
+
if closure is not None:
|
32 |
+
loss = closure()
|
33 |
+
|
34 |
+
for group in self.param_groups:
|
35 |
+
for p in group["params"]:
|
36 |
+
if p.grad is None:
|
37 |
+
continue
|
38 |
+
grad = p.grad.data
|
39 |
+
if grad.is_sparse:
|
40 |
+
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
41 |
+
|
42 |
+
# Access state
|
43 |
+
state = self.state[p]
|
44 |
+
|
45 |
+
# Initialize state if not already done
|
46 |
+
if len(state) == 0:
|
47 |
+
state["step"] = 0
|
48 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
49 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
50 |
+
|
51 |
+
# Hyperparameters
|
52 |
+
alpha = group["lr"]
|
53 |
+
beta1, beta2 = group["betas"]
|
54 |
+
eps = group["eps"]
|
55 |
+
weight_decay = group["weight_decay"]
|
56 |
+
correct_bias = group["correct_bias"]
|
57 |
+
|
58 |
+
# Retrieve state variables
|
59 |
+
exp_avg = state["exp_avg"]
|
60 |
+
exp_avg_sq = state["exp_avg_sq"]
|
61 |
+
step = state["step"]
|
62 |
+
|
63 |
+
# Update step
|
64 |
+
step += 1
|
65 |
+
state["step"] = step
|
66 |
+
|
67 |
+
# Update biased first and second moment estimates
|
68 |
+
exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
|
69 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
70 |
+
|
71 |
+
# Compute bias-corrected moments
|
72 |
+
if correct_bias:
|
73 |
+
bias_correction1 = 1 - beta1 ** step
|
74 |
+
bias_correction2 = 1 - beta2 ** step
|
75 |
+
exp_avg_corr = exp_avg / bias_correction1
|
76 |
+
exp_avg_sq_corr = exp_avg_sq / bias_correction2
|
77 |
+
else:
|
78 |
+
exp_avg_corr = exp_avg
|
79 |
+
exp_avg_sq_corr = exp_avg_sq
|
80 |
+
|
81 |
+
# Update parameters
|
82 |
+
denom = exp_avg_sq_corr.sqrt().add_(eps)
|
83 |
+
step_size = alpha
|
84 |
+
p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
|
85 |
+
|
86 |
+
# Apply weight decay
|
87 |
+
if weight_decay != 0:
|
88 |
+
p.data.add_(p.data, alpha=-alpha * weight_decay)
|
89 |
+
|
90 |
+
return loss
|
optimizer_test.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
|
3 |
+
size 152
|
optimizer_test.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from optimizer import AdamW
|
4 |
+
|
5 |
+
seed = 0
|
6 |
+
|
7 |
+
|
8 |
+
def test_optimizer(opt_class) -> torch.Tensor:
|
9 |
+
rng = np.random.default_rng(seed)
|
10 |
+
torch.manual_seed(seed)
|
11 |
+
model = torch.nn.Linear(3, 2, bias=False)
|
12 |
+
opt = opt_class(
|
13 |
+
model.parameters(),
|
14 |
+
lr=1e-3,
|
15 |
+
weight_decay=1e-4,
|
16 |
+
correct_bias=True,
|
17 |
+
)
|
18 |
+
for i in range(1000):
|
19 |
+
opt.zero_grad()
|
20 |
+
x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
|
21 |
+
y_hat = model(x)
|
22 |
+
y = torch.Tensor([x[0] + x[1], -x[2]])
|
23 |
+
loss = ((y - y_hat) ** 2).sum()
|
24 |
+
loss.backward()
|
25 |
+
opt.step()
|
26 |
+
return model.weight.detach()
|
27 |
+
|
28 |
+
|
29 |
+
ref = torch.tensor(np.load("optimizer_test.npy"))
|
30 |
+
actual = test_optimizer(AdamW)
|
31 |
+
print(ref)
|
32 |
+
print(actual)
|
33 |
+
assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
|
34 |
+
print("Optimizer test passed!")
|
predictions/README
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
By default, `classifier.py` and `multitask_classifier.py` write your model predictions into this folder.
|
2 |
+
Before running prepare_submit.py, make sure that this directory has been populated!
|
predictions/last-linear-layer-cfimdb-dev-out.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3f994587376345ea6a1e80a7946d5889259f6a427989c71e0b45de28ea4545d
|
3 |
+
size 7621
|
predictions/last-linear-layer-cfimdb-test-out.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ebedf210c8973e02648e96152e253daa2385b230a48da151812a58d80178536
|
3 |
+
size 15154
|
predictions/last-linear-layer-sst-dev-out.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22412dead5299ffb8fae45448f240cb135e3ad5dc04cea96975e893bdd719ba8
|
3 |
+
size 34157
|
predictions/last-linear-layer-sst-test-out.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3455d6637e5ecd118c31e48534d92298da3c865ed11ad93e2aadc09fcc743666
|
3 |
+
size 68536
|
prepare_submit.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Creates a zip file for submission on Gradescope.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import zipfile
|
5 |
+
|
6 |
+
required_files = [p for p in os.listdir('.') if p.endswith('.py')] + \
|
7 |
+
[f'predictions/{p}' for p in os.listdir('predictions')]
|
8 |
+
|
9 |
+
def main():
|
10 |
+
aid = 'cs224n_default_final_project_submission'
|
11 |
+
path = os.getcwd()
|
12 |
+
with zipfile.ZipFile(f"{aid}.zip", 'w') as zz:
|
13 |
+
for file in required_files:
|
14 |
+
zz.write(file, os.path.join(".", file))
|
15 |
+
print(f"Submission zip file created: {aid}.zip")
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
main()
|
sanity_check.data
ADDED
Binary file (56.4 kB). View file
|
|
sanity_check.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from bert import BertModel
|
3 |
+
|
4 |
+
|
5 |
+
sanity_data = torch.load("./sanity_check.data", weights_only=True)
|
6 |
+
sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
|
7 |
+
[101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
|
8 |
+
att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
|
9 |
+
|
10 |
+
# Load model.
|
11 |
+
bert = BertModel.from_pretrained('bert-base-uncased')
|
12 |
+
outputs = bert(sent_ids, att_mask)
|
13 |
+
att_mask = att_mask.unsqueeze(-1)
|
14 |
+
outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
|
15 |
+
sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
|
16 |
+
|
17 |
+
for k in ['last_hidden_state', 'pooler_output']:
|
18 |
+
assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
|
19 |
+
print("Your BERT implementation is correct!")
|
setup.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
conda create -n cs224n_dfp python=3.8
|
4 |
+
conda activate cs224n_dfp
|
5 |
+
|
6 |
+
pip install torch torchvision torchaudio
|
7 |
+
pip install tqdm==4.58.0
|
8 |
+
pip install requests==2.25.1
|
9 |
+
pip install importlib-metadata==3.7.0
|
10 |
+
pip install filelock==3.0.12
|
11 |
+
pip install sklearn==0.0
|
12 |
+
pip install tokenizers==0.15
|
13 |
+
pip install explainaboard_client==0.0.7
|
sst-classifier.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62f6282ea608a997c1b43071cedcb1c4ba454b420305c7b15138aa9d7f70103d
|
3 |
+
size 438072793
|
tokenizer.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import tempfile
|
6 |
+
import copy
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from functools import partial
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
from pathlib import Path
|
11 |
+
import requests
|
12 |
+
from hashlib import sha256
|
13 |
+
from filelock import FileLock
|
14 |
+
import importlib_metadata
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch import Tensor
|
18 |
+
import fnmatch
|
19 |
+
|
20 |
+
__version__ = "4.0.0"
|
21 |
+
_torch_version = importlib_metadata.version("torch")
|
22 |
+
|
23 |
+
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
|
24 |
+
default_cache_path = os.path.join(hf_cache_home, "transformers")
|
25 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
26 |
+
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
27 |
+
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
28 |
+
|
29 |
+
PRESET_MIRROR_DICT = {
|
30 |
+
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
31 |
+
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
32 |
+
}
|
33 |
+
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
34 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
35 |
+
CONFIG_NAME = "config.json"
|
36 |
+
|
37 |
+
|
38 |
+
def is_torch_available():
|
39 |
+
return True
|
40 |
+
|
41 |
+
|
42 |
+
def is_tf_available():
|
43 |
+
return False
|
44 |
+
|
45 |
+
|
46 |
+
def is_remote_url(url_or_filename):
|
47 |
+
parsed = urlparse(url_or_filename)
|
48 |
+
return parsed.scheme in ("http", "https")
|
49 |
+
|
50 |
+
|
51 |
+
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
52 |
+
headers = copy.deepcopy(headers)
|
53 |
+
if resume_size > 0:
|
54 |
+
headers["Range"] = "bytes=%d-" % (resume_size,)
|
55 |
+
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
56 |
+
r.raise_for_status()
|
57 |
+
content_length = r.headers.get("Content-Length")
|
58 |
+
total = resume_size + int(content_length) if content_length is not None else None
|
59 |
+
progress = tqdm(
|
60 |
+
unit="B",
|
61 |
+
unit_scale=True,
|
62 |
+
total=total,
|
63 |
+
initial=resume_size,
|
64 |
+
desc="Downloading",
|
65 |
+
disable=False,
|
66 |
+
)
|
67 |
+
for chunk in r.iter_content(chunk_size=1024):
|
68 |
+
if chunk: # filter out keep-alive new chunks
|
69 |
+
progress.update(len(chunk))
|
70 |
+
temp_file.write(chunk)
|
71 |
+
progress.close()
|
72 |
+
|
73 |
+
|
74 |
+
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
75 |
+
url_bytes = url.encode("utf-8")
|
76 |
+
filename = sha256(url_bytes).hexdigest()
|
77 |
+
|
78 |
+
if etag:
|
79 |
+
etag_bytes = etag.encode("utf-8")
|
80 |
+
filename += "." + sha256(etag_bytes).hexdigest()
|
81 |
+
|
82 |
+
if url.endswith(".h5"):
|
83 |
+
filename += ".h5"
|
84 |
+
|
85 |
+
return filename
|
86 |
+
|
87 |
+
|
88 |
+
def hf_bucket_url(
|
89 |
+
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
90 |
+
) -> str:
|
91 |
+
if subfolder is not None:
|
92 |
+
filename = f"{subfolder}/{filename}"
|
93 |
+
|
94 |
+
if mirror:
|
95 |
+
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
96 |
+
legacy_format = "/" not in model_id
|
97 |
+
if legacy_format:
|
98 |
+
return f"{endpoint}/{model_id}-{filename}"
|
99 |
+
else:
|
100 |
+
return f"{endpoint}/{model_id}/{filename}"
|
101 |
+
|
102 |
+
if revision is None:
|
103 |
+
revision = "main"
|
104 |
+
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
105 |
+
|
106 |
+
|
107 |
+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
108 |
+
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
109 |
+
if is_torch_available():
|
110 |
+
ua += f"; torch/{_torch_version}"
|
111 |
+
if is_tf_available():
|
112 |
+
ua += f"; tensorflow/{_tf_version}"
|
113 |
+
if isinstance(user_agent, dict):
|
114 |
+
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
115 |
+
elif isinstance(user_agent, str):
|
116 |
+
ua += "; " + user_agent
|
117 |
+
return ua
|
118 |
+
|
119 |
+
|
120 |
+
def get_from_cache(
|
121 |
+
url: str,
|
122 |
+
cache_dir=None,
|
123 |
+
force_download=False,
|
124 |
+
proxies=None,
|
125 |
+
etag_timeout=10,
|
126 |
+
resume_download=False,
|
127 |
+
user_agent: Union[Dict, str, None] = None,
|
128 |
+
use_auth_token: Union[bool, str, None] = None,
|
129 |
+
local_files_only=False,
|
130 |
+
) -> Optional[str]:
|
131 |
+
if cache_dir is None:
|
132 |
+
cache_dir = TRANSFORMERS_CACHE
|
133 |
+
if isinstance(cache_dir, Path):
|
134 |
+
cache_dir = str(cache_dir)
|
135 |
+
|
136 |
+
os.makedirs(cache_dir, exist_ok=True)
|
137 |
+
|
138 |
+
headers = {"user-agent": http_user_agent(user_agent)}
|
139 |
+
if isinstance(use_auth_token, str):
|
140 |
+
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
141 |
+
elif use_auth_token:
|
142 |
+
token = HfFolder.get_token()
|
143 |
+
if token is None:
|
144 |
+
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
145 |
+
headers["authorization"] = "Bearer {}".format(token)
|
146 |
+
|
147 |
+
url_to_download = url
|
148 |
+
etag = None
|
149 |
+
if not local_files_only:
|
150 |
+
try:
|
151 |
+
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
152 |
+
r.raise_for_status()
|
153 |
+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
154 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
155 |
+
# we fallback to the regular etag header.
|
156 |
+
# If we don't have any of those, raise an error.
|
157 |
+
if etag is None:
|
158 |
+
raise OSError(
|
159 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
160 |
+
)
|
161 |
+
# In case of a redirect,
|
162 |
+
# save an extra redirect on the request.get call,
|
163 |
+
# and ensure we download the exact atomic version even if it changed
|
164 |
+
# between the HEAD and the GET (unlikely, but hey).
|
165 |
+
if 300 <= r.status_code <= 399:
|
166 |
+
url_to_download = r.headers["Location"]
|
167 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
168 |
+
# etag is already None
|
169 |
+
pass
|
170 |
+
|
171 |
+
filename = url_to_filename(url, etag)
|
172 |
+
|
173 |
+
# get cache path to put the file
|
174 |
+
cache_path = os.path.join(cache_dir, filename)
|
175 |
+
|
176 |
+
# etag is None == we don't have a connection or we passed local_files_only.
|
177 |
+
# try to get the last downloaded one
|
178 |
+
if etag is None:
|
179 |
+
if os.path.exists(cache_path):
|
180 |
+
return cache_path
|
181 |
+
else:
|
182 |
+
matching_files = [
|
183 |
+
file
|
184 |
+
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
185 |
+
if not file.endswith(".json") and not file.endswith(".lock")
|
186 |
+
]
|
187 |
+
if len(matching_files) > 0:
|
188 |
+
return os.path.join(cache_dir, matching_files[-1])
|
189 |
+
else:
|
190 |
+
# If files cannot be found and local_files_only=True,
|
191 |
+
# the models might've been found if local_files_only=False
|
192 |
+
# Notify the user about that
|
193 |
+
if local_files_only:
|
194 |
+
raise FileNotFoundError(
|
195 |
+
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
196 |
+
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
197 |
+
" to False."
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
raise ValueError(
|
201 |
+
"Connection error, and we cannot find the requested files in the cached path."
|
202 |
+
" Please try again or make sure your Internet connection is on."
|
203 |
+
)
|
204 |
+
|
205 |
+
# From now on, etag is not None.
|
206 |
+
if os.path.exists(cache_path) and not force_download:
|
207 |
+
return cache_path
|
208 |
+
|
209 |
+
# Prevent parallel downloads of the same file with a lock.
|
210 |
+
lock_path = cache_path + ".lock"
|
211 |
+
with FileLock(lock_path):
|
212 |
+
|
213 |
+
# If the download just completed while the lock was activated.
|
214 |
+
if os.path.exists(cache_path) and not force_download:
|
215 |
+
# Even if returning early like here, the lock will be released.
|
216 |
+
return cache_path
|
217 |
+
|
218 |
+
if resume_download:
|
219 |
+
incomplete_path = cache_path + ".incomplete"
|
220 |
+
|
221 |
+
@contextmanager
|
222 |
+
def _resumable_file_manager() -> "io.BufferedWriter":
|
223 |
+
with open(incomplete_path, "ab") as f:
|
224 |
+
yield f
|
225 |
+
|
226 |
+
temp_file_manager = _resumable_file_manager
|
227 |
+
if os.path.exists(incomplete_path):
|
228 |
+
resume_size = os.stat(incomplete_path).st_size
|
229 |
+
else:
|
230 |
+
resume_size = 0
|
231 |
+
else:
|
232 |
+
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
233 |
+
resume_size = 0
|
234 |
+
|
235 |
+
# Download to temporary file, then copy to cache dir once finished.
|
236 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
237 |
+
with temp_file_manager() as temp_file:
|
238 |
+
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
239 |
+
|
240 |
+
os.replace(temp_file.name, cache_path)
|
241 |
+
|
242 |
+
meta = {"url": url, "etag": etag}
|
243 |
+
meta_path = cache_path + ".json"
|
244 |
+
with open(meta_path, "w") as meta_file:
|
245 |
+
json.dump(meta, meta_file)
|
246 |
+
|
247 |
+
return cache_path
|
248 |
+
|
249 |
+
|
250 |
+
def cached_path(
|
251 |
+
url_or_filename,
|
252 |
+
cache_dir=None,
|
253 |
+
force_download=False,
|
254 |
+
proxies=None,
|
255 |
+
resume_download=False,
|
256 |
+
user_agent: Union[Dict, str, None] = None,
|
257 |
+
extract_compressed_file=False,
|
258 |
+
force_extract=False,
|
259 |
+
use_auth_token: Union[bool, str, None] = None,
|
260 |
+
local_files_only=False,
|
261 |
+
) -> Optional[str]:
|
262 |
+
if cache_dir is None:
|
263 |
+
cache_dir = TRANSFORMERS_CACHE
|
264 |
+
if isinstance(url_or_filename, Path):
|
265 |
+
url_or_filename = str(url_or_filename)
|
266 |
+
if isinstance(cache_dir, Path):
|
267 |
+
cache_dir = str(cache_dir)
|
268 |
+
|
269 |
+
if is_remote_url(url_or_filename):
|
270 |
+
# URL, so get it from the cache (downloading if necessary)
|
271 |
+
output_path = get_from_cache(
|
272 |
+
url_or_filename,
|
273 |
+
cache_dir=cache_dir,
|
274 |
+
force_download=force_download,
|
275 |
+
proxies=proxies,
|
276 |
+
resume_download=resume_download,
|
277 |
+
user_agent=user_agent,
|
278 |
+
use_auth_token=use_auth_token,
|
279 |
+
local_files_only=local_files_only,
|
280 |
+
)
|
281 |
+
elif os.path.exists(url_or_filename):
|
282 |
+
# File, and it exists.
|
283 |
+
output_path = url_or_filename
|
284 |
+
elif urlparse(url_or_filename).scheme == "":
|
285 |
+
# File, but it doesn't exist.
|
286 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
287 |
+
else:
|
288 |
+
# Something unknown
|
289 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
290 |
+
|
291 |
+
if extract_compressed_file:
|
292 |
+
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
293 |
+
return output_path
|
294 |
+
|
295 |
+
# Path where we extract compressed archives
|
296 |
+
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
297 |
+
output_dir, output_file = os.path.split(output_path)
|
298 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
299 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
300 |
+
|
301 |
+
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
302 |
+
return output_path_extracted
|
303 |
+
|
304 |
+
# Prevent parallel extractions
|
305 |
+
lock_path = output_path + ".lock"
|
306 |
+
with FileLock(lock_path):
|
307 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
308 |
+
os.makedirs(output_path_extracted)
|
309 |
+
if is_zipfile(output_path):
|
310 |
+
with ZipFile(output_path, "r") as zip_file:
|
311 |
+
zip_file.extractall(output_path_extracted)
|
312 |
+
zip_file.close()
|
313 |
+
elif tarfile.is_tarfile(output_path):
|
314 |
+
tar_file = tarfile.open(output_path)
|
315 |
+
tar_file.extractall(output_path_extracted)
|
316 |
+
tar_file.close()
|
317 |
+
else:
|
318 |
+
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
319 |
+
|
320 |
+
return output_path_extracted
|
321 |
+
|
322 |
+
return output_path
|
323 |
+
|
324 |
+
|
325 |
+
def get_parameter_dtype(parameter: Union[nn.Module]):
|
326 |
+
try:
|
327 |
+
return next(parameter.parameters()).dtype
|
328 |
+
except StopIteration:
|
329 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
330 |
+
|
331 |
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
332 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
333 |
+
return tuples
|
334 |
+
|
335 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
336 |
+
first_tuple = next(gen)
|
337 |
+
return first_tuple[1].dtype
|
338 |
+
|
339 |
+
|
340 |
+
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
|
341 |
+
# attention_mask [batch_size, seq_length]
|
342 |
+
assert attention_mask.dim() == 2
|
343 |
+
# [batch_size, 1, 1, seq_length] for multi-head attention
|
344 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
345 |
+
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
346 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
347 |
+
return extended_attention_mask
|
zemo1.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch.optim as optim
|
5 |
+
|
6 |
+
# Bước 1: Chuẩn bị dữ liệu mẫu
|
7 |
+
# Dữ liệu giả: mỗi dòng là [giờ học, giờ giải trí, giờ ngủ], điểm trung bình
|
8 |
+
data = [
|
9 |
+
[2, 1, 7, 6.0],
|
10 |
+
[3, 2, 6, 7.5],
|
11 |
+
[1, 3, 8, 5.5],
|
12 |
+
[4, 1, 6, 8.0],
|
13 |
+
[5, 0, 5, 9.0],
|
14 |
+
[6, 0, 6, 9.5]
|
15 |
+
]
|
16 |
+
|
17 |
+
# Tách đặc trưng (features) và mục tiêu (target)
|
18 |
+
X = torch.tensor([row[:3] for row in data], dtype=torch.float32) # Giờ học, giờ giải trí, giờ ngủ
|
19 |
+
y = torch.tensor([[row[3]] for row in data], dtype=torch.float32) # Điểm trung bình
|
20 |
+
|
21 |
+
# Bước 2: Xây dựng mô hình
|
22 |
+
class StudentGradeModel(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(StudentGradeModel, self).__init__()
|
25 |
+
self.linear = nn.Linear(3, 1) # 3 đầu vào, 1 đầu ra
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.linear(x)
|
29 |
+
|
30 |
+
model = StudentGradeModel()
|
31 |
+
|
32 |
+
# Bước 3: Định nghĩa hàm mất mát và bộ tối ưu
|
33 |
+
criterion = nn.MSELoss()
|
34 |
+
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
35 |
+
|
36 |
+
# Bước 4: Huấn luyện mô hình
|
37 |
+
for epoch in tqdm(range(10000), desc="Training Epochs"):
|
38 |
+
optimizer.zero_grad() # Xóa gradient cũ
|
39 |
+
output = model(X) # Truyền dữ liệu qua mô hình
|
40 |
+
loss = criterion(output, y) # Tính mất mát
|
41 |
+
loss.backward() # Tính gradient
|
42 |
+
optimizer.step() # Cập nhật trọng số
|
43 |
+
|
44 |
+
# In loss để theo dõi quá trình huấn luyện
|
45 |
+
if (epoch + 1) % 1000 == 0:
|
46 |
+
tqdm.write(f'Epoch [{epoch + 1}/10000], Loss: {loss.item():.4f}')
|
47 |
+
|
48 |
+
# Bước 5: Dự đoán thử với một học sinh mới
|
49 |
+
model.eval()
|
50 |
+
with torch.no_grad():
|
51 |
+
test_input = torch.tensor([[4, 1, 6]], dtype=torch.float32) # Ví dụ: 4 giờ học, 1 giờ giải trí, 6 giờ ngủ
|
52 |
+
prediction = model(test_input)
|
53 |
+
print("Dự đoán điểm trung bình:", prediction.item())
|
zemo2.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
# Xây dựng mô hình RNN
|
5 |
+
class RNNModel(nn.Module):
|
6 |
+
def __init__(self, input_size, hidden_size, output_size):
|
7 |
+
super(RNNModel, self).__init__()
|
8 |
+
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) # Định nghĩa RNN
|
9 |
+
self.fc = nn.Linear(hidden_size, output_size) # Lớp fully connected để dự đoán output
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
out, _ = self.rnn(x) # Lấy output từ RNN
|
13 |
+
out = out[:, -1, :] # Lấy output của bước cuối cùng (nếu dữ liệu có nhiều bước thời gian)
|
14 |
+
out = self.fc(out) # Dự đoán output
|
15 |
+
return out
|
16 |
+
|
17 |
+
# Khởi tạo mô hình
|
18 |
+
input_size = 10 # Kích thước đầu vào
|
19 |
+
hidden_size = 20 # Số lượng hidden units
|
20 |
+
output_size = 1 # Đầu ra (ví dụ: hồi quy)
|
21 |
+
model = RNNModel(input_size, hidden_size, output_size)
|
22 |
+
|
23 |
+
# Khởi tạo dữ liệu giả
|
24 |
+
X = torch.randn(32, 5, 10) # 32 samples, 5 bước thời gian, mỗi bước có 10 đặc trưng
|
25 |
+
y = torch.randn(32, 1) # 32 samples, 1 giá trị đầu ra cho mỗi sample
|
26 |
+
|
27 |
+
# Hàm mất mát và bộ tối ưu
|
28 |
+
criterion = nn.MSELoss()
|
29 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
30 |
+
|
31 |
+
# Huấn luyện mô hình
|
32 |
+
for epoch in range(100):
|
33 |
+
model.train()
|
34 |
+
optimizer.zero_grad()
|
35 |
+
output = model(X) # Truyền dữ liệu qua mô hình
|
36 |
+
loss = criterion(output, y) # Tính mất mát
|
37 |
+
loss.backward() # Tính gradient
|
38 |
+
optimizer.step() # Cập nhật trọng số
|
39 |
+
|
40 |
+
if (epoch + 1) % 10 == 0:
|
41 |
+
print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
|
zemo3.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tokenizer import BertTokenizer
|
3 |
+
from torch import nn
|
4 |
+
from bert import BertModel
|
5 |
+
|
6 |
+
# Initialize the BERT tokenizer
|
7 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
8 |
+
|
9 |
+
# Example sentence
|
10 |
+
sentences = [
|
11 |
+
"She loves reading novels in her free time",
|
12 |
+
"An apple a day keeps the doctor away",
|
13 |
+
"If you can't explain it simply, you don't understand it well enough."
|
14 |
+
]
|
15 |
+
|
16 |
+
# Tokenize and encode the sentence
|
17 |
+
encoding = tokenizer.batch_encode_plus(
|
18 |
+
sentences,
|
19 |
+
max_length=512,
|
20 |
+
padding='max_length',
|
21 |
+
truncation=True,
|
22 |
+
return_tensors='pt'
|
23 |
+
)
|
24 |
+
|
25 |
+
# Get the token IDs from the encoding
|
26 |
+
input_ids = encoding['input_ids']
|
27 |
+
attention_mask = encoding['attention_mask']
|
28 |
+
|
29 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
30 |
+
|
31 |
+
assert isinstance(model, BertModel)
|
32 |
+
print(model.embed(input_ids).size())
|