GlowCheese commited on
Commit
9756d99
·
1 Parent(s): 23d93ea

First model version

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.csv filter=lfs diff=lfs merge=lfs -text
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.8.20
LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2018- The Hugging Face team. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CS 224N Default Final Project - Multitask BERT
2
+
3
+ This is the default final project for the Stanford CS 224N class. Please refer to the project handout on the course website for detailed instructions and an overview of the codebase.
4
+
5
+ This project comprises two parts. In the first part, you will implement some important components of the BERT model to better understand its architecture.
6
+ In the second part, you will use the embeddings produced by your BERT model on three downstream tasks: sentiment classification, paraphrase detection, and semantic similarity. You will implement extensions to improve your model's performance on the three downstream tasks.
7
+
8
+ In broad strokes, Part 1 of this project targets:
9
+ * bert.py: Missing code blocks.
10
+ * classifier.py: Missing code blocks.
11
+ * optimizer.py: Missing code blocks.
12
+
13
+ And Part 2 targets:
14
+ * multitask_classifier.py: Missing code blocks.
15
+ * datasets.py: Possibly useful functions/classes for extensions.
16
+ * evaluation.py: Possibly useful functions/classes for extensions.
17
+
18
+ ## Setup instructions
19
+
20
+ Follow `setup.sh` to properly setup a conda environment and install dependencies.
21
+
22
+ ## Acknowledgement
23
+
24
+ The BERT implementation part of the project was adapted from the "minbert" assignment developed at Carnegie Mellon University's [CS11-711 Advanced NLP](http://phontron.com/class/anlp2021/index.html),
25
+ created by Shuyan Zhou, Zhengbao Jiang, Ritam Dutt, Brendon Boldt, Aditya Veerubhotla, and Graham Neubig.
26
+
27
+ Parts of the code are from the [`transformers`](https://github.com/huggingface/transformers) library ([Apache License 2.0](./LICENSE)).
__pycache__/base_bert.cpython-38.pyc ADDED
Binary file (7.19 kB). View file
 
__pycache__/bert.cpython-38.pyc ADDED
Binary file (6.3 kB). View file
 
__pycache__/config.cpython-38.pyc ADDED
Binary file (6.64 kB). View file
 
__pycache__/optimizer.cpython-38.pyc ADDED
Binary file (2.37 kB). View file
 
__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (76.3 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (9.09 kB). View file
 
base_bert.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from torch import device, dtype
3
+ from config import BertConfig, PretrainedConfig
4
+ from utils import *
5
+
6
+
7
+ class BertPreTrainedModel(nn.Module):
8
+ config_class = BertConfig
9
+ base_model_prefix = "bert"
10
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
11
+ _keys_to_ignore_on_load_unexpected = None
12
+
13
+ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
14
+ super().__init__()
15
+ self.config = config
16
+ self.name_or_path = config.name_or_path
17
+
18
+ def init_weights(self):
19
+ # Initialize weights
20
+ self.apply(self._init_weights)
21
+
22
+ def _init_weights(self, module):
23
+ """ Initialize the weights """
24
+ if isinstance(module, (nn.Linear, nn.Embedding)):
25
+ # Slightly different from the TF version which uses truncated_normal for initialization
26
+ # cf https://github.com/pytorch/pytorch/pull/5617
27
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
28
+ elif isinstance(module, nn.LayerNorm):
29
+ module.bias.data.zero_()
30
+ module.weight.data.fill_(1.0)
31
+ if isinstance(module, nn.Linear) and module.bias is not None:
32
+ module.bias.data.zero_()
33
+
34
+ @property
35
+ def dtype(self) -> dtype:
36
+ return get_parameter_dtype(self)
37
+
38
+ @classmethod
39
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
40
+ config = kwargs.pop("config", None)
41
+ state_dict = kwargs.pop("state_dict", None)
42
+ cache_dir = kwargs.pop("cache_dir", None)
43
+ force_download = kwargs.pop("force_download", False)
44
+ resume_download = kwargs.pop("resume_download", False)
45
+ proxies = kwargs.pop("proxies", None)
46
+ output_loading_info = kwargs.pop("output_loading_info", False)
47
+ local_files_only = kwargs.pop("local_files_only", False)
48
+ use_auth_token = kwargs.pop("use_auth_token", None)
49
+ revision = kwargs.pop("revision", None)
50
+ mirror = kwargs.pop("mirror", None)
51
+
52
+ # Load config if we don't provide a configuration
53
+ if not isinstance(config, PretrainedConfig):
54
+ config_path = config if config is not None else pretrained_model_name_or_path
55
+ config, model_kwargs = cls.config_class.from_pretrained(
56
+ config_path,
57
+ *model_args,
58
+ cache_dir=cache_dir,
59
+ return_unused_kwargs=True,
60
+ force_download=force_download,
61
+ resume_download=resume_download,
62
+ proxies=proxies,
63
+ local_files_only=local_files_only,
64
+ use_auth_token=use_auth_token,
65
+ revision=revision,
66
+ **kwargs,
67
+ )
68
+ else:
69
+ model_kwargs = kwargs
70
+
71
+ # Load model
72
+ if pretrained_model_name_or_path is not None:
73
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
74
+ if os.path.isdir(pretrained_model_name_or_path):
75
+ # Load from a PyTorch checkpoint
76
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
77
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
78
+ archive_file = pretrained_model_name_or_path
79
+ else:
80
+ archive_file = hf_bucket_url(
81
+ pretrained_model_name_or_path,
82
+ filename=WEIGHTS_NAME,
83
+ revision=revision,
84
+ mirror=mirror,
85
+ )
86
+ try:
87
+ # Load from URL or cache if already cached
88
+ resolved_archive_file = cached_path(
89
+ archive_file,
90
+ cache_dir=cache_dir,
91
+ force_download=force_download,
92
+ proxies=proxies,
93
+ resume_download=resume_download,
94
+ local_files_only=local_files_only,
95
+ use_auth_token=use_auth_token,
96
+ )
97
+ except EnvironmentError as err:
98
+ #logger.error(err)
99
+ msg = (
100
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
101
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
102
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
103
+ )
104
+ raise EnvironmentError(msg)
105
+ else:
106
+ resolved_archive_file = None
107
+
108
+ config.name_or_path = pretrained_model_name_or_path
109
+
110
+ # Instantiate model.
111
+ model = cls(config, *model_args, **model_kwargs)
112
+
113
+ if state_dict is None:
114
+ try:
115
+ state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
116
+ except Exception:
117
+ raise OSError(
118
+ f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
119
+ f"at '{resolved_archive_file}'"
120
+ )
121
+
122
+ missing_keys = []
123
+ unexpected_keys = []
124
+ error_msgs = []
125
+
126
+ # Convert old format to new format if needed from a PyTorch state_dict
127
+ old_keys = []
128
+ new_keys = []
129
+ m = {'embeddings.word_embeddings': 'word_embedding',
130
+ 'embeddings.position_embeddings': 'pos_embedding',
131
+ 'embeddings.token_type_embeddings': 'tk_type_embedding',
132
+ 'embeddings.LayerNorm': 'embed_layer_norm',
133
+ 'embeddings.dropout': 'embed_dropout',
134
+ 'encoder.layer': 'bert_layers',
135
+ 'pooler.dense': 'pooler_dense',
136
+ 'pooler.activation': 'pooler_af',
137
+ 'attention.self': "self_attention",
138
+ 'attention.output.dense': 'attention_dense',
139
+ 'attention.output.LayerNorm': 'attention_layer_norm',
140
+ 'attention.output.dropout': 'attention_dropout',
141
+ 'intermediate.dense': 'interm_dense',
142
+ 'intermediate.intermediate_act_fn': 'interm_af',
143
+ 'output.dense': 'out_dense',
144
+ 'output.LayerNorm': 'out_layer_norm',
145
+ 'output.dropout': 'out_dropout'}
146
+
147
+ for key in state_dict.keys():
148
+ new_key = None
149
+ if "gamma" in key:
150
+ new_key = key.replace("gamma", "weight")
151
+ if "beta" in key:
152
+ new_key = key.replace("beta", "bias")
153
+ for x, y in m.items():
154
+ if new_key is not None:
155
+ _key = new_key
156
+ else:
157
+ _key = key
158
+ if x in key:
159
+ new_key = _key.replace(x, y)
160
+ if new_key:
161
+ old_keys.append(key)
162
+ new_keys.append(new_key)
163
+
164
+ for old_key, new_key in zip(old_keys, new_keys):
165
+ # print(old_key, new_key)
166
+ state_dict[new_key] = state_dict.pop(old_key)
167
+
168
+ # copy state_dict so _load_from_state_dict can modify it
169
+ metadata = getattr(state_dict, "_metadata", None)
170
+ state_dict = state_dict.copy()
171
+ if metadata is not None:
172
+ state_dict._metadata = metadata
173
+
174
+ your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
175
+ for k in state_dict:
176
+ if k not in your_bert_params and not k.startswith("cls."):
177
+ possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
178
+ raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
179
+
180
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
181
+ # so we need to apply the function recursively.
182
+ def load(module: nn.Module, prefix=""):
183
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
184
+ module._load_from_state_dict(
185
+ state_dict,
186
+ prefix,
187
+ local_metadata,
188
+ True,
189
+ missing_keys,
190
+ unexpected_keys,
191
+ error_msgs,
192
+ )
193
+ for name, child in module._modules.items():
194
+ if child is not None:
195
+ load(child, prefix + name + ".")
196
+
197
+ # Make sure we are able to load base models as well as derived models (with heads)
198
+ start_prefix = ""
199
+ model_to_load = model
200
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
201
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
202
+ start_prefix = cls.base_model_prefix + "."
203
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
204
+ model_to_load = getattr(model, cls.base_model_prefix)
205
+ load(model_to_load, prefix=start_prefix)
206
+
207
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
208
+ base_model_state_dict = model_to_load.state_dict().keys()
209
+ head_model_state_dict_without_base_prefix = [
210
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
211
+ ]
212
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
213
+
214
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
215
+ # the user.
216
+ if cls._keys_to_ignore_on_load_missing is not None:
217
+ for pat in cls._keys_to_ignore_on_load_missing:
218
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
219
+
220
+ if cls._keys_to_ignore_on_load_unexpected is not None:
221
+ for pat in cls._keys_to_ignore_on_load_unexpected:
222
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
223
+
224
+ if len(error_msgs) > 0:
225
+ raise RuntimeError(
226
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
227
+ model.__class__.__name__, "\n\t".join(error_msgs)
228
+ )
229
+ )
230
+
231
+ # Set model in evaluation mode to deactivate DropOut modules by default
232
+ model.eval()
233
+
234
+ if output_loading_info:
235
+ loading_info = {
236
+ "missing_keys": missing_keys,
237
+ "unexpected_keys": unexpected_keys,
238
+ "error_msgs": error_msgs,
239
+ }
240
+ return model, loading_info
241
+
242
+ if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
243
+ import torch_xla.core.xla_model as xm
244
+
245
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
246
+ model.to(xm.xla_device())
247
+
248
+ return model
bert.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from base_bert import BertPreTrainedModel
6
+ from utils import *
7
+
8
+
9
+ class BertSelfAttention(nn.Module):
10
+ def __init__(self, config):
11
+ super().__init__()
12
+
13
+ self.num_attention_heads = config.num_attention_heads
14
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
15
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
16
+
17
+ # Initialize the linear transformation layers for key, value, query.
18
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
19
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
20
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
21
+ # This dropout is applied to normalized attention scores following the original
22
+ # implementation of transformer. Although it is a bit unusual, we empirically
23
+ # observe that it yields better performance.
24
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
25
+
26
+ def transform(self, x, linear_layer):
27
+ # The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
28
+ bs, seq_len = x.shape[:2]
29
+ proj = linear_layer(x)
30
+ # Next, we need to produce multiple heads for the proj. This is done by spliting the
31
+ # hidden state to self.num_attention_heads, each of size self.attention_head_size.
32
+ proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
33
+ # By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
34
+ proj = proj.transpose(1, 2)
35
+ return proj
36
+
37
+ def attention(self, key, query, value, attention_mask):
38
+ """
39
+ key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
40
+ attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
41
+ """
42
+
43
+ d_k = query.size(-1) # attention_head_size
44
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
45
+ # attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
46
+
47
+ # Apply attention mask
48
+ attention_scores = attention_scores + attention_mask
49
+
50
+ # Normalize scores with softmax and apply dropout.
51
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
52
+ attention_probs = self.dropout(attention_probs)
53
+
54
+ context = torch.matmul(attention_probs, value)
55
+ # context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
56
+
57
+ # Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
58
+ context = context.transpose(1, 2).contiguous()
59
+ context = context.view(context.size(0), context.size(1), -1)
60
+
61
+ return context
62
+
63
+
64
+ def forward(self, hidden_states, attention_mask):
65
+ """
66
+ hidden_states: [bs, seq_len, hidden_state]
67
+ attention_mask: [bs, 1, 1, seq_len]
68
+ output: [bs, seq_len, hidden_state]
69
+ """
70
+ # First, we have to generate the key, value, query for each token for multi-head attention
71
+ # using self.transform (more details inside the function).
72
+ # Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
73
+ key_layer = self.transform(hidden_states, self.key)
74
+ value_layer = self.transform(hidden_states, self.value)
75
+ query_layer = self.transform(hidden_states, self.query)
76
+ # Calculate the multi-head attention.
77
+ attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
78
+ return attn_value
79
+
80
+
81
+ class BertLayer(nn.Module):
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ # Multi-head attention.
85
+ self.self_attention = BertSelfAttention(config)
86
+ # Add-norm for multi-head attention.
87
+ self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
88
+ self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
89
+ self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
90
+ # Feed forward.
91
+ self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
92
+ self.interm_af = F.gelu
93
+ # Add-norm for feed forward.
94
+ self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
95
+ self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
96
+ self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
97
+
98
+
99
+ def add_norm(self, input, output, dense_layer, dropout, ln_layer):
100
+ transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
101
+ transformed_output = dropout(transformed_output) # Áp dụng dropout
102
+ added_output = input + transformed_output # Kết hợp input và output
103
+ normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
104
+ return normalized_output
105
+
106
+
107
+ def forward(self, hidden_states, attention_mask):
108
+ # 1. Multi-head attention
109
+ attention_output = self.self_attention(hidden_states, attention_mask)
110
+
111
+ # 2. Add-norm after attention
112
+ attention_output = self.add_norm(
113
+ hidden_states,
114
+ attention_output,
115
+ self.attention_dense,
116
+ self.attention_dropout,
117
+ self.attention_layer_norm
118
+ )
119
+
120
+ # 3. Feed-forward network
121
+ intermediate_output = self.interm_af(self.interm_dense(attention_output))
122
+
123
+ # 4. Add-norm after feed-forward
124
+ layer_output = self.add_norm(
125
+ attention_output,
126
+ intermediate_output,
127
+ self.out_dense,
128
+ self.out_dropout,
129
+ self.out_layer_norm
130
+ )
131
+
132
+ return layer_output
133
+
134
+
135
+
136
+
137
+ class BertModel(BertPreTrainedModel):
138
+ """
139
+ The BERT model returns the final embeddings for each token in a sentence.
140
+
141
+ The model consists of:
142
+ 1. Embedding layers (used in self.embed).
143
+ 2. A stack of n BERT layers (used in self.encode).
144
+ 3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
145
+ """
146
+ def __init__(self, config):
147
+ super().__init__(config)
148
+ self.config = config
149
+
150
+ # Embedding layers.
151
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
152
+ self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
153
+ self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
154
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
155
+ self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
156
+ # Register position_ids (1, len position emb) to buffer because it is a constant.
157
+ position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
158
+ self.register_buffer('position_ids', position_ids)
159
+
160
+ # BERT encoder.
161
+ self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
162
+
163
+ # [CLS] token transformations.
164
+ self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
165
+ self.pooler_af = nn.Tanh()
166
+
167
+ self.init_weights()
168
+
169
+
170
+ def embed(self, input_ids):
171
+ input_shape = input_ids.size()
172
+ seq_length = input_shape[1]
173
+
174
+ inputs_embeds = self.word_embedding(input_ids)
175
+
176
+ pos_ids = self.position_ids[:, :seq_length]
177
+ pos_embeds = self.pos_embedding(pos_ids)
178
+
179
+ # Since we are not considering token type, this embedding is just a placeholder.
180
+ tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
181
+ tk_type_embeds = self.tk_type_embedding(tk_type_ids)
182
+
183
+ embeddings = inputs_embeds + pos_embeds + tk_type_embeds
184
+ embeddings = self.embed_layer_norm(embeddings)
185
+ embeddings = self.embed_dropout(embeddings)
186
+
187
+ return embeddings
188
+
189
+
190
+ def encode(self, hidden_states, attention_mask):
191
+ """
192
+ hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
193
+ attention_mask: [batch_size, seq_len]
194
+ """
195
+ # Get the extended attention mask for self-attention.
196
+ # Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
197
+ # Distinguishes between non-padding tokens (with a value of 0) and padding tokens
198
+ # (with a value of a large negative number).
199
+ extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
200
+
201
+ # Pass the hidden states through the encoder layers.
202
+ for i, layer_module in enumerate(self.bert_layers):
203
+ # Feed the encoding from the last bert_layer to the next.
204
+ hidden_states = layer_module(hidden_states, extended_attention_mask)
205
+
206
+ return hidden_states
207
+
208
+
209
+ def forward(self, input_ids, attention_mask):
210
+ """
211
+ input_ids: [batch_size, seq_len], seq_len is the max length of the batch
212
+ attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
213
+ """
214
+ # Get the embedding for each input token.
215
+ embedding_output = self.embed(input_ids=input_ids)
216
+
217
+ # Feed to a transformer (a stack of BertLayers).
218
+ sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
219
+
220
+ # Get cls token hidden state.
221
+ first_tk = sequence_output[:, 0]
222
+ first_tk = self.pooler_dense(first_tk)
223
+ first_tk = self.pooler_af(first_tk)
224
+
225
+ return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
cfimdb-classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1c66df3c0ce0e4326519041f49707f102df5f680de5ded1b5125ba689a9d141
3
+ size 438045778
classifier.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, numpy as np, argparse
2
+ from types import SimpleNamespace
3
+ import csv
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.metrics import f1_score, accuracy_score
9
+
10
+ from tokenizer import BertTokenizer
11
+ from bert import BertModel
12
+ from optimizer import AdamW
13
+ from tqdm import tqdm
14
+
15
+
16
+ TQDM_DISABLE=False
17
+
18
+
19
+ # Fix the random seed.
20
+ def seed_everything(seed=11711):
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed)
26
+ torch.backends.cudnn.benchmark = False
27
+ torch.backends.cudnn.deterministic = True
28
+
29
+
30
+ class BertSentimentClassifier(torch.nn.Module):
31
+ '''
32
+ This module performs sentiment classification using BERT embeddings on the SST dataset.
33
+
34
+ In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
35
+ Thus, your forward() should return one logit for each of the 5 classes.
36
+ '''
37
+ def __init__(self, config):
38
+ super(BertSentimentClassifier, self).__init__()
39
+ self.num_labels = config.num_labels
40
+ self.bert: BertModel = BertModel.from_pretrained('bert-base-uncased')
41
+
42
+ # Pretrain mode does not require updating BERT paramters.
43
+ assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
44
+ for param in self.bert.parameters():
45
+ if config.fine_tune_mode == 'last-linear-layer':
46
+ param.requires_grad = False
47
+ elif config.fine_tune_mode == 'full-model':
48
+ param.requires_grad = True
49
+
50
+ # Create any instance variables you need to classify the sentiment of BERT embeddings.
51
+ self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
52
+
53
+
54
+ def forward(self, input_ids, attention_mask):
55
+ '''Takes a batch of sentences and returns logits for sentiment classes'''
56
+ # The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
57
+ # HINT: You should consider what is an appropriate return value given that
58
+ # the training loop currently uses F.cross_entropy as the loss function.
59
+
60
+ # Get the embedding for each input token.
61
+ embedding_output = self.bert.embed(input_ids=input_ids)
62
+
63
+ # Feed to a transformer (BERT layers).
64
+ sequence_output = self.bert.encode(embedding_output, attention_mask=attention_mask)
65
+
66
+ # The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
67
+ cls_token_output = sequence_output[:, 0, :] # The first token is [CLS]
68
+
69
+ # Pass the [CLS] token representation through the classifier.
70
+ logits = self.classifier(cls_token_output)
71
+
72
+ return logits
73
+
74
+
75
+
76
+ class SentimentDataset(Dataset):
77
+ def __init__(self, dataset, args):
78
+ self.dataset = dataset
79
+ self.p = args
80
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
81
+
82
+ def __len__(self):
83
+ return len(self.dataset)
84
+
85
+ def __getitem__(self, idx):
86
+ return self.dataset[idx]
87
+
88
+ def pad_data(self, data):
89
+ sents = [x[0] for x in data]
90
+ labels = [x[1] for x in data]
91
+ sent_ids = [x[2] for x in data]
92
+
93
+ encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
94
+ token_ids = torch.LongTensor(encoding['input_ids'])
95
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
96
+ labels = torch.LongTensor(labels)
97
+
98
+ return token_ids, attention_mask, labels, sents, sent_ids
99
+
100
+ def collate_fn(self, all_data):
101
+ token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
102
+
103
+ batched_data = {
104
+ 'token_ids': token_ids,
105
+ 'attention_mask': attention_mask,
106
+ 'labels': labels,
107
+ 'sents': sents,
108
+ 'sent_ids': sent_ids
109
+ }
110
+
111
+ return batched_data
112
+
113
+
114
+ class SentimentTestDataset(Dataset):
115
+ def __init__(self, dataset, args):
116
+ self.dataset = dataset
117
+ self.p = args
118
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
119
+
120
+ def __len__(self):
121
+ return len(self.dataset)
122
+
123
+ def __getitem__(self, idx):
124
+ return self.dataset[idx]
125
+
126
+ def pad_data(self, data):
127
+ sents = [x[0] for x in data]
128
+ sent_ids = [x[1] for x in data]
129
+
130
+ encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
131
+ token_ids = torch.LongTensor(encoding['input_ids'])
132
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
133
+
134
+ return token_ids, attention_mask, sents, sent_ids
135
+
136
+ def collate_fn(self, all_data):
137
+ token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
138
+
139
+ batched_data = {
140
+ 'token_ids': token_ids,
141
+ 'attention_mask': attention_mask,
142
+ 'sents': sents,
143
+ 'sent_ids': sent_ids
144
+ }
145
+
146
+ return batched_data
147
+
148
+
149
+ # Load the data: a list of (sentence, label).
150
+ def load_data(filename, flag='train'):
151
+ num_labels = {}
152
+ data = []
153
+ if flag == 'test':
154
+ with open(filename, 'r') as fp:
155
+ for record in csv.DictReader(fp,delimiter = '\t'):
156
+ sent = record['sentence'].lower().strip()
157
+ sent_id = record['id'].lower().strip()
158
+ data.append((sent,sent_id))
159
+ else:
160
+ with open(filename, 'r') as fp:
161
+ for record in csv.DictReader(fp,delimiter = '\t'):
162
+ sent = record['sentence'].lower().strip()
163
+ sent_id = record['id'].lower().strip()
164
+ label = int(record['sentiment'].strip())
165
+ if label not in num_labels:
166
+ num_labels[label] = len(num_labels)
167
+ data.append((sent, label,sent_id))
168
+ print(f"load {len(data)} data from {filename}")
169
+
170
+ if flag == 'train':
171
+ return data, len(num_labels)
172
+ else:
173
+ return data
174
+
175
+
176
+ # Evaluate the model on dev examples.
177
+ def model_eval(dataloader, model, device):
178
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
179
+ y_true = []
180
+ y_pred = []
181
+ sents = []
182
+ sent_ids = []
183
+ for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
184
+ b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
185
+ batch['labels'], batch['sents'], batch['sent_ids']
186
+
187
+ b_ids = b_ids.to(device)
188
+ b_mask = b_mask.to(device)
189
+
190
+ logits = model(b_ids, b_mask)
191
+ logits = logits.detach().cpu().numpy()
192
+ preds = np.argmax(logits, axis=1).flatten()
193
+
194
+ b_labels = b_labels.flatten()
195
+ y_true.extend(b_labels)
196
+ y_pred.extend(preds)
197
+ sents.extend(b_sents)
198
+ sent_ids.extend(b_sent_ids)
199
+
200
+ f1 = f1_score(y_true, y_pred, average='macro')
201
+ acc = accuracy_score(y_true, y_pred)
202
+
203
+ return acc, f1, y_pred, y_true, sents, sent_ids
204
+
205
+
206
+ # Evaluate the model on test examples.
207
+ def model_test_eval(dataloader, model, device):
208
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
209
+ y_pred = []
210
+ sents = []
211
+ sent_ids = []
212
+ for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
213
+ b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
214
+ batch['sents'], batch['sent_ids']
215
+
216
+ b_ids = b_ids.to(device)
217
+ b_mask = b_mask.to(device)
218
+
219
+ logits = model(b_ids, b_mask)
220
+ logits = logits.detach().cpu().numpy()
221
+ preds = np.argmax(logits, axis=1).flatten()
222
+
223
+ y_pred.extend(preds)
224
+ sents.extend(b_sents)
225
+ sent_ids.extend(b_sent_ids)
226
+
227
+ return y_pred, sents, sent_ids
228
+
229
+
230
+ def save_model(model, optimizer, args, config, filepath):
231
+ save_info = {
232
+ 'model': model.state_dict(),
233
+ 'optim': optimizer.state_dict(),
234
+ 'args': args,
235
+ 'model_config': config,
236
+ 'system_rng': random.getstate(),
237
+ 'numpy_rng': np.random.get_state(),
238
+ 'torch_rng': torch.random.get_rng_state(),
239
+ }
240
+
241
+ torch.save(save_info, filepath)
242
+ print(f"save the model to {filepath}")
243
+
244
+
245
+ def train(args):
246
+ device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
247
+ # Create the data and its corresponding datasets and dataloader.
248
+ train_data, num_labels = load_data(args.train, 'train')
249
+ dev_data = load_data(args.dev, 'valid')
250
+
251
+ train_dataset = SentimentDataset(train_data, args)
252
+ dev_dataset = SentimentDataset(dev_data, args)
253
+
254
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
255
+ collate_fn=train_dataset.collate_fn)
256
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
257
+ collate_fn=dev_dataset.collate_fn)
258
+
259
+ # Init model.
260
+ config = {'hidden_dropout_prob': args.hidden_dropout_prob,
261
+ 'num_labels': num_labels,
262
+ 'hidden_size': 768,
263
+ 'data_dir': '.',
264
+ 'fine_tune_mode': args.fine_tune_mode}
265
+
266
+ config = SimpleNamespace(**config)
267
+
268
+ model = BertSentimentClassifier(config)
269
+ model = model.to(device)
270
+
271
+ lr = args.lr
272
+ optimizer = AdamW(model.parameters(), lr=lr)
273
+ best_dev_acc = 0
274
+
275
+ # Run for the specified number of epochs.
276
+ for epoch in range(args.epochs):
277
+ model.train()
278
+ train_loss = 0
279
+ num_batches = 0
280
+ for batch in tqdm(train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
281
+ b_ids, b_mask, b_labels = (batch['token_ids'],
282
+ batch['attention_mask'], batch['labels'])
283
+
284
+ b_ids = b_ids.to(device)
285
+ b_mask = b_mask.to(device)
286
+ b_labels = b_labels.to(device)
287
+
288
+ optimizer.zero_grad()
289
+ logits = model(b_ids, b_mask)
290
+ loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
291
+
292
+ loss.backward()
293
+ optimizer.step()
294
+
295
+ train_loss += loss.item()
296
+ num_batches += 1
297
+
298
+ train_loss = train_loss / (num_batches)
299
+
300
+ train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
301
+ dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
302
+
303
+ if dev_acc > best_dev_acc:
304
+ best_dev_acc = dev_acc
305
+ save_model(model, optimizer, args, config, args.filepath)
306
+
307
+ print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
308
+
309
+
310
+ def test(args):
311
+ with torch.no_grad():
312
+ device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
313
+ saved = torch.load(args.filepath)
314
+ config = saved['model_config']
315
+ model = BertSentimentClassifier(config)
316
+ model.load_state_dict(saved['model'])
317
+ model = model.to(device)
318
+ print(f"load model from {args.filepath}")
319
+
320
+ dev_data = load_data(args.dev, 'valid')
321
+ dev_dataset = SentimentDataset(dev_data, args)
322
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)
323
+
324
+ test_data = load_data(args.test, 'test')
325
+ test_dataset = SentimentTestDataset(test_data, args)
326
+ test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
327
+
328
+ dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
329
+ print('DONE DEV')
330
+ test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
331
+ print('DONE Test')
332
+ with open(args.dev_out, "w+") as f:
333
+ print(f"dev acc :: {dev_acc :.3f}")
334
+ f.write(f"id \t Predicted_Sentiment \n")
335
+ for p, s in zip(dev_sent_ids,dev_pred ):
336
+ f.write(f"{p} , {s} \n")
337
+
338
+ with open(args.test_out, "w+") as f:
339
+ f.write(f"id \t Predicted_Sentiment \n")
340
+ for p, s in zip(test_sent_ids,test_pred ):
341
+ f.write(f"{p} , {s} \n")
342
+
343
+
344
+ def get_args():
345
+ parser = argparse.ArgumentParser()
346
+ parser.add_argument("--seed", type=int, default=11711)
347
+ parser.add_argument("--epochs", type=int, default=10)
348
+ parser.add_argument("--fine-tune-mode", type=str,
349
+ help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
350
+ choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
351
+ parser.add_argument("--use_gpu", action='store_true')
352
+
353
+ parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
354
+ parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
355
+ parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
356
+ default=1e-3)
357
+
358
+ args = parser.parse_args()
359
+ return args
360
+
361
+
362
+ if __name__ == "__main__":
363
+ args = get_args()
364
+ seed_everything(args.seed)
365
+
366
+ print('Training Sentiment Classifier on SST...')
367
+ config = SimpleNamespace(
368
+ filepath='sst-classifier.pt',
369
+ lr=args.lr,
370
+ use_gpu=args.use_gpu,
371
+ epochs=args.epochs,
372
+ batch_size=args.batch_size,
373
+ hidden_dropout_prob=args.hidden_dropout_prob,
374
+ train='data/ids-sst-train.csv',
375
+ dev='data/ids-sst-dev.csv',
376
+ test='data/ids-sst-test-student.csv',
377
+ fine_tune_mode=args.fine_tune_mode,
378
+ dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
379
+ test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
380
+ )
381
+
382
+ train(config)
383
+
384
+ print('Evaluating on SST...')
385
+ test(config)
386
+
387
+ print('Training Sentiment Classifier on cfimdb...')
388
+ config = SimpleNamespace(
389
+ filepath='cfimdb-classifier.pt',
390
+ lr=args.lr,
391
+ use_gpu=args.use_gpu,
392
+ epochs=args.epochs,
393
+ batch_size=8,
394
+ hidden_dropout_prob=args.hidden_dropout_prob,
395
+ train='data/ids-cfimdb-train.csv',
396
+ dev='data/ids-cfimdb-dev.csv',
397
+ test='data/ids-cfimdb-test-student.csv',
398
+ fine_tune_mode=args.fine_tune_mode,
399
+ dev_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-dev-out.csv',
400
+ test_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-test-out.csv'
401
+ )
402
+
403
+ train(config)
404
+
405
+ print('Evaluating on cfimdb...')
406
+ test(config)
config.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple, Dict, Any, Optional
2
+ import os
3
+ import json
4
+ from collections import OrderedDict
5
+ import torch
6
+ from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
7
+
8
+ class PretrainedConfig(object):
9
+ model_type: str = ""
10
+ is_composition: bool = False
11
+
12
+ def __init__(self, **kwargs):
13
+ # Attributes with defaults
14
+ self.return_dict = kwargs.pop("return_dict", True)
15
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
16
+ self.output_attentions = kwargs.pop("output_attentions", False)
17
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
18
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
19
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
20
+ self.tie_word_embeddings = kwargs.pop(
21
+ "tie_word_embeddings", True
22
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
23
+
24
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
25
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
26
+ self.is_decoder = kwargs.pop("is_decoder", False)
27
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
28
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
29
+
30
+ # Parameters for sequence generation
31
+ self.max_length = kwargs.pop("max_length", 20)
32
+ self.min_length = kwargs.pop("min_length", 0)
33
+ self.do_sample = kwargs.pop("do_sample", False)
34
+ self.early_stopping = kwargs.pop("early_stopping", False)
35
+ self.num_beams = kwargs.pop("num_beams", 1)
36
+ self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
37
+ self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
38
+ self.temperature = kwargs.pop("temperature", 1.0)
39
+ self.top_k = kwargs.pop("top_k", 50)
40
+ self.top_p = kwargs.pop("top_p", 1.0)
41
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
42
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
43
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
44
+ self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
45
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
46
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
47
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
48
+ self.output_scores = kwargs.pop("output_scores", False)
49
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
50
+ self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
51
+ self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
52
+
53
+ # Fine-tuning task arguments
54
+ self.architectures = kwargs.pop("architectures", None)
55
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
56
+ self.id2label = kwargs.pop("id2label", None)
57
+ self.label2id = kwargs.pop("label2id", None)
58
+ if self.id2label is not None:
59
+ kwargs.pop("num_labels", None)
60
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
61
+ # Keys are always strings in JSON so convert ids to int here.
62
+ else:
63
+ self.num_labels = kwargs.pop("num_labels", 2)
64
+
65
+ # Tokenizer arguments
66
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
67
+ self.prefix = kwargs.pop("prefix", None)
68
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
69
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
70
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
71
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
72
+
73
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
74
+
75
+ # task specific arguments
76
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
77
+
78
+ # TPU arguments
79
+ self.xla_device = kwargs.pop("xla_device", None)
80
+
81
+ # Name or path to the pretrained checkpoint
82
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
83
+
84
+ # Drop the transformers version info
85
+ kwargs.pop("transformers_version", None)
86
+
87
+ # Additional attributes without default values
88
+ for key, value in kwargs.items():
89
+ try:
90
+ setattr(self, key, value)
91
+ except AttributeError as err:
92
+ raise err
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
96
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
97
+ return cls.from_dict(config_dict, **kwargs)
98
+
99
+ @classmethod
100
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
101
+ with open(json_file, "r", encoding="utf-8") as reader:
102
+ text = reader.read()
103
+ return json.loads(text)
104
+
105
+ @classmethod
106
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
107
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
108
+
109
+ config = cls(**config_dict)
110
+
111
+ if hasattr(config, "pruned_heads"):
112
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
113
+
114
+ # Update config with kwargs if needed
115
+ to_remove = []
116
+ for key, value in kwargs.items():
117
+ if hasattr(config, key):
118
+ setattr(config, key, value)
119
+ to_remove.append(key)
120
+ for key in to_remove:
121
+ kwargs.pop(key, None)
122
+
123
+ if return_unused_kwargs:
124
+ return config, kwargs
125
+ else:
126
+ return config
127
+
128
+ @classmethod
129
+ def get_config_dict(
130
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
131
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
132
+ cache_dir = kwargs.pop("cache_dir", None)
133
+ force_download = kwargs.pop("force_download", False)
134
+ resume_download = kwargs.pop("resume_download", False)
135
+ proxies = kwargs.pop("proxies", None)
136
+ use_auth_token = kwargs.pop("use_auth_token", None)
137
+ local_files_only = kwargs.pop("local_files_only", False)
138
+ revision = kwargs.pop("revision", None)
139
+
140
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
141
+ if os.path.isdir(pretrained_model_name_or_path):
142
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
143
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
144
+ config_file = pretrained_model_name_or_path
145
+ else:
146
+ config_file = hf_bucket_url(
147
+ pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
148
+ )
149
+
150
+ try:
151
+ # Load from URL or cache if already cached
152
+ resolved_config_file = cached_path(
153
+ config_file,
154
+ cache_dir=cache_dir,
155
+ force_download=force_download,
156
+ proxies=proxies,
157
+ resume_download=resume_download,
158
+ local_files_only=local_files_only,
159
+ use_auth_token=use_auth_token,
160
+ )
161
+ # Load config dict
162
+ config_dict = cls._dict_from_json_file(resolved_config_file)
163
+
164
+ except EnvironmentError as err:
165
+ msg = (
166
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
167
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
168
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
169
+ )
170
+ raise EnvironmentError(msg)
171
+
172
+ except json.JSONDecodeError:
173
+ msg = (
174
+ "Couldn't reach server at '{}' to download configuration file or "
175
+ "configuration file is not a valid JSON file. "
176
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
177
+ )
178
+ raise EnvironmentError(msg)
179
+
180
+ return config_dict, kwargs
181
+
182
+
183
+ class BertConfig(PretrainedConfig):
184
+ model_type = "bert"
185
+
186
+ def __init__(
187
+ self,
188
+ vocab_size=30522,
189
+ hidden_size=768,
190
+ num_hidden_layers=12,
191
+ num_attention_heads=12,
192
+ intermediate_size=3072,
193
+ hidden_act="gelu",
194
+ hidden_dropout_prob=0.1,
195
+ attention_probs_dropout_prob=0.1,
196
+ max_position_embeddings=512,
197
+ type_vocab_size=2,
198
+ initializer_range=0.02,
199
+ layer_norm_eps=1e-12,
200
+ pad_token_id=0,
201
+ gradient_checkpointing=False,
202
+ position_embedding_type="absolute",
203
+ use_cache=True,
204
+ **kwargs
205
+ ):
206
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
207
+
208
+ self.vocab_size = vocab_size
209
+ self.hidden_size = hidden_size
210
+ self.num_hidden_layers = num_hidden_layers
211
+ self.num_attention_heads = num_attention_heads
212
+ self.hidden_act = hidden_act
213
+ self.intermediate_size = intermediate_size
214
+ self.hidden_dropout_prob = hidden_dropout_prob
215
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
216
+ self.max_position_embeddings = max_position_embeddings
217
+ self.type_vocab_size = type_vocab_size
218
+ self.initializer_range = initializer_range
219
+ self.layer_norm_eps = layer_norm_eps
220
+ self.gradient_checkpointing = gradient_checkpointing
221
+ self.position_embedding_type = position_embedding_type
222
+ self.use_cache = use_cache
data/ids-cfimdb-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
3
+ size 249095
data/ids-cfimdb-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
3
+ size 495595
data/ids-cfimdb-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
3
+ size 1693182
data/ids-sst-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
3
+ size 151384
data/ids-sst-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
3
+ size 313202
data/ids-sst-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
3
+ size 1175139
data/quora-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e9dc46b273a711d82a065f55e1754a9b92c10ad7345ebe0b0ebba61397dda4a
3
+ size 6896912
data/quora-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa130f532cdde70287081aa04af13a4b12e3aa862e9162763d15fb46385497a
3
+ size 13487951
data/quora-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cd59e1ddb3a5b5d03f4a885c64e67aaf50122d9ab9ed7a476b5d2d6f7137ae8
3
+ size 48270674
data/sts-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce3cad6f16062586ac7ba462c28b010a9be10c530fd5074165860d7b7ab4e93d
3
+ size 132265
data/sts-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee455745b72e9ca3ff74e7c056bd73e34bad5b8d5641045a2c1e7e131866f47
3
+ size 256677
data/sts-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15d12efc2d656fffb1d61ac1f08ec4227f43925fd16f420c037cbd063699c21b
3
+ size 928832
datasets.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ '''
4
+ This module contains our Dataset classes and functions that load the three datasets
5
+ for training and evaluating multitask BERT.
6
+
7
+ Feel free to edit code in this file if you wish to modify the way in which the data
8
+ examples are preprocessed.
9
+ '''
10
+
11
+ import csv
12
+
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+ from tokenizer import BertTokenizer
16
+
17
+
18
+ def preprocess_string(s):
19
+ return ' '.join(s.lower()
20
+ .replace('.', ' .')
21
+ .replace('?', ' ?')
22
+ .replace(',', ' ,')
23
+ .replace('\'', ' \'')
24
+ .split())
25
+
26
+
27
+ class SentenceClassificationDataset(Dataset):
28
+ def __init__(self, dataset, args):
29
+ self.dataset = dataset
30
+ self.p = args
31
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
32
+
33
+ def __len__(self):
34
+ return len(self.dataset)
35
+
36
+ def __getitem__(self, idx):
37
+ return self.dataset[idx]
38
+
39
+ def pad_data(self, data):
40
+
41
+ sents = [x[0] for x in data]
42
+ labels = [x[1] for x in data]
43
+ sent_ids = [x[2] for x in data]
44
+
45
+ encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
46
+ token_ids = torch.LongTensor(encoding['input_ids'])
47
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
48
+ labels = torch.LongTensor(labels)
49
+
50
+ return token_ids, attention_mask, labels, sents, sent_ids
51
+
52
+ def collate_fn(self, all_data):
53
+ token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
54
+
55
+ batched_data = {
56
+ 'token_ids': token_ids,
57
+ 'attention_mask': attention_mask,
58
+ 'labels': labels,
59
+ 'sents': sents,
60
+ 'sent_ids': sent_ids
61
+ }
62
+
63
+ return batched_data
64
+
65
+
66
+ # Unlike SentenceClassificationDataset, we do not load labels in SentenceClassificationTestDataset.
67
+ class SentenceClassificationTestDataset(Dataset):
68
+ def __init__(self, dataset, args):
69
+ self.dataset = dataset
70
+ self.p = args
71
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
72
+
73
+ def __len__(self):
74
+ return len(self.dataset)
75
+
76
+ def __getitem__(self, idx):
77
+ return self.dataset[idx]
78
+
79
+ def pad_data(self, data):
80
+ sents = [x[0] for x in data]
81
+ sent_ids = [x[1] for x in data]
82
+
83
+ encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
84
+ token_ids = torch.LongTensor(encoding['input_ids'])
85
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
86
+
87
+ return token_ids, attention_mask, sents, sent_ids
88
+
89
+ def collate_fn(self, all_data):
90
+ token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
91
+
92
+ batched_data = {
93
+ 'token_ids': token_ids,
94
+ 'attention_mask': attention_mask,
95
+ 'sents': sents,
96
+ 'sent_ids': sent_ids
97
+ }
98
+
99
+ return batched_data
100
+
101
+
102
+ class SentencePairDataset(Dataset):
103
+ def __init__(self, dataset, args, isRegression=False):
104
+ self.dataset = dataset
105
+ self.p = args
106
+ self.isRegression = isRegression
107
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
108
+
109
+ def __len__(self):
110
+ return len(self.dataset)
111
+
112
+ def __getitem__(self, idx):
113
+ return self.dataset[idx]
114
+
115
+ def pad_data(self, data):
116
+ sent1 = [x[0] for x in data]
117
+ sent2 = [x[1] for x in data]
118
+ labels = [x[2] for x in data]
119
+ sent_ids = [x[3] for x in data]
120
+
121
+ encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
122
+ encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
123
+
124
+ token_ids = torch.LongTensor(encoding1['input_ids'])
125
+ attention_mask = torch.LongTensor(encoding1['attention_mask'])
126
+ token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
127
+
128
+ token_ids2 = torch.LongTensor(encoding2['input_ids'])
129
+ attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
130
+ token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
131
+ if self.isRegression:
132
+ labels = torch.DoubleTensor(labels)
133
+ else:
134
+ labels = torch.LongTensor(labels)
135
+
136
+ return (token_ids, token_type_ids, attention_mask,
137
+ token_ids2, token_type_ids2, attention_mask2,
138
+ labels,sent_ids)
139
+
140
+ def collate_fn(self, all_data):
141
+ (token_ids, token_type_ids, attention_mask,
142
+ token_ids2, token_type_ids2, attention_mask2,
143
+ labels, sent_ids) = self.pad_data(all_data)
144
+
145
+ batched_data = {
146
+ 'token_ids_1': token_ids,
147
+ 'token_type_ids_1': token_type_ids,
148
+ 'attention_mask_1': attention_mask,
149
+ 'token_ids_2': token_ids2,
150
+ 'token_type_ids_2': token_type_ids2,
151
+ 'attention_mask_2': attention_mask2,
152
+ 'labels': labels,
153
+ 'sent_ids': sent_ids
154
+ }
155
+
156
+ return batched_data
157
+
158
+
159
+ # Unlike SentencePairDataset, we do not load labels in SentencePairTestDataset.
160
+ class SentencePairTestDataset(Dataset):
161
+ def __init__(self, dataset, args):
162
+ self.dataset = dataset
163
+ self.p = args
164
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
165
+
166
+ def __len__(self):
167
+ return len(self.dataset)
168
+
169
+ def __getitem__(self, idx):
170
+ return self.dataset[idx]
171
+
172
+ def pad_data(self, data):
173
+ sent1 = [x[0] for x in data]
174
+ sent2 = [x[1] for x in data]
175
+ sent_ids = [x[2] for x in data]
176
+
177
+ encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
178
+ encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)
179
+
180
+ token_ids = torch.LongTensor(encoding1['input_ids'])
181
+ attention_mask = torch.LongTensor(encoding1['attention_mask'])
182
+ token_type_ids = torch.LongTensor(encoding1['token_type_ids'])
183
+
184
+ token_ids2 = torch.LongTensor(encoding2['input_ids'])
185
+ attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
186
+ token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
187
+
188
+
189
+ return (token_ids, token_type_ids, attention_mask,
190
+ token_ids2, token_type_ids2, attention_mask2,
191
+ sent_ids)
192
+
193
+ def collate_fn(self, all_data):
194
+ (token_ids, token_type_ids, attention_mask,
195
+ token_ids2, token_type_ids2, attention_mask2,
196
+ sent_ids) = self.pad_data(all_data)
197
+
198
+ batched_data = {
199
+ 'token_ids_1': token_ids,
200
+ 'token_type_ids_1': token_type_ids,
201
+ 'attention_mask_1': attention_mask,
202
+ 'token_ids_2': token_ids2,
203
+ 'token_type_ids_2': token_type_ids2,
204
+ 'attention_mask_2': attention_mask2,
205
+ 'sent_ids': sent_ids
206
+ }
207
+
208
+ return batched_data
209
+
210
+
211
+ def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'):
212
+ sentiment_data = []
213
+ num_labels = {}
214
+ if split == 'test':
215
+ with open(sentiment_filename, 'r') as fp:
216
+ for record in csv.DictReader(fp,delimiter = '\t'):
217
+ sent = record['sentence'].lower().strip()
218
+ sent_id = record['id'].lower().strip()
219
+ sentiment_data.append((sent,sent_id))
220
+ else:
221
+ with open(sentiment_filename, 'r') as fp:
222
+ for record in csv.DictReader(fp,delimiter = '\t'):
223
+ sent = record['sentence'].lower().strip()
224
+ sent_id = record['id'].lower().strip()
225
+ label = int(record['sentiment'].strip())
226
+ if label not in num_labels:
227
+ num_labels[label] = len(num_labels)
228
+ sentiment_data.append((sent, label,sent_id))
229
+
230
+ print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")
231
+
232
+ paraphrase_data = []
233
+ if split == 'test':
234
+ with open(paraphrase_filename, 'r') as fp:
235
+ for record in csv.DictReader(fp,delimiter = '\t'):
236
+ sent_id = record['id'].lower().strip()
237
+ paraphrase_data.append((preprocess_string(record['sentence1']),
238
+ preprocess_string(record['sentence2']),
239
+ sent_id))
240
+
241
+ else:
242
+ with open(paraphrase_filename, 'r') as fp:
243
+ for record in csv.DictReader(fp,delimiter = '\t'):
244
+ try:
245
+ sent_id = record['id'].lower().strip()
246
+ paraphrase_data.append((preprocess_string(record['sentence1']),
247
+ preprocess_string(record['sentence2']),
248
+ int(float(record['is_duplicate'])),sent_id))
249
+ except:
250
+ pass
251
+
252
+ print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")
253
+
254
+ similarity_data = []
255
+ if split == 'test':
256
+ with open(similarity_filename, 'r') as fp:
257
+ for record in csv.DictReader(fp,delimiter = '\t'):
258
+ sent_id = record['id'].lower().strip()
259
+ similarity_data.append((preprocess_string(record['sentence1']),
260
+ preprocess_string(record['sentence2'])
261
+ ,sent_id))
262
+ else:
263
+ with open(similarity_filename, 'r') as fp:
264
+ for record in csv.DictReader(fp,delimiter = '\t'):
265
+ sent_id = record['id'].lower().strip()
266
+ similarity_data.append((preprocess_string(record['sentence1']),
267
+ preprocess_string(record['sentence2']),
268
+ float(record['similarity']),sent_id))
269
+
270
+ print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
271
+
272
+ return sentiment_data, num_labels, paraphrase_data, similarity_data
evaluation.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ '''
4
+ Multitask BERT evaluation functions.
5
+
6
+ When training your multitask model, you will find it useful to call
7
+ model_eval_multitask to evaluate your model on the 3 tasks' dev sets.
8
+ '''
9
+
10
+ import torch
11
+ from sklearn.metrics import f1_score, accuracy_score
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+
15
+
16
+ TQDM_DISABLE = False
17
+
18
+
19
+ # Evaluate multitask model on SST only.
20
+ def model_eval_sst(dataloader, model, device):
21
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
22
+ y_true = []
23
+ y_pred = []
24
+ sents = []
25
+ sent_ids = []
26
+ for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
27
+ b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
28
+ batch['labels'], batch['sents'], batch['sent_ids']
29
+
30
+ b_ids = b_ids.to(device)
31
+ b_mask = b_mask.to(device)
32
+
33
+ logits = model.predict_sentiment(b_ids, b_mask)
34
+ logits = logits.detach().cpu().numpy()
35
+ preds = np.argmax(logits, axis=1).flatten()
36
+
37
+ b_labels = b_labels.flatten()
38
+ y_true.extend(b_labels)
39
+ y_pred.extend(preds)
40
+ sents.extend(b_sents)
41
+ sent_ids.extend(b_sent_ids)
42
+
43
+ f1 = f1_score(y_true, y_pred, average='macro')
44
+ acc = accuracy_score(y_true, y_pred)
45
+
46
+ return acc, f1, y_pred, y_true, sents, sent_ids
47
+
48
+
49
+ # Evaluate multitask model on dev sets.
50
+ def model_eval_multitask(sentiment_dataloader,
51
+ paraphrase_dataloader,
52
+ sts_dataloader,
53
+ model, device):
54
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
55
+
56
+ with torch.no_grad():
57
+ # Evaluate sentiment classification.
58
+ sst_y_true = []
59
+ sst_y_pred = []
60
+ sst_sent_ids = []
61
+ for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
62
+ b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']
63
+
64
+ b_ids = b_ids.to(device)
65
+ b_mask = b_mask.to(device)
66
+
67
+ logits = model.predict_sentiment(b_ids, b_mask)
68
+ y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
69
+ b_labels = b_labels.flatten().cpu().numpy()
70
+
71
+ sst_y_pred.extend(y_hat)
72
+ sst_y_true.extend(b_labels)
73
+ sst_sent_ids.extend(b_sent_ids)
74
+
75
+ sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))
76
+
77
+ # Evaluate paraphrase detection.
78
+ para_y_true = []
79
+ para_y_pred = []
80
+ para_sent_ids = []
81
+ for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
82
+ (b_ids1, b_mask1,
83
+ b_ids2, b_mask2,
84
+ b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
85
+ batch['token_ids_2'], batch['attention_mask_2'],
86
+ batch['labels'], batch['sent_ids'])
87
+
88
+ b_ids1 = b_ids1.to(device)
89
+ b_mask1 = b_mask1.to(device)
90
+ b_ids2 = b_ids2.to(device)
91
+ b_mask2 = b_mask2.to(device)
92
+
93
+ logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
94
+ y_hat = logits.sigmoid().round().flatten().cpu().numpy()
95
+ b_labels = b_labels.flatten().cpu().numpy()
96
+
97
+ para_y_pred.extend(y_hat)
98
+ para_y_true.extend(b_labels)
99
+ para_sent_ids.extend(b_sent_ids)
100
+
101
+ paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))
102
+
103
+ # Evaluate semantic textual similarity.
104
+ sts_y_true = []
105
+ sts_y_pred = []
106
+ sts_sent_ids = []
107
+ for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
108
+ (b_ids1, b_mask1,
109
+ b_ids2, b_mask2,
110
+ b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
111
+ batch['token_ids_2'], batch['attention_mask_2'],
112
+ batch['labels'], batch['sent_ids'])
113
+
114
+ b_ids1 = b_ids1.to(device)
115
+ b_mask1 = b_mask1.to(device)
116
+ b_ids2 = b_ids2.to(device)
117
+ b_mask2 = b_mask2.to(device)
118
+
119
+ logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
120
+ y_hat = logits.flatten().cpu().numpy()
121
+ b_labels = b_labels.flatten().cpu().numpy()
122
+
123
+ sts_y_pred.extend(y_hat)
124
+ sts_y_true.extend(b_labels)
125
+ sts_sent_ids.extend(b_sent_ids)
126
+ pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)
127
+ sts_corr = pearson_mat[1][0]
128
+
129
+ print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')
130
+ print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')
131
+ print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')
132
+
133
+ return (sentiment_accuracy,sst_y_pred, sst_sent_ids,
134
+ paraphrase_accuracy, para_y_pred, para_sent_ids,
135
+ sts_corr, sts_y_pred, sts_sent_ids)
136
+
137
+
138
+ # Evaluate multitask model on test sets.
139
+ def model_eval_test_multitask(sentiment_dataloader,
140
+ paraphrase_dataloader,
141
+ sts_dataloader,
142
+ model, device):
143
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
144
+
145
+ with torch.no_grad():
146
+ # Evaluate sentiment classification.
147
+ sst_y_pred = []
148
+ sst_sent_ids = []
149
+ for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
150
+ b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
151
+
152
+ b_ids = b_ids.to(device)
153
+ b_mask = b_mask.to(device)
154
+
155
+ logits = model.predict_sentiment(b_ids, b_mask)
156
+ y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
157
+
158
+ sst_y_pred.extend(y_hat)
159
+ sst_sent_ids.extend(b_sent_ids)
160
+
161
+ # Evaluate paraphrase detection.
162
+ para_y_pred = []
163
+ para_sent_ids = []
164
+ for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
165
+ (b_ids1, b_mask1,
166
+ b_ids2, b_mask2,
167
+ b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
168
+ batch['token_ids_2'], batch['attention_mask_2'],
169
+ batch['sent_ids'])
170
+
171
+ b_ids1 = b_ids1.to(device)
172
+ b_mask1 = b_mask1.to(device)
173
+ b_ids2 = b_ids2.to(device)
174
+ b_mask2 = b_mask2.to(device)
175
+
176
+ logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
177
+ y_hat = logits.sigmoid().round().flatten().cpu().numpy()
178
+
179
+ para_y_pred.extend(y_hat)
180
+ para_sent_ids.extend(b_sent_ids)
181
+
182
+ # Evaluate semantic textual similarity.
183
+ sts_y_pred = []
184
+ sts_sent_ids = []
185
+ for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
186
+ (b_ids1, b_mask1,
187
+ b_ids2, b_mask2,
188
+ b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
189
+ batch['token_ids_2'], batch['attention_mask_2'],
190
+ batch['sent_ids'])
191
+
192
+ b_ids1 = b_ids1.to(device)
193
+ b_mask1 = b_mask1.to(device)
194
+ b_ids2 = b_ids2.to(device)
195
+ b_mask2 = b_mask2.to(device)
196
+
197
+ logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
198
+ y_hat = logits.flatten().cpu().numpy()
199
+
200
+ sts_y_pred.extend(y_hat)
201
+ sts_sent_ids.extend(b_sent_ids)
202
+
203
+ return (sst_y_pred, sst_sent_ids,
204
+ para_y_pred, para_sent_ids,
205
+ sts_y_pred, sts_sent_ids)
multitask_classifier.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Multitask BERT class, starter training code, evaluation, and test code.
3
+
4
+ Of note are:
5
+ * class MultitaskBERT: Your implementation of multitask BERT.
6
+ * function train_multitask: Training procedure for MultitaskBERT. Starter code
7
+ copies training procedure from `classifier.py` (single-task SST).
8
+ * function test_multitask: Test procedure for MultitaskBERT. This function generates
9
+ the required files for submission.
10
+
11
+ Running `python multitask_classifier.py` trains and tests your MultitaskBERT and
12
+ writes all required submission files.
13
+ '''
14
+
15
+ import random, numpy as np, argparse
16
+ from types import SimpleNamespace
17
+
18
+ import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader
22
+
23
+ from bert import BertModel
24
+ from optimizer import AdamW
25
+ from tqdm import tqdm
26
+
27
+ from datasets import (
28
+ SentenceClassificationDataset,
29
+ SentenceClassificationTestDataset,
30
+ SentencePairDataset,
31
+ SentencePairTestDataset,
32
+ load_multitask_data
33
+ )
34
+
35
+ from evaluation import model_eval_sst, model_eval_multitask, model_eval_test_multitask
36
+
37
+
38
+ TQDM_DISABLE=False
39
+
40
+
41
+ # Fix the random seed.
42
+ def seed_everything(seed=11711):
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed(seed)
47
+ torch.cuda.manual_seed_all(seed)
48
+ torch.backends.cudnn.benchmark = False
49
+ torch.backends.cudnn.deterministic = True
50
+
51
+
52
+ BERT_HIDDEN_SIZE = 768
53
+ N_SENTIMENT_CLASSES = 5
54
+
55
+
56
+ class MultitaskBERT(nn.Module):
57
+ '''
58
+ This module should use BERT for 3 tasks:
59
+
60
+ - Sentiment classification (predict_sentiment)
61
+ - Paraphrase detection (predict_paraphrase)
62
+ - Semantic Textual Similarity (predict_similarity)
63
+ '''
64
+ def __init__(self, config):
65
+ super(MultitaskBERT, self).__init__()
66
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
67
+ # last-linear-layer mode does not require updating BERT paramters.
68
+ assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
69
+ for param in self.bert.parameters():
70
+ if config.fine_tune_mode == 'last-linear-layer':
71
+ param.requires_grad = False
72
+ elif config.fine_tune_mode == 'full-model':
73
+ param.requires_grad = True
74
+ # You will want to add layers here to perform the downstream tasks.
75
+ ### TODO
76
+ raise NotImplementedError
77
+
78
+
79
+ def forward(self, input_ids, attention_mask):
80
+ 'Takes a batch of sentences and produces embeddings for them.'
81
+ # The final BERT embedding is the hidden state of [CLS] token (the first token)
82
+ # Here, you can start by just returning the embeddings straight from BERT.
83
+ # When thinking of improvements, you can later try modifying this
84
+ # (e.g., by adding other layers).
85
+ ### TODO
86
+ raise NotImplementedError
87
+
88
+
89
+ def predict_sentiment(self, input_ids, attention_mask):
90
+ '''Given a batch of sentences, outputs logits for classifying sentiment.
91
+ There are 5 sentiment classes:
92
+ (0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
93
+ Thus, your output should contain 5 logits for each sentence.
94
+ '''
95
+ ### TODO
96
+ raise NotImplementedError
97
+
98
+
99
+ def predict_paraphrase(self,
100
+ input_ids_1, attention_mask_1,
101
+ input_ids_2, attention_mask_2):
102
+ '''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.
103
+ Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
104
+ during evaluation.
105
+ '''
106
+ ### TODO
107
+ raise NotImplementedError
108
+
109
+
110
+ def predict_similarity(self,
111
+ input_ids_1, attention_mask_1,
112
+ input_ids_2, attention_mask_2):
113
+ '''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
114
+ Note that your output should be unnormalized (a logit).
115
+ '''
116
+ ### TODO
117
+ raise NotImplementedError
118
+
119
+
120
+
121
+
122
+ def save_model(model, optimizer, args, config, filepath):
123
+ save_info = {
124
+ 'model': model.state_dict(),
125
+ 'optim': optimizer.state_dict(),
126
+ 'args': args,
127
+ 'model_config': config,
128
+ 'system_rng': random.getstate(),
129
+ 'numpy_rng': np.random.get_state(),
130
+ 'torch_rng': torch.random.get_rng_state(),
131
+ }
132
+
133
+ torch.save(save_info, filepath)
134
+ print(f"save the model to {filepath}")
135
+
136
+
137
+ def train_multitask(args):
138
+ '''Train MultitaskBERT.
139
+
140
+ Currently only trains on SST dataset. The way you incorporate training examples
141
+ from other datasets into the training procedure is up to you. To begin, take a
142
+ look at test_multitask below to see how you can use the custom torch `Dataset`s
143
+ in datasets.py to load in examples from the Quora and SemEval datasets.
144
+ '''
145
+ device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
146
+ # Create the data and its corresponding datasets and dataloader.
147
+ sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
148
+ sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')
149
+
150
+ sst_train_data = SentenceClassificationDataset(sst_train_data, args)
151
+ sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
152
+
153
+ sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
154
+ collate_fn=sst_train_data.collate_fn)
155
+ sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
156
+ collate_fn=sst_dev_data.collate_fn)
157
+
158
+ # Init model.
159
+ config = {'hidden_dropout_prob': args.hidden_dropout_prob,
160
+ 'num_labels': num_labels,
161
+ 'hidden_size': 768,
162
+ 'data_dir': '.',
163
+ 'fine_tune_mode': args.fine_tune_mode}
164
+
165
+ config = SimpleNamespace(**config)
166
+
167
+ model = MultitaskBERT(config)
168
+ model = model.to(device)
169
+
170
+ lr = args.lr
171
+ optimizer = AdamW(model.parameters(), lr=lr)
172
+ best_dev_acc = 0
173
+
174
+ # Run for the specified number of epochs.
175
+ for epoch in range(args.epochs):
176
+ model.train()
177
+ train_loss = 0
178
+ num_batches = 0
179
+ for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
180
+ b_ids, b_mask, b_labels = (batch['token_ids'],
181
+ batch['attention_mask'], batch['labels'])
182
+
183
+ b_ids = b_ids.to(device)
184
+ b_mask = b_mask.to(device)
185
+ b_labels = b_labels.to(device)
186
+
187
+ optimizer.zero_grad()
188
+ logits = model.predict_sentiment(b_ids, b_mask)
189
+ loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
190
+
191
+ loss.backward()
192
+ optimizer.step()
193
+
194
+ train_loss += loss.item()
195
+ num_batches += 1
196
+
197
+ train_loss = train_loss / (num_batches)
198
+
199
+ train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
200
+ dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)
201
+
202
+ if dev_acc > best_dev_acc:
203
+ best_dev_acc = dev_acc
204
+ save_model(model, optimizer, args, config, args.filepath)
205
+
206
+ print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
207
+
208
+
209
+ def test_multitask(args):
210
+ '''Test and save predictions on the dev and test sets of all three tasks.'''
211
+ with torch.no_grad():
212
+ device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
213
+ saved = torch.load(args.filepath)
214
+ config = saved['model_config']
215
+
216
+ model = MultitaskBERT(config)
217
+ model.load_state_dict(saved['model'])
218
+ model = model.to(device)
219
+ print(f"Loaded model to test from {args.filepath}")
220
+
221
+ sst_test_data, num_labels,para_test_data, sts_test_data = \
222
+ load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')
223
+
224
+ sst_dev_data, num_labels,para_dev_data, sts_dev_data = \
225
+ load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')
226
+
227
+ sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)
228
+ sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
229
+
230
+ sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,
231
+ collate_fn=sst_test_data.collate_fn)
232
+ sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
233
+ collate_fn=sst_dev_data.collate_fn)
234
+
235
+ para_test_data = SentencePairTestDataset(para_test_data, args)
236
+ para_dev_data = SentencePairDataset(para_dev_data, args)
237
+
238
+ para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,
239
+ collate_fn=para_test_data.collate_fn)
240
+ para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
241
+ collate_fn=para_dev_data.collate_fn)
242
+
243
+ sts_test_data = SentencePairTestDataset(sts_test_data, args)
244
+ sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)
245
+
246
+ sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,
247
+ collate_fn=sts_test_data.collate_fn)
248
+ sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,
249
+ collate_fn=sts_dev_data.collate_fn)
250
+
251
+ dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, \
252
+ dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \
253
+ dev_sts_corr, dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,
254
+ para_dev_dataloader,
255
+ sts_dev_dataloader, model, device)
256
+
257
+ test_sst_y_pred, \
258
+ test_sst_sent_ids, test_para_y_pred, test_para_sent_ids, test_sts_y_pred, test_sts_sent_ids = \
259
+ model_eval_test_multitask(sst_test_dataloader,
260
+ para_test_dataloader,
261
+ sts_test_dataloader, model, device)
262
+
263
+ with open(args.sst_dev_out, "w+") as f:
264
+ print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")
265
+ f.write(f"id \t Predicted_Sentiment \n")
266
+ for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):
267
+ f.write(f"{p} , {s} \n")
268
+
269
+ with open(args.sst_test_out, "w+") as f:
270
+ f.write(f"id \t Predicted_Sentiment \n")
271
+ for p, s in zip(test_sst_sent_ids, test_sst_y_pred):
272
+ f.write(f"{p} , {s} \n")
273
+
274
+ with open(args.para_dev_out, "w+") as f:
275
+ print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")
276
+ f.write(f"id \t Predicted_Is_Paraphrase \n")
277
+ for p, s in zip(dev_para_sent_ids, dev_para_y_pred):
278
+ f.write(f"{p} , {s} \n")
279
+
280
+ with open(args.para_test_out, "w+") as f:
281
+ f.write(f"id \t Predicted_Is_Paraphrase \n")
282
+ for p, s in zip(test_para_sent_ids, test_para_y_pred):
283
+ f.write(f"{p} , {s} \n")
284
+
285
+ with open(args.sts_dev_out, "w+") as f:
286
+ print(f"dev sts corr :: {dev_sts_corr :.3f}")
287
+ f.write(f"id \t Predicted_Similiary \n")
288
+ for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):
289
+ f.write(f"{p} , {s} \n")
290
+
291
+ with open(args.sts_test_out, "w+") as f:
292
+ f.write(f"id \t Predicted_Similiary \n")
293
+ for p, s in zip(test_sts_sent_ids, test_sts_y_pred):
294
+ f.write(f"{p} , {s} \n")
295
+
296
+
297
+ def get_args():
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")
300
+ parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")
301
+ parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")
302
+
303
+ parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
304
+ parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
305
+ parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
306
+
307
+ parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")
308
+ parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")
309
+ parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")
310
+
311
+ parser.add_argument("--seed", type=int, default=11711)
312
+ parser.add_argument("--epochs", type=int, default=10)
313
+ parser.add_argument("--fine-tune-mode", type=str,
314
+ help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
315
+ choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
316
+ parser.add_argument("--use_gpu", action='store_true')
317
+
318
+ parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")
319
+ parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")
320
+
321
+ parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
322
+ parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")
323
+
324
+ parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")
325
+ parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")
326
+
327
+ parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
328
+ parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
329
+ parser.add_argument("--lr", type=float, help="learning rate", default=1e-5)
330
+
331
+ args = parser.parse_args()
332
+ return args
333
+
334
+
335
+ if __name__ == "__main__":
336
+ args = get_args()
337
+ args.filepath = f'{args.fine_tune_mode}-{args.epochs}-{args.lr}-multitask.pt' # Save path.
338
+ seed_everything(args.seed) # Fix the seed for reproducibility.
339
+ train_multitask(args)
340
+ test_multitask(args)
optimizer.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable, Tuple
2
+ import math
3
+
4
+ import torch
5
+ from torch.optim import Optimizer
6
+
7
+
8
+ class AdamW(Optimizer):
9
+ def __init__(
10
+ self,
11
+ params: Iterable[torch.nn.parameter.Parameter],
12
+ lr: float = 1e-3,
13
+ betas: Tuple[float, float] = (0.9, 0.999),
14
+ eps: float = 1e-6,
15
+ weight_decay: float = 0.0,
16
+ correct_bias: bool = True,
17
+ ):
18
+ if lr < 0.0:
19
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
20
+ if not 0.0 <= betas[0] < 1.0:
21
+ raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
22
+ if not 0.0 <= betas[1] < 1.0:
23
+ raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
24
+ if not 0.0 <= eps:
25
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
26
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
27
+ super().__init__(params, defaults)
28
+
29
+ def step(self, closure: Callable = None):
30
+ loss = None
31
+ if closure is not None:
32
+ loss = closure()
33
+
34
+ for group in self.param_groups:
35
+ for p in group["params"]:
36
+ if p.grad is None:
37
+ continue
38
+ grad = p.grad.data
39
+ if grad.is_sparse:
40
+ raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
41
+
42
+ # Access state
43
+ state = self.state[p]
44
+
45
+ # Initialize state if not already done
46
+ if len(state) == 0:
47
+ state["step"] = 0
48
+ state["exp_avg"] = torch.zeros_like(p.data)
49
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
50
+
51
+ # Hyperparameters
52
+ alpha = group["lr"]
53
+ beta1, beta2 = group["betas"]
54
+ eps = group["eps"]
55
+ weight_decay = group["weight_decay"]
56
+ correct_bias = group["correct_bias"]
57
+
58
+ # Retrieve state variables
59
+ exp_avg = state["exp_avg"]
60
+ exp_avg_sq = state["exp_avg_sq"]
61
+ step = state["step"]
62
+
63
+ # Update step
64
+ step += 1
65
+ state["step"] = step
66
+
67
+ # Update biased first and second moment estimates
68
+ exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
69
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
70
+
71
+ # Compute bias-corrected moments
72
+ if correct_bias:
73
+ bias_correction1 = 1 - beta1 ** step
74
+ bias_correction2 = 1 - beta2 ** step
75
+ exp_avg_corr = exp_avg / bias_correction1
76
+ exp_avg_sq_corr = exp_avg_sq / bias_correction2
77
+ else:
78
+ exp_avg_corr = exp_avg
79
+ exp_avg_sq_corr = exp_avg_sq
80
+
81
+ # Update parameters
82
+ denom = exp_avg_sq_corr.sqrt().add_(eps)
83
+ step_size = alpha
84
+ p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
85
+
86
+ # Apply weight decay
87
+ if weight_decay != 0:
88
+ p.data.add_(p.data, alpha=-alpha * weight_decay)
89
+
90
+ return loss
optimizer_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
3
+ size 152
optimizer_test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from optimizer import AdamW
4
+
5
+ seed = 0
6
+
7
+
8
+ def test_optimizer(opt_class) -> torch.Tensor:
9
+ rng = np.random.default_rng(seed)
10
+ torch.manual_seed(seed)
11
+ model = torch.nn.Linear(3, 2, bias=False)
12
+ opt = opt_class(
13
+ model.parameters(),
14
+ lr=1e-3,
15
+ weight_decay=1e-4,
16
+ correct_bias=True,
17
+ )
18
+ for i in range(1000):
19
+ opt.zero_grad()
20
+ x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
21
+ y_hat = model(x)
22
+ y = torch.Tensor([x[0] + x[1], -x[2]])
23
+ loss = ((y - y_hat) ** 2).sum()
24
+ loss.backward()
25
+ opt.step()
26
+ return model.weight.detach()
27
+
28
+
29
+ ref = torch.tensor(np.load("optimizer_test.npy"))
30
+ actual = test_optimizer(AdamW)
31
+ print(ref)
32
+ print(actual)
33
+ assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
34
+ print("Optimizer test passed!")
predictions/README ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ By default, `classifier.py` and `multitask_classifier.py` write your model predictions into this folder.
2
+ Before running prepare_submit.py, make sure that this directory has been populated!
predictions/last-linear-layer-cfimdb-dev-out.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3f994587376345ea6a1e80a7946d5889259f6a427989c71e0b45de28ea4545d
3
+ size 7621
predictions/last-linear-layer-cfimdb-test-out.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ebedf210c8973e02648e96152e253daa2385b230a48da151812a58d80178536
3
+ size 15154
predictions/last-linear-layer-sst-dev-out.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22412dead5299ffb8fae45448f240cb135e3ad5dc04cea96975e893bdd719ba8
3
+ size 34157
predictions/last-linear-layer-sst-test-out.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3455d6637e5ecd118c31e48534d92298da3c865ed11ad93e2aadc09fcc743666
3
+ size 68536
prepare_submit.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creates a zip file for submission on Gradescope.
2
+
3
+ import os
4
+ import zipfile
5
+
6
+ required_files = [p for p in os.listdir('.') if p.endswith('.py')] + \
7
+ [f'predictions/{p}' for p in os.listdir('predictions')]
8
+
9
+ def main():
10
+ aid = 'cs224n_default_final_project_submission'
11
+ path = os.getcwd()
12
+ with zipfile.ZipFile(f"{aid}.zip", 'w') as zz:
13
+ for file in required_files:
14
+ zz.write(file, os.path.join(".", file))
15
+ print(f"Submission zip file created: {aid}.zip")
16
+
17
+ if __name__ == '__main__':
18
+ main()
sanity_check.data ADDED
Binary file (56.4 kB). View file
 
sanity_check.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from bert import BertModel
3
+
4
+
5
+ sanity_data = torch.load("./sanity_check.data", weights_only=True)
6
+ sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
7
+ [101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
8
+ att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
9
+
10
+ # Load model.
11
+ bert = BertModel.from_pretrained('bert-base-uncased')
12
+ outputs = bert(sent_ids, att_mask)
13
+ att_mask = att_mask.unsqueeze(-1)
14
+ outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
15
+ sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
16
+
17
+ for k in ['last_hidden_state', 'pooler_output']:
18
+ assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
19
+ print("Your BERT implementation is correct!")
setup.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ conda create -n cs224n_dfp python=3.8
4
+ conda activate cs224n_dfp
5
+
6
+ pip install torch torchvision torchaudio
7
+ pip install tqdm==4.58.0
8
+ pip install requests==2.25.1
9
+ pip install importlib-metadata==3.7.0
10
+ pip install filelock==3.0.12
11
+ pip install sklearn==0.0
12
+ pip install tokenizers==0.15
13
+ pip install explainaboard_client==0.0.7
sst-classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62f6282ea608a997c1b43071cedcb1c4ba454b420305c7b15138aa9d7f70103d
3
+ size 438072793
tokenizer.py ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union, Tuple, BinaryIO
2
+ import os
3
+ import sys
4
+ import json
5
+ import tempfile
6
+ import copy
7
+ from tqdm.auto import tqdm
8
+ from functools import partial
9
+ from urllib.parse import urlparse
10
+ from pathlib import Path
11
+ import requests
12
+ from hashlib import sha256
13
+ from filelock import FileLock
14
+ import importlib_metadata
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ import fnmatch
19
+
20
+ __version__ = "4.0.0"
21
+ _torch_version = importlib_metadata.version("torch")
22
+
23
+ hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
24
+ default_cache_path = os.path.join(hf_cache_home, "transformers")
25
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
26
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
27
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
28
+
29
+ PRESET_MIRROR_DICT = {
30
+ "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
31
+ "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
32
+ }
33
+ HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
34
+ WEIGHTS_NAME = "pytorch_model.bin"
35
+ CONFIG_NAME = "config.json"
36
+
37
+
38
+ def is_torch_available():
39
+ return True
40
+
41
+
42
+ def is_tf_available():
43
+ return False
44
+
45
+
46
+ def is_remote_url(url_or_filename):
47
+ parsed = urlparse(url_or_filename)
48
+ return parsed.scheme in ("http", "https")
49
+
50
+
51
+ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
52
+ headers = copy.deepcopy(headers)
53
+ if resume_size > 0:
54
+ headers["Range"] = "bytes=%d-" % (resume_size,)
55
+ r = requests.get(url, stream=True, proxies=proxies, headers=headers)
56
+ r.raise_for_status()
57
+ content_length = r.headers.get("Content-Length")
58
+ total = resume_size + int(content_length) if content_length is not None else None
59
+ progress = tqdm(
60
+ unit="B",
61
+ unit_scale=True,
62
+ total=total,
63
+ initial=resume_size,
64
+ desc="Downloading",
65
+ disable=False,
66
+ )
67
+ for chunk in r.iter_content(chunk_size=1024):
68
+ if chunk: # filter out keep-alive new chunks
69
+ progress.update(len(chunk))
70
+ temp_file.write(chunk)
71
+ progress.close()
72
+
73
+
74
+ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
75
+ url_bytes = url.encode("utf-8")
76
+ filename = sha256(url_bytes).hexdigest()
77
+
78
+ if etag:
79
+ etag_bytes = etag.encode("utf-8")
80
+ filename += "." + sha256(etag_bytes).hexdigest()
81
+
82
+ if url.endswith(".h5"):
83
+ filename += ".h5"
84
+
85
+ return filename
86
+
87
+
88
+ def hf_bucket_url(
89
+ model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
90
+ ) -> str:
91
+ if subfolder is not None:
92
+ filename = f"{subfolder}/{filename}"
93
+
94
+ if mirror:
95
+ endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
96
+ legacy_format = "/" not in model_id
97
+ if legacy_format:
98
+ return f"{endpoint}/{model_id}-{filename}"
99
+ else:
100
+ return f"{endpoint}/{model_id}/{filename}"
101
+
102
+ if revision is None:
103
+ revision = "main"
104
+ return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
105
+
106
+
107
+ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
108
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
109
+ if is_torch_available():
110
+ ua += f"; torch/{_torch_version}"
111
+ if is_tf_available():
112
+ ua += f"; tensorflow/{_tf_version}"
113
+ if isinstance(user_agent, dict):
114
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
115
+ elif isinstance(user_agent, str):
116
+ ua += "; " + user_agent
117
+ return ua
118
+
119
+
120
+ def get_from_cache(
121
+ url: str,
122
+ cache_dir=None,
123
+ force_download=False,
124
+ proxies=None,
125
+ etag_timeout=10,
126
+ resume_download=False,
127
+ user_agent: Union[Dict, str, None] = None,
128
+ use_auth_token: Union[bool, str, None] = None,
129
+ local_files_only=False,
130
+ ) -> Optional[str]:
131
+ if cache_dir is None:
132
+ cache_dir = TRANSFORMERS_CACHE
133
+ if isinstance(cache_dir, Path):
134
+ cache_dir = str(cache_dir)
135
+
136
+ os.makedirs(cache_dir, exist_ok=True)
137
+
138
+ headers = {"user-agent": http_user_agent(user_agent)}
139
+ if isinstance(use_auth_token, str):
140
+ headers["authorization"] = "Bearer {}".format(use_auth_token)
141
+ elif use_auth_token:
142
+ token = HfFolder.get_token()
143
+ if token is None:
144
+ raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
145
+ headers["authorization"] = "Bearer {}".format(token)
146
+
147
+ url_to_download = url
148
+ etag = None
149
+ if not local_files_only:
150
+ try:
151
+ r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
152
+ r.raise_for_status()
153
+ etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
154
+ # We favor a custom header indicating the etag of the linked resource, and
155
+ # we fallback to the regular etag header.
156
+ # If we don't have any of those, raise an error.
157
+ if etag is None:
158
+ raise OSError(
159
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
160
+ )
161
+ # In case of a redirect,
162
+ # save an extra redirect on the request.get call,
163
+ # and ensure we download the exact atomic version even if it changed
164
+ # between the HEAD and the GET (unlikely, but hey).
165
+ if 300 <= r.status_code <= 399:
166
+ url_to_download = r.headers["Location"]
167
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
168
+ # etag is already None
169
+ pass
170
+
171
+ filename = url_to_filename(url, etag)
172
+
173
+ # get cache path to put the file
174
+ cache_path = os.path.join(cache_dir, filename)
175
+
176
+ # etag is None == we don't have a connection or we passed local_files_only.
177
+ # try to get the last downloaded one
178
+ if etag is None:
179
+ if os.path.exists(cache_path):
180
+ return cache_path
181
+ else:
182
+ matching_files = [
183
+ file
184
+ for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
185
+ if not file.endswith(".json") and not file.endswith(".lock")
186
+ ]
187
+ if len(matching_files) > 0:
188
+ return os.path.join(cache_dir, matching_files[-1])
189
+ else:
190
+ # If files cannot be found and local_files_only=True,
191
+ # the models might've been found if local_files_only=False
192
+ # Notify the user about that
193
+ if local_files_only:
194
+ raise FileNotFoundError(
195
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
196
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
197
+ " to False."
198
+ )
199
+ else:
200
+ raise ValueError(
201
+ "Connection error, and we cannot find the requested files in the cached path."
202
+ " Please try again or make sure your Internet connection is on."
203
+ )
204
+
205
+ # From now on, etag is not None.
206
+ if os.path.exists(cache_path) and not force_download:
207
+ return cache_path
208
+
209
+ # Prevent parallel downloads of the same file with a lock.
210
+ lock_path = cache_path + ".lock"
211
+ with FileLock(lock_path):
212
+
213
+ # If the download just completed while the lock was activated.
214
+ if os.path.exists(cache_path) and not force_download:
215
+ # Even if returning early like here, the lock will be released.
216
+ return cache_path
217
+
218
+ if resume_download:
219
+ incomplete_path = cache_path + ".incomplete"
220
+
221
+ @contextmanager
222
+ def _resumable_file_manager() -> "io.BufferedWriter":
223
+ with open(incomplete_path, "ab") as f:
224
+ yield f
225
+
226
+ temp_file_manager = _resumable_file_manager
227
+ if os.path.exists(incomplete_path):
228
+ resume_size = os.stat(incomplete_path).st_size
229
+ else:
230
+ resume_size = 0
231
+ else:
232
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
233
+ resume_size = 0
234
+
235
+ # Download to temporary file, then copy to cache dir once finished.
236
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
237
+ with temp_file_manager() as temp_file:
238
+ http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
239
+
240
+ os.replace(temp_file.name, cache_path)
241
+
242
+ meta = {"url": url, "etag": etag}
243
+ meta_path = cache_path + ".json"
244
+ with open(meta_path, "w") as meta_file:
245
+ json.dump(meta, meta_file)
246
+
247
+ return cache_path
248
+
249
+
250
+ def cached_path(
251
+ url_or_filename,
252
+ cache_dir=None,
253
+ force_download=False,
254
+ proxies=None,
255
+ resume_download=False,
256
+ user_agent: Union[Dict, str, None] = None,
257
+ extract_compressed_file=False,
258
+ force_extract=False,
259
+ use_auth_token: Union[bool, str, None] = None,
260
+ local_files_only=False,
261
+ ) -> Optional[str]:
262
+ if cache_dir is None:
263
+ cache_dir = TRANSFORMERS_CACHE
264
+ if isinstance(url_or_filename, Path):
265
+ url_or_filename = str(url_or_filename)
266
+ if isinstance(cache_dir, Path):
267
+ cache_dir = str(cache_dir)
268
+
269
+ if is_remote_url(url_or_filename):
270
+ # URL, so get it from the cache (downloading if necessary)
271
+ output_path = get_from_cache(
272
+ url_or_filename,
273
+ cache_dir=cache_dir,
274
+ force_download=force_download,
275
+ proxies=proxies,
276
+ resume_download=resume_download,
277
+ user_agent=user_agent,
278
+ use_auth_token=use_auth_token,
279
+ local_files_only=local_files_only,
280
+ )
281
+ elif os.path.exists(url_or_filename):
282
+ # File, and it exists.
283
+ output_path = url_or_filename
284
+ elif urlparse(url_or_filename).scheme == "":
285
+ # File, but it doesn't exist.
286
+ raise EnvironmentError("file {} not found".format(url_or_filename))
287
+ else:
288
+ # Something unknown
289
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
290
+
291
+ if extract_compressed_file:
292
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
293
+ return output_path
294
+
295
+ # Path where we extract compressed archives
296
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
297
+ output_dir, output_file = os.path.split(output_path)
298
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
299
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
300
+
301
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
302
+ return output_path_extracted
303
+
304
+ # Prevent parallel extractions
305
+ lock_path = output_path + ".lock"
306
+ with FileLock(lock_path):
307
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
308
+ os.makedirs(output_path_extracted)
309
+ if is_zipfile(output_path):
310
+ with ZipFile(output_path, "r") as zip_file:
311
+ zip_file.extractall(output_path_extracted)
312
+ zip_file.close()
313
+ elif tarfile.is_tarfile(output_path):
314
+ tar_file = tarfile.open(output_path)
315
+ tar_file.extractall(output_path_extracted)
316
+ tar_file.close()
317
+ else:
318
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
319
+
320
+ return output_path_extracted
321
+
322
+ return output_path
323
+
324
+
325
+ def get_parameter_dtype(parameter: Union[nn.Module]):
326
+ try:
327
+ return next(parameter.parameters()).dtype
328
+ except StopIteration:
329
+ # For nn.DataParallel compatibility in PyTorch 1.5
330
+
331
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
332
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
333
+ return tuples
334
+
335
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
336
+ first_tuple = next(gen)
337
+ return first_tuple[1].dtype
338
+
339
+
340
+ def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
341
+ # attention_mask [batch_size, seq_length]
342
+ assert attention_mask.dim() == 2
343
+ # [batch_size, 1, 1, seq_length] for multi-head attention
344
+ extended_attention_mask = attention_mask[:, None, None, :]
345
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
346
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
347
+ return extended_attention_mask
zemo1.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ import torch.optim as optim
5
+
6
+ # Bước 1: Chuẩn bị dữ liệu mẫu
7
+ # Dữ liệu giả: mỗi dòng là [giờ học, giờ giải trí, giờ ngủ], điểm trung bình
8
+ data = [
9
+ [2, 1, 7, 6.0],
10
+ [3, 2, 6, 7.5],
11
+ [1, 3, 8, 5.5],
12
+ [4, 1, 6, 8.0],
13
+ [5, 0, 5, 9.0],
14
+ [6, 0, 6, 9.5]
15
+ ]
16
+
17
+ # Tách đặc trưng (features) và mục tiêu (target)
18
+ X = torch.tensor([row[:3] for row in data], dtype=torch.float32) # Giờ học, giờ giải trí, giờ ngủ
19
+ y = torch.tensor([[row[3]] for row in data], dtype=torch.float32) # Điểm trung bình
20
+
21
+ # Bước 2: Xây dựng mô hình
22
+ class StudentGradeModel(nn.Module):
23
+ def __init__(self):
24
+ super(StudentGradeModel, self).__init__()
25
+ self.linear = nn.Linear(3, 1) # 3 đầu vào, 1 đầu ra
26
+
27
+ def forward(self, x):
28
+ return self.linear(x)
29
+
30
+ model = StudentGradeModel()
31
+
32
+ # Bước 3: Định nghĩa hàm mất mát và bộ tối ưu
33
+ criterion = nn.MSELoss()
34
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
35
+
36
+ # Bước 4: Huấn luyện mô hình
37
+ for epoch in tqdm(range(10000), desc="Training Epochs"):
38
+ optimizer.zero_grad() # Xóa gradient cũ
39
+ output = model(X) # Truyền dữ liệu qua mô hình
40
+ loss = criterion(output, y) # Tính mất mát
41
+ loss.backward() # Tính gradient
42
+ optimizer.step() # Cập nhật trọng số
43
+
44
+ # In loss để theo dõi quá trình huấn luyện
45
+ if (epoch + 1) % 1000 == 0:
46
+ tqdm.write(f'Epoch [{epoch + 1}/10000], Loss: {loss.item():.4f}')
47
+
48
+ # Bước 5: Dự đoán thử với một học sinh mới
49
+ model.eval()
50
+ with torch.no_grad():
51
+ test_input = torch.tensor([[4, 1, 6]], dtype=torch.float32) # Ví dụ: 4 giờ học, 1 giờ giải trí, 6 giờ ngủ
52
+ prediction = model(test_input)
53
+ print("Dự đoán điểm trung bình:", prediction.item())
zemo2.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # Xây dựng mô hình RNN
5
+ class RNNModel(nn.Module):
6
+ def __init__(self, input_size, hidden_size, output_size):
7
+ super(RNNModel, self).__init__()
8
+ self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) # Định nghĩa RNN
9
+ self.fc = nn.Linear(hidden_size, output_size) # Lớp fully connected để dự đoán output
10
+
11
+ def forward(self, x):
12
+ out, _ = self.rnn(x) # Lấy output từ RNN
13
+ out = out[:, -1, :] # Lấy output của bước cuối cùng (nếu dữ liệu có nhiều bước thời gian)
14
+ out = self.fc(out) # Dự đoán output
15
+ return out
16
+
17
+ # Khởi tạo mô hình
18
+ input_size = 10 # Kích thước đầu vào
19
+ hidden_size = 20 # Số lượng hidden units
20
+ output_size = 1 # Đầu ra (ví dụ: hồi quy)
21
+ model = RNNModel(input_size, hidden_size, output_size)
22
+
23
+ # Khởi tạo dữ liệu giả
24
+ X = torch.randn(32, 5, 10) # 32 samples, 5 bước thời gian, mỗi bước có 10 đặc trưng
25
+ y = torch.randn(32, 1) # 32 samples, 1 giá trị đầu ra cho mỗi sample
26
+
27
+ # Hàm mất mát và bộ tối ưu
28
+ criterion = nn.MSELoss()
29
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
30
+
31
+ # Huấn luyện mô hình
32
+ for epoch in range(100):
33
+ model.train()
34
+ optimizer.zero_grad()
35
+ output = model(X) # Truyền dữ liệu qua mô hình
36
+ loss = criterion(output, y) # Tính mất mát
37
+ loss.backward() # Tính gradient
38
+ optimizer.step() # Cập nhật trọng số
39
+
40
+ if (epoch + 1) % 10 == 0:
41
+ print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
zemo3.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tokenizer import BertTokenizer
3
+ from torch import nn
4
+ from bert import BertModel
5
+
6
+ # Initialize the BERT tokenizer
7
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+
9
+ # Example sentence
10
+ sentences = [
11
+ "She loves reading novels in her free time",
12
+ "An apple a day keeps the doctor away",
13
+ "If you can't explain it simply, you don't understand it well enough."
14
+ ]
15
+
16
+ # Tokenize and encode the sentence
17
+ encoding = tokenizer.batch_encode_plus(
18
+ sentences,
19
+ max_length=512,
20
+ padding='max_length',
21
+ truncation=True,
22
+ return_tensors='pt'
23
+ )
24
+
25
+ # Get the token IDs from the encoding
26
+ input_ids = encoding['input_ids']
27
+ attention_mask = encoding['attention_mask']
28
+
29
+ model = BertModel.from_pretrained('bert-base-uncased')
30
+
31
+ assert isinstance(model, BertModel)
32
+ print(model.embed(input_ids).size())