Commit
·
a0b398e
1
Parent(s):
354d4d3
Transfer code from Kaggle
Browse files- .python-version +1 -1
- base_bert.py +247 -0
- bert.py +221 -0
- classifier.py +199 -0
- classifier_utils.py +235 -0
- config.py +222 -0
- constants.py +36 -0
- everything.py +20 -0
- finetune-scripts/sup.py +12 -0
- finetune-scripts/unsup.py +12 -0
- finetuning.py +173 -0
- minbert-data/amazon-polarity.parquet +3 -0
- minbert-data/ids-cfimdb-dev.csv +3 -0
- minbert-data/ids-cfimdb-test-student.csv +3 -0
- minbert-data/ids-cfimdb-train.csv +3 -0
- minbert-data/ids-sst-dev.csv +3 -0
- minbert-data/ids-sst-test-student.csv +3 -0
- minbert-data/ids-sst-train.csv +3 -0
- minbert-data/nli-train.parquet +3 -0
- minbert-data/optimizer_test.npy +3 -0
- minbert-data/sanity_check.data +0 -0
- minbert-data/stsb-dev.parquet +3 -0
- minbert-model/sup-cse-bert.pth +3 -0
- minbert-model/unsup-cse-bert.pth +3 -0
- optimizer.py +90 -0
- optimizer_test.py +36 -0
- requirements.txt +9 -0
- sanity_check.py +21 -0
- tokenizer.py +0 -0
- train-scripts/base_cfimdb_onfm.py +15 -0
- train-scripts/base_cfimdb_onll.py +15 -0
- train-scripts/base_sst_onfm.py +15 -0
- train-scripts/base_sst_onll.py +15 -0
- train-scripts/finetuned_bert.py +15 -0
- train-scripts/sup_cfimdb_onfm.py +17 -0
- train-scripts/sup_cfimdb_onll.py +17 -0
- train-scripts/sup_sst_onfm.py +17 -0
- train-scripts/sup_sst_onll.py +17 -0
- train-scripts/unsup_cfimdb_onfm.py +17 -0
- train-scripts/unsup_cfimdb_onll.py +17 -0
- train-scripts/unsup_sst_onfm.py +17 -0
- train-scripts/unsup_sst_onll.py +17 -0
- utils.py +349 -0
.python-version
CHANGED
@@ -1 +1 @@
|
|
1 |
-
3.
|
|
|
1 |
+
3.10.15
|
base_bert.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from torch import device, dtype, nn
|
3 |
+
from config import BertConfig, PretrainedConfig
|
4 |
+
from utils import *
|
5 |
+
|
6 |
+
class BertPreTrainedModel(nn.Module):
|
7 |
+
config_class = BertConfig
|
8 |
+
base_model_prefix = "bert"
|
9 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
10 |
+
_keys_to_ignore_on_load_unexpected = None
|
11 |
+
|
12 |
+
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.config = config
|
15 |
+
self.name_or_path = config.name_or_path
|
16 |
+
|
17 |
+
def init_weights(self):
|
18 |
+
# Initialize weights
|
19 |
+
self.apply(self._init_weights)
|
20 |
+
|
21 |
+
def _init_weights(self, module):
|
22 |
+
""" Initialize the weights """
|
23 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
24 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
25 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
26 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
27 |
+
elif isinstance(module, nn.LayerNorm):
|
28 |
+
module.bias.data.zero_()
|
29 |
+
module.weight.data.fill_(1.0)
|
30 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
31 |
+
module.bias.data.zero_()
|
32 |
+
|
33 |
+
@property
|
34 |
+
def dtype(self) -> dtype:
|
35 |
+
return get_parameter_dtype(self)
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
39 |
+
config = kwargs.pop("config", None)
|
40 |
+
state_dict = kwargs.pop("state_dict", None)
|
41 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
42 |
+
force_download = kwargs.pop("force_download", False)
|
43 |
+
resume_download = kwargs.pop("resume_download", False)
|
44 |
+
proxies = kwargs.pop("proxies", None)
|
45 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
46 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
47 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
48 |
+
revision = kwargs.pop("revision", None)
|
49 |
+
mirror = kwargs.pop("mirror", None)
|
50 |
+
|
51 |
+
# Load config if we don't provide a configuration
|
52 |
+
if not isinstance(config, PretrainedConfig):
|
53 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
54 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
55 |
+
config_path,
|
56 |
+
*model_args,
|
57 |
+
cache_dir=cache_dir,
|
58 |
+
return_unused_kwargs=True,
|
59 |
+
force_download=force_download,
|
60 |
+
resume_download=resume_download,
|
61 |
+
proxies=proxies,
|
62 |
+
local_files_only=local_files_only,
|
63 |
+
use_auth_token=use_auth_token,
|
64 |
+
revision=revision,
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
model_kwargs = kwargs
|
69 |
+
|
70 |
+
# Load model
|
71 |
+
if pretrained_model_name_or_path is not None:
|
72 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
73 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
74 |
+
# Load from a PyTorch checkpoint
|
75 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
76 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
77 |
+
archive_file = pretrained_model_name_or_path
|
78 |
+
else:
|
79 |
+
archive_file = hf_bucket_url(
|
80 |
+
pretrained_model_name_or_path,
|
81 |
+
filename=WEIGHTS_NAME,
|
82 |
+
revision=revision,
|
83 |
+
mirror=mirror,
|
84 |
+
)
|
85 |
+
try:
|
86 |
+
# Load from URL or cache if already cached
|
87 |
+
resolved_archive_file = cached_path(
|
88 |
+
archive_file,
|
89 |
+
cache_dir=cache_dir,
|
90 |
+
force_download=force_download,
|
91 |
+
proxies=proxies,
|
92 |
+
resume_download=resume_download,
|
93 |
+
local_files_only=local_files_only,
|
94 |
+
use_auth_token=use_auth_token,
|
95 |
+
)
|
96 |
+
except EnvironmentError as err:
|
97 |
+
#logger.error(err)
|
98 |
+
msg = (
|
99 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
100 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
101 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
|
102 |
+
)
|
103 |
+
raise EnvironmentError(msg)
|
104 |
+
else:
|
105 |
+
resolved_archive_file = None
|
106 |
+
|
107 |
+
config.name_or_path = pretrained_model_name_or_path
|
108 |
+
|
109 |
+
# Instantiate model.
|
110 |
+
model = cls(config, *model_args, **model_kwargs)
|
111 |
+
|
112 |
+
if state_dict is None:
|
113 |
+
try:
|
114 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
|
115 |
+
except Exception:
|
116 |
+
raise OSError(
|
117 |
+
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
118 |
+
f"at '{resolved_archive_file}'"
|
119 |
+
)
|
120 |
+
|
121 |
+
missing_keys = []
|
122 |
+
unexpected_keys = []
|
123 |
+
error_msgs = []
|
124 |
+
|
125 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
126 |
+
old_keys = []
|
127 |
+
new_keys = []
|
128 |
+
m = {'embeddings.word_embeddings': 'word_embedding',
|
129 |
+
'embeddings.position_embeddings': 'pos_embedding',
|
130 |
+
'embeddings.token_type_embeddings': 'tk_type_embedding',
|
131 |
+
'embeddings.LayerNorm': 'embed_layer_norm',
|
132 |
+
'embeddings.dropout': 'embed_dropout',
|
133 |
+
'encoder.layer': 'bert_layers',
|
134 |
+
'pooler.dense': 'pooler_dense',
|
135 |
+
'pooler.activation': 'pooler_af',
|
136 |
+
'attention.self': "self_attention",
|
137 |
+
'attention.output.dense': 'attention_dense',
|
138 |
+
'attention.output.LayerNorm': 'attention_layer_norm',
|
139 |
+
'attention.output.dropout': 'attention_dropout',
|
140 |
+
'intermediate.dense': 'interm_dense',
|
141 |
+
'intermediate.intermediate_act_fn': 'interm_af',
|
142 |
+
'output.dense': 'out_dense',
|
143 |
+
'output.LayerNorm': 'out_layer_norm',
|
144 |
+
'output.dropout': 'out_dropout'}
|
145 |
+
|
146 |
+
for key in state_dict.keys():
|
147 |
+
new_key = None
|
148 |
+
if "gamma" in key:
|
149 |
+
new_key = key.replace("gamma", "weight")
|
150 |
+
if "beta" in key:
|
151 |
+
new_key = key.replace("beta", "bias")
|
152 |
+
for x, y in m.items():
|
153 |
+
if new_key is not None:
|
154 |
+
_key = new_key
|
155 |
+
else:
|
156 |
+
_key = key
|
157 |
+
if x in key:
|
158 |
+
new_key = _key.replace(x, y)
|
159 |
+
if new_key:
|
160 |
+
old_keys.append(key)
|
161 |
+
new_keys.append(new_key)
|
162 |
+
|
163 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
164 |
+
# print(old_key, new_key)
|
165 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
166 |
+
|
167 |
+
# copy state_dict so _load_from_state_dict can modify it
|
168 |
+
metadata = getattr(state_dict, "_metadata", None)
|
169 |
+
state_dict = state_dict.copy()
|
170 |
+
if metadata is not None:
|
171 |
+
state_dict._metadata = metadata
|
172 |
+
|
173 |
+
your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
|
174 |
+
for k in state_dict:
|
175 |
+
if k not in your_bert_params and not k.startswith("cls."):
|
176 |
+
possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
|
177 |
+
raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
|
178 |
+
|
179 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
180 |
+
# so we need to apply the function recursively.
|
181 |
+
def load(module: nn.Module, prefix=""):
|
182 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
183 |
+
module._load_from_state_dict(
|
184 |
+
state_dict,
|
185 |
+
prefix,
|
186 |
+
local_metadata,
|
187 |
+
True,
|
188 |
+
missing_keys,
|
189 |
+
unexpected_keys,
|
190 |
+
error_msgs,
|
191 |
+
)
|
192 |
+
for name, child in module._modules.items():
|
193 |
+
if child is not None:
|
194 |
+
load(child, prefix + name + ".")
|
195 |
+
|
196 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
197 |
+
start_prefix = ""
|
198 |
+
model_to_load = model
|
199 |
+
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
200 |
+
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
201 |
+
start_prefix = cls.base_model_prefix + "."
|
202 |
+
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
203 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
204 |
+
load(model_to_load, prefix=start_prefix)
|
205 |
+
|
206 |
+
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
207 |
+
base_model_state_dict = model_to_load.state_dict().keys()
|
208 |
+
head_model_state_dict_without_base_prefix = [
|
209 |
+
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
210 |
+
]
|
211 |
+
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
212 |
+
|
213 |
+
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
214 |
+
# the user.
|
215 |
+
if cls._keys_to_ignore_on_load_missing is not None:
|
216 |
+
for pat in cls._keys_to_ignore_on_load_missing:
|
217 |
+
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
218 |
+
|
219 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
220 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
221 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
222 |
+
|
223 |
+
if len(error_msgs) > 0:
|
224 |
+
raise RuntimeError(
|
225 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
226 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
231 |
+
model.eval()
|
232 |
+
|
233 |
+
if output_loading_info:
|
234 |
+
loading_info = {
|
235 |
+
"missing_keys": missing_keys,
|
236 |
+
"unexpected_keys": unexpected_keys,
|
237 |
+
"error_msgs": error_msgs,
|
238 |
+
}
|
239 |
+
return model, loading_info
|
240 |
+
|
241 |
+
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
242 |
+
import torch_xla.core.xla_model as xm
|
243 |
+
|
244 |
+
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
245 |
+
model.to(xm.xla_device())
|
246 |
+
|
247 |
+
return model
|
bert.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class BertSelfAttention(nn.Module):
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.num_attention_heads = config.num_attention_heads
|
12 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
13 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
14 |
+
|
15 |
+
# Initialize the linear transformation layers for key, value, query.
|
16 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
17 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
18 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
19 |
+
# This dropout is applied to normalized attention scores following the original
|
20 |
+
# implementation of transformer. Although it is a bit unusual, we empirically
|
21 |
+
# observe that it yields better performance.
|
22 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
23 |
+
|
24 |
+
def transform(self, x, linear_layer):
|
25 |
+
# The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
|
26 |
+
bs, seq_len = x.shape[:2]
|
27 |
+
proj = linear_layer(x)
|
28 |
+
# Next, we need to produce multiple heads for the proj. This is done by spliting the
|
29 |
+
# hidden state to self.num_attention_heads, each of size self.attention_head_size.
|
30 |
+
proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
|
31 |
+
# By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
|
32 |
+
proj = proj.transpose(1, 2)
|
33 |
+
return proj
|
34 |
+
|
35 |
+
def attention(self, key, query, value, attention_mask):
|
36 |
+
"""
|
37 |
+
key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
38 |
+
attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
|
39 |
+
"""
|
40 |
+
|
41 |
+
d_k = query.size(-1) # attention_head_size
|
42 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
|
43 |
+
# attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
|
44 |
+
|
45 |
+
# Apply attention mask
|
46 |
+
attention_scores = attention_scores + attention_mask
|
47 |
+
|
48 |
+
# Normalize scores with softmax and apply dropout.
|
49 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
50 |
+
attention_probs = self.dropout(attention_probs)
|
51 |
+
|
52 |
+
context = torch.matmul(attention_probs, value)
|
53 |
+
# context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
|
54 |
+
|
55 |
+
# Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
|
56 |
+
context = context.transpose(1, 2).contiguous()
|
57 |
+
context = context.view(context.size(0), context.size(1), -1)
|
58 |
+
|
59 |
+
return context
|
60 |
+
|
61 |
+
|
62 |
+
def forward(self, hidden_states, attention_mask):
|
63 |
+
"""
|
64 |
+
hidden_states: [bs, seq_len, hidden_state]
|
65 |
+
attention_mask: [bs, 1, 1, seq_len]
|
66 |
+
output: [bs, seq_len, hidden_state]
|
67 |
+
"""
|
68 |
+
# First, we have to generate the key, value, query for each token for multi-head attention
|
69 |
+
# using self.transform (more details inside the function).
|
70 |
+
# Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
|
71 |
+
key_layer = self.transform(hidden_states, self.key)
|
72 |
+
value_layer = self.transform(hidden_states, self.value)
|
73 |
+
query_layer = self.transform(hidden_states, self.query)
|
74 |
+
# Calculate the multi-head attention.
|
75 |
+
attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
|
76 |
+
return attn_value
|
77 |
+
|
78 |
+
|
79 |
+
class BertLayer(nn.Module):
|
80 |
+
def __init__(self, config):
|
81 |
+
super().__init__()
|
82 |
+
# Multi-head attention.
|
83 |
+
self.self_attention = BertSelfAttention(config)
|
84 |
+
# Add-norm for multi-head attention.
|
85 |
+
self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
86 |
+
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
87 |
+
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
|
88 |
+
# Feed forward.
|
89 |
+
self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
90 |
+
self.interm_af = F.gelu
|
91 |
+
# Add-norm for feed forward.
|
92 |
+
self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
93 |
+
self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
94 |
+
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
|
95 |
+
|
96 |
+
|
97 |
+
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
|
98 |
+
transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
|
99 |
+
transformed_output = dropout(transformed_output) # Áp dụng dropout
|
100 |
+
added_output = input + transformed_output # Kết hợp input và output
|
101 |
+
normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
|
102 |
+
return normalized_output
|
103 |
+
|
104 |
+
|
105 |
+
def forward(self, hidden_states, attention_mask):
|
106 |
+
# 1. Multi-head attention
|
107 |
+
attention_output = self.self_attention(hidden_states, attention_mask)
|
108 |
+
|
109 |
+
# 2. Add-norm after attention
|
110 |
+
attention_output = self.add_norm(
|
111 |
+
hidden_states,
|
112 |
+
attention_output,
|
113 |
+
self.attention_dense,
|
114 |
+
self.attention_dropout,
|
115 |
+
self.attention_layer_norm
|
116 |
+
)
|
117 |
+
|
118 |
+
# 3. Feed-forward network
|
119 |
+
intermediate_output = self.interm_af(self.interm_dense(attention_output))
|
120 |
+
|
121 |
+
# 4. Add-norm after feed-forward
|
122 |
+
layer_output = self.add_norm(
|
123 |
+
attention_output,
|
124 |
+
intermediate_output,
|
125 |
+
self.out_dense,
|
126 |
+
self.out_dropout,
|
127 |
+
self.out_layer_norm
|
128 |
+
)
|
129 |
+
|
130 |
+
return layer_output
|
131 |
+
|
132 |
+
|
133 |
+
class BertModel(BertPreTrainedModel):
|
134 |
+
"""
|
135 |
+
The BERT model returns the final embeddings for each token in a sentence.
|
136 |
+
|
137 |
+
The model consists of:
|
138 |
+
1. Embedding layers (used in self.embed).
|
139 |
+
2. A stack of n BERT layers (used in self.encode).
|
140 |
+
3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
|
141 |
+
"""
|
142 |
+
def __init__(self, config):
|
143 |
+
super().__init__(config)
|
144 |
+
self.config = config
|
145 |
+
|
146 |
+
# Embedding layers.
|
147 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
148 |
+
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
149 |
+
self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
150 |
+
self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
151 |
+
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
|
152 |
+
# Register position_ids (1, len position emb) to buffer because it is a constant.
|
153 |
+
position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
|
154 |
+
self.register_buffer('position_ids', position_ids)
|
155 |
+
|
156 |
+
# BERT encoder.
|
157 |
+
self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
158 |
+
|
159 |
+
# [CLS] token transformations.
|
160 |
+
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
|
161 |
+
self.pooler_af = nn.Tanh()
|
162 |
+
|
163 |
+
self.init_weights()
|
164 |
+
|
165 |
+
|
166 |
+
def embed(self, input_ids):
|
167 |
+
input_shape = input_ids.size()
|
168 |
+
seq_length = input_shape[1]
|
169 |
+
|
170 |
+
inputs_embeds = self.word_embedding(input_ids)
|
171 |
+
|
172 |
+
pos_ids = self.position_ids[:, :seq_length]
|
173 |
+
pos_embeds = self.pos_embedding(pos_ids)
|
174 |
+
|
175 |
+
# Since we are not considering token type, this embedding is just a placeholder.
|
176 |
+
tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
177 |
+
tk_type_embeds = self.tk_type_embedding(tk_type_ids)
|
178 |
+
|
179 |
+
embeddings = inputs_embeds + pos_embeds + tk_type_embeds
|
180 |
+
embeddings = self.embed_layer_norm(embeddings)
|
181 |
+
embeddings = self.embed_dropout(embeddings)
|
182 |
+
|
183 |
+
return embeddings
|
184 |
+
|
185 |
+
|
186 |
+
def encode(self, hidden_states, attention_mask):
|
187 |
+
"""
|
188 |
+
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
|
189 |
+
attention_mask: [batch_size, seq_len]
|
190 |
+
"""
|
191 |
+
# Get the extended attention mask for self-attention.
|
192 |
+
# Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
|
193 |
+
# Distinguishes between non-padding tokens (with a value of 0) and padding tokens
|
194 |
+
# (with a value of a large negative number).
|
195 |
+
extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
|
196 |
+
|
197 |
+
# Pass the hidden states through the encoder layers.
|
198 |
+
for i, layer_module in enumerate(self.bert_layers):
|
199 |
+
# Feed the encoding from the last bert_layer to the next.
|
200 |
+
hidden_states = layer_module(hidden_states, extended_attention_mask)
|
201 |
+
|
202 |
+
return hidden_states
|
203 |
+
|
204 |
+
|
205 |
+
def forward(self, input_ids, attention_mask):
|
206 |
+
"""
|
207 |
+
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
|
208 |
+
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
|
209 |
+
"""
|
210 |
+
# Get the embedding for each input token.
|
211 |
+
embedding_output = self.embed(input_ids=input_ids)
|
212 |
+
|
213 |
+
# Feed to a transformer (a stack of BertLayers).
|
214 |
+
sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
|
215 |
+
|
216 |
+
# Get cls token hidden state.
|
217 |
+
first_tk = sequence_output[:, 0]
|
218 |
+
first_tk = self.pooler_dense(first_tk)
|
219 |
+
first_tk = self.pooler_af(first_tk)
|
220 |
+
|
221 |
+
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
|
classifier.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from classifier_utils import *
|
2 |
+
|
3 |
+
|
4 |
+
TQDM_DISABLE=True
|
5 |
+
|
6 |
+
|
7 |
+
class BertSentimentClassifier(torch.nn.Module):
|
8 |
+
def __init__(self, config, custom_bert = None):
|
9 |
+
super(BertSentimentClassifier, self).__init__()
|
10 |
+
self.num_labels = config.num_labels
|
11 |
+
self.bert: BertModel = custom_bert or BertModel.from_pretrained('bert-base-uncased')
|
12 |
+
|
13 |
+
# Pretrain mode does not require updating BERT paramters.
|
14 |
+
assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
|
15 |
+
for param in self.bert.parameters():
|
16 |
+
if config.fine_tune_mode == 'last-linear-layer':
|
17 |
+
param.requires_grad = False
|
18 |
+
elif config.fine_tune_mode == 'full-model':
|
19 |
+
param.requires_grad = True
|
20 |
+
|
21 |
+
# Classifier = Dropout + Linear
|
22 |
+
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
23 |
+
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
|
24 |
+
|
25 |
+
|
26 |
+
def forward(self, input_ids, attention_mask):
|
27 |
+
outputs = self.bert(input_ids, attention_mask)
|
28 |
+
pooler_output = outputs['pooler_output']
|
29 |
+
|
30 |
+
return self.classifier(self.dropout(pooler_output))
|
31 |
+
|
32 |
+
|
33 |
+
# Evaluate the model on dev examples.
|
34 |
+
def model_eval(dataloader, model: BertSentimentClassifier, device):
|
35 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
36 |
+
y_true = []
|
37 |
+
y_pred = []
|
38 |
+
sents = []
|
39 |
+
sent_ids = []
|
40 |
+
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
|
41 |
+
b_labels, b_sents, b_sent_ids = batch['labels'], batch['sents'], batch['sent_ids']
|
42 |
+
|
43 |
+
b_ids = batch['token_ids'].to(device)
|
44 |
+
b_mask = batch['attention_mask'].to(device)
|
45 |
+
|
46 |
+
logits = model(b_ids, b_mask)
|
47 |
+
logits = logits.detach().cpu().numpy()
|
48 |
+
preds = np.argmax(logits, axis=1).flatten()
|
49 |
+
|
50 |
+
b_labels = b_labels.flatten()
|
51 |
+
y_true.extend(b_labels)
|
52 |
+
y_pred.extend(preds)
|
53 |
+
sents.extend(b_sents)
|
54 |
+
sent_ids.extend(b_sent_ids)
|
55 |
+
|
56 |
+
f1 = f1_score(y_true, y_pred, average='macro')
|
57 |
+
acc = accuracy_score(y_true, y_pred)
|
58 |
+
|
59 |
+
return acc, f1, y_pred, y_true, sents, sent_ids
|
60 |
+
|
61 |
+
|
62 |
+
# Evaluate the model on test examples.
|
63 |
+
def model_test_eval(dataloader, model, device):
|
64 |
+
model.eval() # Switch to eval model, will turn off randomness like dropout.
|
65 |
+
y_pred = []
|
66 |
+
sents = []
|
67 |
+
sent_ids = []
|
68 |
+
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
|
69 |
+
b_sents, b_sent_ids = batch['sents'], batch['sent_ids']
|
70 |
+
|
71 |
+
b_ids = batch['token_ids'].to(device)
|
72 |
+
b_mask = batch['attention_mask'].to(device)
|
73 |
+
|
74 |
+
logits = model(b_ids, b_mask)
|
75 |
+
logits = logits.detach().cpu().numpy()
|
76 |
+
preds = np.argmax(logits, axis=1).flatten()
|
77 |
+
|
78 |
+
y_pred.extend(preds)
|
79 |
+
sents.extend(b_sents)
|
80 |
+
sent_ids.extend(b_sent_ids)
|
81 |
+
|
82 |
+
return y_pred, sents, sent_ids
|
83 |
+
|
84 |
+
|
85 |
+
def save_model(model, args, config, filepath):
|
86 |
+
save_info = {
|
87 |
+
'model': model.state_dict(),
|
88 |
+
'args': args,
|
89 |
+
'model_config': config,
|
90 |
+
'system_rng': random.getstate(),
|
91 |
+
'numpy_rng': np.random.get_state(),
|
92 |
+
'torch_rng': torch.random.get_rng_state(),
|
93 |
+
}
|
94 |
+
|
95 |
+
torch.save(save_info, filepath)
|
96 |
+
print(f"save the model to {filepath}")
|
97 |
+
|
98 |
+
|
99 |
+
def train(args, custom_bert=None):
|
100 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
101 |
+
# Create the data and its corresponding datasets and dataloader.
|
102 |
+
train_data, num_labels = load_data(args.train, 'train')
|
103 |
+
dev_data = load_data(args.dev, 'valid')
|
104 |
+
|
105 |
+
train_dataset = SentimentDataset(train_data)
|
106 |
+
dev_dataset = SentimentDataset(dev_data)
|
107 |
+
|
108 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
109 |
+
num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn)
|
110 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
111 |
+
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
|
112 |
+
|
113 |
+
# Init model.
|
114 |
+
config = {'hidden_dropout_prob': HIDDEN_DROPOUT_PROB,
|
115 |
+
'num_labels': num_labels,
|
116 |
+
'hidden_size': 768,
|
117 |
+
'data_dir': '.',
|
118 |
+
'fine_tune_mode': args.fine_tune_mode}
|
119 |
+
|
120 |
+
config = SimpleNamespace(**config)
|
121 |
+
|
122 |
+
model = BertSentimentClassifier(config, custom_bert)
|
123 |
+
model = model.to(device)
|
124 |
+
|
125 |
+
lr = args.lr
|
126 |
+
optimizer = AdamW(model.parameters(), lr=lr)
|
127 |
+
best_dev_acc = 0
|
128 |
+
|
129 |
+
# Run for the specified number of epochs.
|
130 |
+
for epoch in range(EPOCHS):
|
131 |
+
model.train()
|
132 |
+
train_loss = 0
|
133 |
+
num_batches = 0
|
134 |
+
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
|
135 |
+
b_ids = batch['token_ids'].to(device)
|
136 |
+
b_mask = batch['attention_mask'].to(device)
|
137 |
+
b_labels = batch['labels'].to(device)
|
138 |
+
|
139 |
+
optimizer.zero_grad()
|
140 |
+
logits = model(b_ids, b_mask)
|
141 |
+
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
|
142 |
+
|
143 |
+
loss.backward()
|
144 |
+
optimizer.step()
|
145 |
+
|
146 |
+
train_loss += loss.item()
|
147 |
+
num_batches += 1
|
148 |
+
|
149 |
+
train_loss = train_loss / (num_batches)
|
150 |
+
|
151 |
+
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
|
152 |
+
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
|
153 |
+
|
154 |
+
if dev_acc > best_dev_acc:
|
155 |
+
best_dev_acc = dev_acc
|
156 |
+
save_model(model, args, config, args.filepath)
|
157 |
+
|
158 |
+
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
|
159 |
+
|
160 |
+
|
161 |
+
def test(args):
|
162 |
+
with torch.no_grad():
|
163 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
164 |
+
saved = torch.load(args.filepath, weights_only=False)
|
165 |
+
config = saved['model_config']
|
166 |
+
model = BertSentimentClassifier(config)
|
167 |
+
model.load_state_dict(saved['model'])
|
168 |
+
model = model.to(device)
|
169 |
+
print(f"load model from {args.filepath}")
|
170 |
+
|
171 |
+
dev_data = load_data(args.dev, 'valid')
|
172 |
+
dev_dataset = SentimentDataset(dev_data)
|
173 |
+
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
|
174 |
+
num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
|
175 |
+
|
176 |
+
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
|
177 |
+
print('DONE DEV')
|
178 |
+
print(f"dev acc :: {dev_acc :.3f}")
|
179 |
+
|
180 |
+
|
181 |
+
def classifier_run(args, custom_bert=None):
|
182 |
+
seed_everything(SEED)
|
183 |
+
torch.set_num_threads(NUM_CPU_CORES)
|
184 |
+
|
185 |
+
print(f'Training Sentiment Classifier on {args.dataset}...')
|
186 |
+
config = SimpleNamespace(
|
187 |
+
filepath=f'/kaggle/working/{args.dataset}-classifier.pt',
|
188 |
+
lr=args.lr,
|
189 |
+
batch_size=args.batch_size,
|
190 |
+
fine_tune_mode=args.fine_tune_mode,
|
191 |
+
train=args.train, dev=args.dev, test=args.test,
|
192 |
+
dev_out = f'/kaggle/working/predictions/{args.fine_tune_mode}-{args.dataset}-dev-out.csv',
|
193 |
+
test_out = f'/kaggle/working/predictions/{args.fine_tune_mode}-{args.dataset}-test-out.csv'
|
194 |
+
)
|
195 |
+
|
196 |
+
train(config, custom_bert)
|
197 |
+
|
198 |
+
print(f'Evaluating on {args.dataset}...')
|
199 |
+
test(config)
|
classifier_utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from everything import *
|
2 |
+
from bert import BertModel
|
3 |
+
from optimizer import AdamW
|
4 |
+
from tokenizer import BertTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
8 |
+
|
9 |
+
|
10 |
+
class SentimentDataset(Dataset):
|
11 |
+
def __init__(self, dataset):
|
12 |
+
self.dataset = dataset
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.dataset)
|
16 |
+
|
17 |
+
def __getitem__(self, idx):
|
18 |
+
return self.dataset[idx]
|
19 |
+
|
20 |
+
def pad_data(self, data):
|
21 |
+
sents = [x[0] for x in data]
|
22 |
+
labels = [x[1] for x in data]
|
23 |
+
sent_ids = [x[2] for x in data]
|
24 |
+
|
25 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
26 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
27 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
28 |
+
labels = torch.LongTensor(labels)
|
29 |
+
|
30 |
+
return token_ids, attention_mask, labels, sents, sent_ids
|
31 |
+
|
32 |
+
def collate_fn(self, all_data):
|
33 |
+
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
|
34 |
+
|
35 |
+
batched_data = {
|
36 |
+
'token_ids': token_ids,
|
37 |
+
'attention_mask': attention_mask,
|
38 |
+
'labels': labels,
|
39 |
+
'sents': sents,
|
40 |
+
'sent_ids': sent_ids
|
41 |
+
}
|
42 |
+
|
43 |
+
return batched_data
|
44 |
+
|
45 |
+
|
46 |
+
class SentimentTestDataset(Dataset):
|
47 |
+
def __init__(self, dataset):
|
48 |
+
self.dataset = dataset
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.dataset)
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
return self.dataset[idx]
|
55 |
+
|
56 |
+
def pad_data(self, data):
|
57 |
+
sents = [x[0] for x in data]
|
58 |
+
sent_ids = [x[1] for x in data]
|
59 |
+
|
60 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
61 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
62 |
+
attention_mask = torch.LongTensor(encoding['attention_mask'])
|
63 |
+
|
64 |
+
return token_ids, attention_mask, sents, sent_ids
|
65 |
+
|
66 |
+
def collate_fn(self, all_data):
|
67 |
+
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
|
68 |
+
|
69 |
+
batched_data = {
|
70 |
+
'token_ids': token_ids,
|
71 |
+
'attention_mask': attention_mask,
|
72 |
+
'sents': sents,
|
73 |
+
'sent_ids': sent_ids
|
74 |
+
}
|
75 |
+
|
76 |
+
return batched_data
|
77 |
+
|
78 |
+
|
79 |
+
class AmazonDataset(Dataset):
|
80 |
+
def __init__(self, dataset):
|
81 |
+
self.dataset = dataset
|
82 |
+
|
83 |
+
def __len__(self):
|
84 |
+
return len(self.dataset)
|
85 |
+
|
86 |
+
def __getitem__(self, idx):
|
87 |
+
return self.dataset[idx]
|
88 |
+
|
89 |
+
def pad_data(self, data):
|
90 |
+
sents = [x[0] for x in data]
|
91 |
+
sent_ids = [x[1] for x in data]
|
92 |
+
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
93 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
94 |
+
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
95 |
+
|
96 |
+
return token_ids, attension_mask, sent_ids
|
97 |
+
|
98 |
+
def collate_fn(self, data):
|
99 |
+
token_ids, attention_mask, sent_ids = self.pad_data(data)
|
100 |
+
|
101 |
+
batched_data = {
|
102 |
+
'token_ids': token_ids,
|
103 |
+
'attention_mask': attention_mask,
|
104 |
+
'sent_ids': sent_ids
|
105 |
+
}
|
106 |
+
|
107 |
+
return batched_data
|
108 |
+
|
109 |
+
|
110 |
+
class SemanticDataset(Dataset):
|
111 |
+
def __init__(self, dataset):
|
112 |
+
self.dataset = dataset
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.dataset)
|
116 |
+
|
117 |
+
def __getitem__(self, idx):
|
118 |
+
return self.dataset[idx]
|
119 |
+
|
120 |
+
def pad_data(self, data):
|
121 |
+
sents1 = [x[0] for x in data]
|
122 |
+
sents2 = [x[1] for x in data]
|
123 |
+
score = [x[2] for x in data]
|
124 |
+
sent_ids = [x[3] for x in data]
|
125 |
+
encoding = tokenizer(sents1 + sents2, return_tensors='pt', padding=True, truncation=True)
|
126 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
127 |
+
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
128 |
+
|
129 |
+
return token_ids, attension_mask, score, sent_ids
|
130 |
+
|
131 |
+
def collate_fn(self, data):
|
132 |
+
token_ids, attention_mask, score, sent_ids = self.pad_data(data)
|
133 |
+
n = len(sent_ids)
|
134 |
+
|
135 |
+
batched_data = {
|
136 |
+
'token_ids_1': token_ids[:n],
|
137 |
+
'token_ids_2': token_ids[n:],
|
138 |
+
'attention_mask_1': attention_mask[:n],
|
139 |
+
'attention_mask_2': attention_mask[n:],
|
140 |
+
'score': score,
|
141 |
+
'sent_ids': sent_ids
|
142 |
+
}
|
143 |
+
|
144 |
+
return batched_data
|
145 |
+
|
146 |
+
|
147 |
+
class InferenceDataset(Dataset):
|
148 |
+
def __init__(self, dataset):
|
149 |
+
self.dataset = dataset
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.dataset)
|
153 |
+
|
154 |
+
def __getitem__(self, idx):
|
155 |
+
return self.dataset[idx]
|
156 |
+
|
157 |
+
def pad_data(self, data):
|
158 |
+
anchor = [x[0] for x in data]
|
159 |
+
positive = [x[1] for x in data]
|
160 |
+
negative = [x[2] for x in data]
|
161 |
+
sent_ids = [x[3] for x in data]
|
162 |
+
encoding = tokenizer(anchor + positive + negative, return_tensors='pt', padding=True, truncation=True)
|
163 |
+
token_ids = torch.LongTensor(encoding['input_ids'])
|
164 |
+
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
165 |
+
|
166 |
+
return token_ids, attension_mask, sent_ids
|
167 |
+
|
168 |
+
def collate_fn(self, data):
|
169 |
+
token_ids, attention_mask, sent_ids = self.pad_data(data)
|
170 |
+
n = len(sent_ids)
|
171 |
+
|
172 |
+
batched_data = {
|
173 |
+
'anchor_ids': token_ids[:n],
|
174 |
+
'positive_ids': token_ids[n:2*n],
|
175 |
+
'negative_ids': token_ids[2*n:],
|
176 |
+
'anchor_masks': attention_mask[:n],
|
177 |
+
'positive_masks': attention_mask[n:2*n],
|
178 |
+
'negative_masks': attention_mask[2*n:],
|
179 |
+
'sent_ids': sent_ids
|
180 |
+
}
|
181 |
+
|
182 |
+
return batched_data
|
183 |
+
|
184 |
+
|
185 |
+
def load_data(filename, flag='train'):
|
186 |
+
'''
|
187 |
+
- for amazon dataset: list of (sent, id)
|
188 |
+
- for nli dataset: list of (anchor, positive, negative, id)
|
189 |
+
- for stsb dataset: list of (sentence1, sentence2, score, id)
|
190 |
+
|
191 |
+
- for test dataset: list of (sent, id)
|
192 |
+
- for train dataset: list of (sent, label, id)
|
193 |
+
'''
|
194 |
+
|
195 |
+
if flag == 'amazon':
|
196 |
+
df = pd.read_parquet(filename)
|
197 |
+
data = list(zip(df['content'], df.index))
|
198 |
+
elif flag == 'nli':
|
199 |
+
df = pd.read_parquet(filename)
|
200 |
+
data = list(zip(df['anchor'], df['positive'], df['negative'], df.index))
|
201 |
+
elif flag == 'stsb':
|
202 |
+
df = pd.read_parquet(filename)
|
203 |
+
data = list(zip(df['sentence1'], df['sentence2'], df['score'], df.index))
|
204 |
+
else:
|
205 |
+
data, num_labels = [], set()
|
206 |
+
|
207 |
+
with open(filename, 'r') as fp:
|
208 |
+
if flag == 'test':
|
209 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
210 |
+
sent = record['sentence'].lower().strip()
|
211 |
+
sent_id = record['id'].lower().strip()
|
212 |
+
data.append((sent,sent_id))
|
213 |
+
else:
|
214 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
215 |
+
sent = record['sentence'].lower().strip()
|
216 |
+
sent_id = record['id'].lower().strip()
|
217 |
+
label = int(record['sentiment'].strip())
|
218 |
+
num_labels.add(label)
|
219 |
+
data.append((sent, label, sent_id))
|
220 |
+
|
221 |
+
print(f"load {len(data)} data from {filename}")
|
222 |
+
if flag == "train":
|
223 |
+
return data, len(num_labels)
|
224 |
+
else:
|
225 |
+
return data
|
226 |
+
|
227 |
+
|
228 |
+
def seed_everything(seed=11711):
|
229 |
+
random.seed(seed)
|
230 |
+
np.random.seed(seed)
|
231 |
+
torch.manual_seed(seed)
|
232 |
+
torch.cuda.manual_seed(seed)
|
233 |
+
torch.cuda.manual_seed_all(seed)
|
234 |
+
torch.backends.cudnn.benchmark = False
|
235 |
+
torch.backends.cudnn.deterministic = True
|
config.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple, Dict, Any, Optional
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from collections import OrderedDict
|
5 |
+
import torch
|
6 |
+
from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
|
7 |
+
|
8 |
+
class PretrainedConfig(object):
|
9 |
+
model_type: str = ""
|
10 |
+
is_composition: bool = False
|
11 |
+
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
# Attributes with defaults
|
14 |
+
self.return_dict = kwargs.pop("return_dict", True)
|
15 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
16 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
17 |
+
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
18 |
+
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
19 |
+
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
20 |
+
self.tie_word_embeddings = kwargs.pop(
|
21 |
+
"tie_word_embeddings", True
|
22 |
+
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
23 |
+
|
24 |
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
25 |
+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
26 |
+
self.is_decoder = kwargs.pop("is_decoder", False)
|
27 |
+
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
28 |
+
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
29 |
+
|
30 |
+
# Parameters for sequence generation
|
31 |
+
self.max_length = kwargs.pop("max_length", 20)
|
32 |
+
self.min_length = kwargs.pop("min_length", 0)
|
33 |
+
self.do_sample = kwargs.pop("do_sample", False)
|
34 |
+
self.early_stopping = kwargs.pop("early_stopping", False)
|
35 |
+
self.num_beams = kwargs.pop("num_beams", 1)
|
36 |
+
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
37 |
+
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
38 |
+
self.temperature = kwargs.pop("temperature", 1.0)
|
39 |
+
self.top_k = kwargs.pop("top_k", 50)
|
40 |
+
self.top_p = kwargs.pop("top_p", 1.0)
|
41 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
42 |
+
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
43 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
44 |
+
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
45 |
+
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
46 |
+
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
47 |
+
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
48 |
+
self.output_scores = kwargs.pop("output_scores", False)
|
49 |
+
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
50 |
+
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
51 |
+
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
52 |
+
|
53 |
+
# Fine-tuning task arguments
|
54 |
+
self.architectures = kwargs.pop("architectures", None)
|
55 |
+
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
56 |
+
self.id2label = kwargs.pop("id2label", None)
|
57 |
+
self.label2id = kwargs.pop("label2id", None)
|
58 |
+
if self.id2label is not None:
|
59 |
+
kwargs.pop("num_labels", None)
|
60 |
+
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
61 |
+
# Keys are always strings in JSON so convert ids to int here.
|
62 |
+
else:
|
63 |
+
self.num_labels = kwargs.pop("num_labels", 2)
|
64 |
+
|
65 |
+
# Tokenizer arguments
|
66 |
+
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
67 |
+
self.prefix = kwargs.pop("prefix", None)
|
68 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
69 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
70 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
71 |
+
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
72 |
+
|
73 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
74 |
+
|
75 |
+
# task specific arguments
|
76 |
+
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
77 |
+
|
78 |
+
# TPU arguments
|
79 |
+
self.xla_device = kwargs.pop("xla_device", None)
|
80 |
+
|
81 |
+
# Name or path to the pretrained checkpoint
|
82 |
+
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
83 |
+
|
84 |
+
# Drop the transformers version info
|
85 |
+
kwargs.pop("transformers_version", None)
|
86 |
+
|
87 |
+
# Additional attributes without default values
|
88 |
+
for key, value in kwargs.items():
|
89 |
+
try:
|
90 |
+
setattr(self, key, value)
|
91 |
+
except AttributeError as err:
|
92 |
+
raise err
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
96 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
97 |
+
return cls.from_dict(config_dict, **kwargs)
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
101 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
102 |
+
text = reader.read()
|
103 |
+
return json.loads(text)
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
107 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
108 |
+
|
109 |
+
config = cls(**config_dict)
|
110 |
+
|
111 |
+
if hasattr(config, "pruned_heads"):
|
112 |
+
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
113 |
+
|
114 |
+
# Update config with kwargs if needed
|
115 |
+
to_remove = []
|
116 |
+
for key, value in kwargs.items():
|
117 |
+
if hasattr(config, key):
|
118 |
+
setattr(config, key, value)
|
119 |
+
to_remove.append(key)
|
120 |
+
for key in to_remove:
|
121 |
+
kwargs.pop(key, None)
|
122 |
+
|
123 |
+
if return_unused_kwargs:
|
124 |
+
return config, kwargs
|
125 |
+
else:
|
126 |
+
return config
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def get_config_dict(
|
130 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
131 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
132 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
133 |
+
force_download = kwargs.pop("force_download", False)
|
134 |
+
resume_download = kwargs.pop("resume_download", False)
|
135 |
+
proxies = kwargs.pop("proxies", None)
|
136 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
137 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
138 |
+
revision = kwargs.pop("revision", None)
|
139 |
+
|
140 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
141 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
142 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
143 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
144 |
+
config_file = pretrained_model_name_or_path
|
145 |
+
else:
|
146 |
+
config_file = hf_bucket_url(
|
147 |
+
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
148 |
+
)
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Load from URL or cache if already cached
|
152 |
+
resolved_config_file = cached_path(
|
153 |
+
config_file,
|
154 |
+
cache_dir=cache_dir,
|
155 |
+
force_download=force_download,
|
156 |
+
proxies=proxies,
|
157 |
+
resume_download=resume_download,
|
158 |
+
local_files_only=local_files_only,
|
159 |
+
use_auth_token=use_auth_token,
|
160 |
+
)
|
161 |
+
# Load config dict
|
162 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
163 |
+
|
164 |
+
except EnvironmentError as err:
|
165 |
+
msg = (
|
166 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
167 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
168 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
169 |
+
)
|
170 |
+
raise EnvironmentError(msg)
|
171 |
+
|
172 |
+
except json.JSONDecodeError:
|
173 |
+
msg = (
|
174 |
+
"Couldn't reach server at '{}' to download configuration file or "
|
175 |
+
"configuration file is not a valid JSON file. "
|
176 |
+
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
177 |
+
)
|
178 |
+
raise EnvironmentError(msg)
|
179 |
+
|
180 |
+
return config_dict, kwargs
|
181 |
+
|
182 |
+
|
183 |
+
class BertConfig(PretrainedConfig):
|
184 |
+
model_type = "bert"
|
185 |
+
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
vocab_size=30522,
|
189 |
+
hidden_size=768,
|
190 |
+
num_hidden_layers=12,
|
191 |
+
num_attention_heads=12,
|
192 |
+
intermediate_size=3072,
|
193 |
+
hidden_act="gelu",
|
194 |
+
hidden_dropout_prob=0.1,
|
195 |
+
attention_probs_dropout_prob=0.1,
|
196 |
+
max_position_embeddings=512,
|
197 |
+
type_vocab_size=2,
|
198 |
+
initializer_range=0.02,
|
199 |
+
layer_norm_eps=1e-12,
|
200 |
+
pad_token_id=0,
|
201 |
+
gradient_checkpointing=False,
|
202 |
+
position_embedding_type="absolute",
|
203 |
+
use_cache=True,
|
204 |
+
**kwargs
|
205 |
+
):
|
206 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
207 |
+
|
208 |
+
self.vocab_size = vocab_size
|
209 |
+
self.hidden_size = hidden_size
|
210 |
+
self.num_hidden_layers = num_hidden_layers
|
211 |
+
self.num_attention_heads = num_attention_heads
|
212 |
+
self.hidden_act = hidden_act
|
213 |
+
self.intermediate_size = intermediate_size
|
214 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
215 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
216 |
+
self.max_position_embeddings = max_position_embeddings
|
217 |
+
self.type_vocab_size = type_vocab_size
|
218 |
+
self.initializer_range = initializer_range
|
219 |
+
self.layer_norm_eps = layer_norm_eps
|
220 |
+
self.gradient_checkpointing = gradient_checkpointing
|
221 |
+
self.position_embedding_type = position_embedding_type
|
222 |
+
self.use_cache = use_cache
|
constants.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
DATA_DIR = 'minbert-data'
|
4 |
+
MODEL_DIR = 'minbert-model'
|
5 |
+
|
6 |
+
|
7 |
+
# Pretrained weights
|
8 |
+
SUP_BERT = os.path.join(MODEL_DIR, 'sup-cse-bert.pth')
|
9 |
+
UNSUP_BERT = os.path.join(MODEL_DIR, 'unsup-cse-bert.pth')
|
10 |
+
|
11 |
+
|
12 |
+
# CFIMDB dataset
|
13 |
+
IDS_CFIMDB_DEV = os.path.join(DATA_DIR, 'ids-cfimdb-dev.csv')
|
14 |
+
IDS_CFIMDB_TEST = os.path.join(DATA_DIR, 'ids-cfimdb-test-student.csv')
|
15 |
+
IDS_CFIMDB_TRAIN = os.path.join(DATA_DIR, 'ids-cfimdb-train.csv')
|
16 |
+
|
17 |
+
# SST dataset
|
18 |
+
IDS_SST_DEV = os.path.join(DATA_DIR, 'ids-sst-dev.csv')
|
19 |
+
IDS_SST_TEST = os.path.join(DATA_DIR, 'ids-sst-test-student.csv')
|
20 |
+
IDS_SST_TRAIN = os.path.join(DATA_DIR, 'ids-sst-train.csv')
|
21 |
+
|
22 |
+
# SimCSE train/dev dataset
|
23 |
+
NLI_TRAIN = os.path.join(DATA_DIR, 'nli-train.parquet')
|
24 |
+
AMAZON_POLARITY = os.path.join(DATA_DIR, 'amazon-polarity.parquet')
|
25 |
+
STSB_DEV = os.path.join(DATA_DIR, 'stsb-dev.parquet')
|
26 |
+
|
27 |
+
|
28 |
+
# Training-specific constants
|
29 |
+
SEED=11711
|
30 |
+
NUM_CPU_CORES=4
|
31 |
+
EPOCHS=10
|
32 |
+
USE_GPU=True
|
33 |
+
BATCH_SIZE_CSE=8
|
34 |
+
BATCH_SIZE_SST=64
|
35 |
+
BATCH_SIZE_CFIMDB=8
|
36 |
+
HIDDEN_DROPOUT_PROB=0.3
|
everything.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from torch import nn, Tensor
|
11 |
+
from types import SimpleNamespace
|
12 |
+
from scipy.stats import spearmanr
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
from sklearn.metrics import f1_score, accuracy_score
|
15 |
+
|
16 |
+
from constants import *
|
17 |
+
|
18 |
+
import random, numpy as np, argparse
|
19 |
+
from types import SimpleNamespace
|
20 |
+
import csv
|
finetune-scripts/sup.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from finetuning import finetune_bert
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
mode='sup',
|
7 |
+
filepath='/minbert-model/sup-cse-bert.pth',
|
8 |
+
batch_size_train=12,
|
9 |
+
batch_size_dev=64,
|
10 |
+
temp=0.05, lr=1e-5,
|
11 |
+
)
|
12 |
+
finetune_bert(ARGUMENTS)
|
finetune-scripts/unsup.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from finetuning import finetune_bert
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
mode='unsup',
|
7 |
+
filepath='/minbert-model/unsup-cse-bert.pth',
|
8 |
+
batch_size_train=8,
|
9 |
+
batch_size_dev=64,
|
10 |
+
temp=0.05, lr=1e-5,
|
11 |
+
)
|
12 |
+
finetune_bert(ARGUMENTS)
|
finetuning.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from classifier_utils import *
|
2 |
+
|
3 |
+
|
4 |
+
TQDM_DISABLE=True
|
5 |
+
|
6 |
+
|
7 |
+
def unsup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
8 |
+
'''
|
9 |
+
embeds_1: [batch_size, hidden_size]
|
10 |
+
embeds_2: [batch_size, hidden_size]
|
11 |
+
'''
|
12 |
+
|
13 |
+
# [batch_size, batch_size]
|
14 |
+
sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
|
15 |
+
|
16 |
+
# [batch_size]
|
17 |
+
positive_sim = torch.diagonal(sim_matrix)
|
18 |
+
|
19 |
+
# [batch_size]
|
20 |
+
nume = torch.exp(positive_sim)
|
21 |
+
|
22 |
+
# [batch_size]
|
23 |
+
deno = torch.exp(sim_matrix).sum(1)
|
24 |
+
|
25 |
+
# [batch_size]
|
26 |
+
loss_per_batch = -torch.log(nume / deno)
|
27 |
+
|
28 |
+
return loss_per_batch.sum()
|
29 |
+
|
30 |
+
|
31 |
+
def sup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, embeds_3: Tensor, temp=0.05):
|
32 |
+
'''
|
33 |
+
embeds_1: [batch_size, hidden_size]
|
34 |
+
embeds_2: [batch_size, hidden_size]
|
35 |
+
embeds_3: [batch_size, hidden_size]
|
36 |
+
'''
|
37 |
+
|
38 |
+
# [batch_size, batch_size]
|
39 |
+
pos_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
|
40 |
+
neg_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_3.unsqueeze(0), dim=-1) / temp
|
41 |
+
|
42 |
+
# [batch_size]
|
43 |
+
positive_sim = torch.diagonal(pos_sim_matrix)
|
44 |
+
|
45 |
+
# [batch_size]
|
46 |
+
nume = torch.exp(positive_sim)
|
47 |
+
|
48 |
+
# [batch_size]
|
49 |
+
deno = (torch.exp(pos_sim_matrix) + torch.exp(neg_sim_matrix)).sum(1)
|
50 |
+
|
51 |
+
# [batch_size]
|
52 |
+
loss_per_batch = -torch.log(nume / deno)
|
53 |
+
|
54 |
+
return loss_per_batch.sum()
|
55 |
+
|
56 |
+
|
57 |
+
def sts_eval(dataloader, model: BertModel, device):
|
58 |
+
model.eval()
|
59 |
+
y_true = []
|
60 |
+
y_pred = []
|
61 |
+
sent_ids = []
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
for batch in tqdm(dataloader, desc='eval', leave=False, disable=TQDM_DISABLE):
|
65 |
+
token_ids_1 = batch['token_ids_1'].to(device)
|
66 |
+
token_ids_2 = batch['token_ids_2'].to(device)
|
67 |
+
attention_mask_1 = batch['attention_mask_1'].to(device)
|
68 |
+
attention_mask_2 = batch['attention_mask_2'].to(device)
|
69 |
+
|
70 |
+
scores = batch['score']
|
71 |
+
b_sent_ids = batch['sent_ids']
|
72 |
+
|
73 |
+
logits_1 = model(token_ids_1, attention_mask_1)['pooler_output']
|
74 |
+
logits_2 = model(token_ids_2, attention_mask_2)['pooler_output']
|
75 |
+
|
76 |
+
sim = F.cosine_similarity(logits_1, logits_2)
|
77 |
+
y_true.extend(scores)
|
78 |
+
y_pred.extend(sim.cpu().tolist())
|
79 |
+
sent_ids.extend(b_sent_ids)
|
80 |
+
|
81 |
+
spearman_corr, _ = spearmanr(y_pred, y_true)
|
82 |
+
return spearman_corr, b_sent_ids
|
83 |
+
|
84 |
+
|
85 |
+
def finetune_bert(args):
|
86 |
+
'''
|
87 |
+
Finetuning Baseline
|
88 |
+
-------------------
|
89 |
+
1. Load the Amazon Polarity (train) and STS Dataset (dev).
|
90 |
+
2. Initialize pretrained minBERT
|
91 |
+
3. Looping through 10 epoches.
|
92 |
+
4. Calculate batches' SimCSE loss function.
|
93 |
+
5. Backpropagation using Adam Optimizer.
|
94 |
+
6. Evaluation on dev dataset:
|
95 |
+
6.1. Create two [CLS] embeddings for given pair.
|
96 |
+
6.2. Calculate their cosine similarity (0 <= sim <= 1).
|
97 |
+
6.3. Spearman's correlation between calculated sim and expected sim.
|
98 |
+
7. Better spearman's correlation (dev_acc > best_dev_acc) -> save_model(...).
|
99 |
+
'''
|
100 |
+
|
101 |
+
assert args.mode in ['unsup', 'sup']
|
102 |
+
|
103 |
+
seed_everything(SEED)
|
104 |
+
torch.set_num_threads(NUM_CPU_CORES)
|
105 |
+
|
106 |
+
if args.mode == 'unsup':
|
107 |
+
train_dataset = AmazonDataset(load_data(AMAZON_POLARITY, 'amazon'))
|
108 |
+
else:
|
109 |
+
train_dataset = InferenceDataset(load_data(NLI_TRAIN, 'nli'))
|
110 |
+
|
111 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_train,
|
112 |
+
num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn)
|
113 |
+
|
114 |
+
sts_dataset = SemanticDataset(load_data(STSB_DEV, 'stsb'))
|
115 |
+
sts_dataloader = DataLoader(sts_dataset, shuffle=False, batch_size=args.batch_size_dev,
|
116 |
+
num_workers=NUM_CPU_CORES, collate_fn=sts_dataset.collate_fn)
|
117 |
+
|
118 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
119 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
120 |
+
model.to(device)
|
121 |
+
|
122 |
+
best_dev_acc = 0
|
123 |
+
optimizer = AdamW(model.parameters(), lr=args.lr)
|
124 |
+
|
125 |
+
print(f'Finetuning minBERT with {args.mode}ervised method...')
|
126 |
+
|
127 |
+
for epoch in range(EPOCHS):
|
128 |
+
model.train()
|
129 |
+
train_loss = num_batches = 0
|
130 |
+
|
131 |
+
for batch in tqdm(train_dataloader, f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
|
132 |
+
if args.mode == 'unsup':
|
133 |
+
b_ids = batch['token_ids'].to(device)
|
134 |
+
b_mask = batch['attention_mask'].to(device)
|
135 |
+
|
136 |
+
# Get different embeddings with different dropout masks
|
137 |
+
logits_1 = model(b_ids, b_mask)['pooler_output']
|
138 |
+
logits_2 = model(b_ids, b_mask)['pooler_output']
|
139 |
+
|
140 |
+
# Calculate mean SimCSE loss function
|
141 |
+
loss = unsup_contrastive_loss(logits_1, logits_2, args.temp)
|
142 |
+
|
143 |
+
else:
|
144 |
+
b_anchor_ids = batch['anchor_ids'].to(device)
|
145 |
+
b_positive_ids = batch['positive_ids'].to(device)
|
146 |
+
b_negative_ids = batch['negative_ids'].to(device)
|
147 |
+
b_anchor_masks = batch['anchor_masks'].to(device)
|
148 |
+
b_positive_masks = batch['positive_masks'].to(device)
|
149 |
+
b_negative_masks = batch['negative_masks'].to(device)
|
150 |
+
|
151 |
+
logits_1 = model(b_anchor_ids, b_anchor_masks)['pooler_output']
|
152 |
+
logits_2 = model(b_positive_ids, b_positive_masks)['pooler_output']
|
153 |
+
logits_3 = model(b_negative_ids, b_negative_masks)['pooler_output']
|
154 |
+
|
155 |
+
loss = sup_contrastive_loss(logits_1, logits_2, logits_3, args.temp)
|
156 |
+
|
157 |
+
# Back propagation
|
158 |
+
optimizer.zero_grad()
|
159 |
+
loss.backward()
|
160 |
+
optimizer.step()
|
161 |
+
|
162 |
+
train_loss += loss.item()
|
163 |
+
num_batches += 1
|
164 |
+
|
165 |
+
train_loss /= num_batches
|
166 |
+
dev_acc, _ = sts_eval(sts_dataloader, model, device)
|
167 |
+
|
168 |
+
if dev_acc > best_dev_acc:
|
169 |
+
best_dev_acc = dev_acc
|
170 |
+
torch.save(model.state_dict(), args.filepath)
|
171 |
+
print(f"save the model to {args.filepath}")
|
172 |
+
|
173 |
+
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, dev acc :: {dev_acc :.3f}")
|
minbert-data/amazon-polarity.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cbb5fc18093875baac0f49ae926f76aa4938e3e9b8114d9a1d95fa05810d8e4
|
3 |
+
size 31246252
|
minbert-data/ids-cfimdb-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
|
3 |
+
size 249095
|
minbert-data/ids-cfimdb-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
|
3 |
+
size 495595
|
minbert-data/ids-cfimdb-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
|
3 |
+
size 1693182
|
minbert-data/ids-sst-dev.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
|
3 |
+
size 151384
|
minbert-data/ids-sst-test-student.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
|
3 |
+
size 313202
|
minbert-data/ids-sst-train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
|
3 |
+
size 1175139
|
minbert-data/nli-train.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc8e9bffa6e24c1175be20aa2d064dae99501937d7dec52a341f23f75eaeaec8
|
3 |
+
size 28964735
|
minbert-data/optimizer_test.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
|
3 |
+
size 152
|
minbert-data/sanity_check.data
ADDED
Binary file (56.4 kB). View file
|
|
minbert-data/stsb-dev.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
|
3 |
+
size 142187
|
minbert-model/sup-cse-bert.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8ba7bedbe15db3ce7345fa9cba47dc281d4ddb34512fe745445468adbe6abd08
|
3 |
+
size 438007966
|
minbert-model/unsup-cse-bert.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a8cfeab46f5903b3297f9536c29516ec96c9bef525ecf13d33494dbd0edebd7
|
3 |
+
size 438008374
|
optimizer.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Iterable, Tuple
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.optim import Optimizer
|
6 |
+
|
7 |
+
|
8 |
+
class AdamW(Optimizer):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
params: Iterable[torch.nn.parameter.Parameter],
|
12 |
+
lr: float = 1e-3,
|
13 |
+
betas: Tuple[float, float] = (0.9, 0.999),
|
14 |
+
eps: float = 1e-6,
|
15 |
+
weight_decay: float = 0.0,
|
16 |
+
correct_bias: bool = True,
|
17 |
+
):
|
18 |
+
if lr < 0.0:
|
19 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
20 |
+
if not 0.0 <= betas[0] < 1.0:
|
21 |
+
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
22 |
+
if not 0.0 <= betas[1] < 1.0:
|
23 |
+
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
24 |
+
if not 0.0 <= eps:
|
25 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
26 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
27 |
+
super().__init__(params, defaults)
|
28 |
+
|
29 |
+
def step(self, closure: Callable = None):
|
30 |
+
loss = None
|
31 |
+
if closure is not None:
|
32 |
+
loss = closure()
|
33 |
+
|
34 |
+
for group in self.param_groups:
|
35 |
+
for p in group["params"]:
|
36 |
+
if p.grad is None:
|
37 |
+
continue
|
38 |
+
grad = p.grad.data
|
39 |
+
if grad.is_sparse:
|
40 |
+
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
41 |
+
|
42 |
+
# Access state
|
43 |
+
state = self.state[p]
|
44 |
+
|
45 |
+
# Initialize state if not already done
|
46 |
+
if len(state) == 0:
|
47 |
+
state["step"] = 0
|
48 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
49 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
50 |
+
|
51 |
+
# Hyperparameters
|
52 |
+
alpha = group["lr"]
|
53 |
+
beta1, beta2 = group["betas"]
|
54 |
+
eps = group["eps"]
|
55 |
+
weight_decay = group["weight_decay"]
|
56 |
+
correct_bias = group["correct_bias"]
|
57 |
+
|
58 |
+
# Retrieve state variables
|
59 |
+
exp_avg = state["exp_avg"]
|
60 |
+
exp_avg_sq = state["exp_avg_sq"]
|
61 |
+
step = state["step"]
|
62 |
+
|
63 |
+
# Update step
|
64 |
+
step += 1
|
65 |
+
state["step"] = step
|
66 |
+
|
67 |
+
# Update biased first and second moment estimates
|
68 |
+
exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
|
69 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
70 |
+
|
71 |
+
# Compute bias-corrected moments
|
72 |
+
if correct_bias:
|
73 |
+
bias_correction1 = 1 - beta1 ** step
|
74 |
+
bias_correction2 = 1 - beta2 ** step
|
75 |
+
exp_avg_corr = exp_avg / bias_correction1
|
76 |
+
exp_avg_sq_corr = exp_avg_sq / bias_correction2
|
77 |
+
else:
|
78 |
+
exp_avg_corr = exp_avg
|
79 |
+
exp_avg_sq_corr = exp_avg_sq
|
80 |
+
|
81 |
+
# Update parameters
|
82 |
+
denom = exp_avg_sq_corr.sqrt().add_(eps)
|
83 |
+
step_size = alpha
|
84 |
+
p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
|
85 |
+
|
86 |
+
# Apply weight decay
|
87 |
+
if weight_decay != 0:
|
88 |
+
p.data.add_(p.data, alpha=-alpha * weight_decay)
|
89 |
+
|
90 |
+
return loss
|
optimizer_test.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from optimizer import AdamW
|
5 |
+
from constants import DATA_DIR
|
6 |
+
|
7 |
+
seed = 0
|
8 |
+
|
9 |
+
|
10 |
+
def test_optimizer(opt_class) -> torch.Tensor:
|
11 |
+
rng = np.random.default_rng(seed)
|
12 |
+
torch.manual_seed(seed)
|
13 |
+
model = torch.nn.Linear(3, 2, bias=False)
|
14 |
+
opt = opt_class(
|
15 |
+
model.parameters(),
|
16 |
+
lr=1e-3,
|
17 |
+
weight_decay=1e-4,
|
18 |
+
correct_bias=True,
|
19 |
+
)
|
20 |
+
for i in range(1000):
|
21 |
+
opt.zero_grad()
|
22 |
+
x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
|
23 |
+
y_hat = model(x)
|
24 |
+
y = torch.Tensor([x[0] + x[1], -x[2]])
|
25 |
+
loss = ((y - y_hat) ** 2).sum()
|
26 |
+
loss.backward()
|
27 |
+
opt.step()
|
28 |
+
return model.weight.detach()
|
29 |
+
|
30 |
+
|
31 |
+
ref = torch.tensor(np.load(os.path.join(DATA_DIR, "optimizer_test.npy")))
|
32 |
+
actual = test_optimizer(AdamW)
|
33 |
+
print(ref)
|
34 |
+
print(actual)
|
35 |
+
assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
|
36 |
+
print("Optimizer test passed!")
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
tqdm==4.58.0
|
3 |
+
requests==2.25.1
|
4 |
+
importlib-metadata==3.7.0
|
5 |
+
filelock==3.16.1
|
6 |
+
sklearn==0.0
|
7 |
+
tokenizers==0.15
|
8 |
+
explainaboard_client==0.0.7
|
9 |
+
pandas==2.2.3
|
sanity_check.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from bert import BertModel
|
5 |
+
from constants import DATA_DIR
|
6 |
+
|
7 |
+
sanity_data = torch.load(os.path.join(DATA_DIR, "sanity_check.data"), weights_only=True)
|
8 |
+
sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
|
9 |
+
[101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
|
10 |
+
att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
|
11 |
+
|
12 |
+
# Load model.
|
13 |
+
bert = BertModel.from_pretrained('bert-base-uncased')
|
14 |
+
outputs = bert(sent_ids, att_mask)
|
15 |
+
att_mask = att_mask.unsqueeze(-1)
|
16 |
+
outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
|
17 |
+
sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
|
18 |
+
|
19 |
+
for k in ['last_hidden_state', 'pooler_output']:
|
20 |
+
assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
|
21 |
+
print("Your BERT implementation is correct!")
|
tokenizer.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
train-scripts/base_cfimdb_onfm.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
dataset='cfimdb',
|
7 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
8 |
+
train=IDS_CFIMDB_TRAIN,
|
9 |
+
dev=IDS_CFIMDB_DEV,
|
10 |
+
test=IDS_CFIMDB_TEST,
|
11 |
+
lr=1e-5,
|
12 |
+
fine_tune_mode='full-model'
|
13 |
+
)
|
14 |
+
|
15 |
+
classifier_run(ARGUMENTS)
|
train-scripts/base_cfimdb_onll.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
dataset='cfimdb',
|
7 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
8 |
+
train=IDS_CFIMDB_TRAIN,
|
9 |
+
dev=IDS_CFIMDB_DEV,
|
10 |
+
test=IDS_CFIMDB_TEST,
|
11 |
+
lr=1e-3,
|
12 |
+
fine_tune_mode='last-linear-layer'
|
13 |
+
)
|
14 |
+
|
15 |
+
classifier_run(ARGUMENTS)
|
train-scripts/base_sst_onfm.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
dataset='sst',
|
7 |
+
batch_size=BATCH_SIZE_SST,
|
8 |
+
train=IDS_SST_TRAIN,
|
9 |
+
dev=IDS_SST_DEV,
|
10 |
+
test=IDS_SST_TEST,
|
11 |
+
lr=1e-5,
|
12 |
+
fine_tune_mode='full-model'
|
13 |
+
)
|
14 |
+
|
15 |
+
classifier_run(ARGUMENTS)
|
train-scripts/base_sst_onll.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
|
5 |
+
ARGUMENTS = SimpleNamespace(
|
6 |
+
dataset='sst',
|
7 |
+
batch_size=BATCH_SIZE_SST,
|
8 |
+
train=IDS_SST_TRAIN,
|
9 |
+
dev=IDS_SST_DEV,
|
10 |
+
test=IDS_SST_TEST,
|
11 |
+
lr=1e-3,
|
12 |
+
fine_tune_mode='last-linear-layer'
|
13 |
+
)
|
14 |
+
|
15 |
+
classifier_run(ARGUMENTS)
|
train-scripts/finetuned_bert.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from everything import *
|
2 |
+
from bert import BertModel
|
3 |
+
|
4 |
+
def get_finetuned_bert(mode: str):
|
5 |
+
assert mode in ['sup', 'unsup']
|
6 |
+
|
7 |
+
bert = BertModel.from_pretrained('bert-base-uncased')
|
8 |
+
if mode == 'sup':
|
9 |
+
state_dict = torch.load(SUP_BERT, weights_only=True)
|
10 |
+
else:
|
11 |
+
state_dict = torch.load(UNSUP_BERT, weights_only=True)
|
12 |
+
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
|
13 |
+
|
14 |
+
bert.load_state_dict(state_dict)
|
15 |
+
return bert.to(device)
|
train-scripts/sup_cfimdb_onfm.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='cfimdb',
|
8 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
9 |
+
train=IDS_CFIMDB_TRAIN,
|
10 |
+
dev=IDS_CFIMDB_DEV,
|
11 |
+
test=IDS_CFIMDB_TEST,
|
12 |
+
lr=1e-5,
|
13 |
+
fine_tune_mode='full-model'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('sup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/sup_cfimdb_onll.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='cfimdb',
|
8 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
9 |
+
train=IDS_CFIMDB_TRAIN,
|
10 |
+
dev=IDS_CFIMDB_DEV,
|
11 |
+
test=IDS_CFIMDB_TEST,
|
12 |
+
lr=1e-3,
|
13 |
+
fine_tune_mode='last-linear-layer'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('sup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/sup_sst_onfm.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='sst',
|
8 |
+
batch_size=BATCH_SIZE_SST,
|
9 |
+
train=IDS_SST_TRAIN,
|
10 |
+
dev=IDS_SST_DEV,
|
11 |
+
test=IDS_SST_TEST,
|
12 |
+
lr=1e-5,
|
13 |
+
fine_tune_mode='full-model'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('sup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/sup_sst_onll.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='sst',
|
8 |
+
batch_size=BATCH_SIZE_SST,
|
9 |
+
train=IDS_SST_TRAIN,
|
10 |
+
dev=IDS_SST_DEV,
|
11 |
+
test=IDS_SST_TEST,
|
12 |
+
lr=1e-3,
|
13 |
+
fine_tune_mode='last-linear-layer'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('sup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/unsup_cfimdb_onfm.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='cfimdb',
|
8 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
9 |
+
train=IDS_CFIMDB_TRAIN,
|
10 |
+
dev=IDS_CFIMDB_DEV,
|
11 |
+
test=IDS_CFIMDB_TEST,
|
12 |
+
lr=1e-5,
|
13 |
+
fine_tune_mode='full-model'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('unsup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/unsup_cfimdb_onll.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='cfimdb',
|
8 |
+
batch_size=BATCH_SIZE_CFIMDB,
|
9 |
+
train=IDS_CFIMDB_TRAIN,
|
10 |
+
dev=IDS_CFIMDB_DEV,
|
11 |
+
test=IDS_CFIMDB_TEST,
|
12 |
+
lr=1e-3,
|
13 |
+
fine_tune_mode='last-linear-layer'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('unsup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/unsup_sst_onfm.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='sst',
|
8 |
+
batch_size=BATCH_SIZE_SST,
|
9 |
+
train=IDS_SST_TRAIN,
|
10 |
+
dev=IDS_SST_DEV,
|
11 |
+
test=IDS_SST_TEST,
|
12 |
+
lr=1e-5,
|
13 |
+
fine_tune_mode='full-model'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('unsup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
train-scripts/unsup_sst_onll.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
from types import SimpleNamespace
|
3 |
+
from classifier import classifier_run
|
4 |
+
from .finetuned_bert import get_finetuned_bert
|
5 |
+
|
6 |
+
ARGUMENTS = SimpleNamespace(
|
7 |
+
dataset='sst',
|
8 |
+
batch_size=BATCH_SIZE_SST,
|
9 |
+
train=IDS_SST_TRAIN,
|
10 |
+
dev=IDS_SST_DEV,
|
11 |
+
test=IDS_SST_TEST,
|
12 |
+
lr=1e-3,
|
13 |
+
fine_tune_mode='last-linear-layer'
|
14 |
+
)
|
15 |
+
|
16 |
+
bert = get_finetuned_bert('unsup')
|
17 |
+
classifier_run(ARGUMENTS, bert)
|
utils.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
import copy
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from functools import partial
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
from pathlib import Path
|
12 |
+
import requests
|
13 |
+
from hashlib import sha256
|
14 |
+
from filelock import FileLock
|
15 |
+
import importlib_metadata
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch import Tensor
|
19 |
+
import fnmatch
|
20 |
+
|
21 |
+
|
22 |
+
__version__ = "4.0.0"
|
23 |
+
_torch_version = importlib_metadata.version("torch")
|
24 |
+
|
25 |
+
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
|
26 |
+
default_cache_path = os.path.join(hf_cache_home, "transformers")
|
27 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
28 |
+
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
29 |
+
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
30 |
+
|
31 |
+
PRESET_MIRROR_DICT = {
|
32 |
+
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
33 |
+
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
34 |
+
}
|
35 |
+
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
36 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
37 |
+
CONFIG_NAME = "config.json"
|
38 |
+
|
39 |
+
|
40 |
+
def is_torch_available():
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def is_tf_available():
|
45 |
+
return False
|
46 |
+
|
47 |
+
|
48 |
+
def is_remote_url(url_or_filename):
|
49 |
+
parsed = urlparse(url_or_filename)
|
50 |
+
return parsed.scheme in ("http", "https")
|
51 |
+
|
52 |
+
|
53 |
+
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
54 |
+
headers = copy.deepcopy(headers)
|
55 |
+
if resume_size > 0:
|
56 |
+
headers["Range"] = "bytes=%d-" % (resume_size,)
|
57 |
+
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
58 |
+
r.raise_for_status()
|
59 |
+
content_length = r.headers.get("Content-Length")
|
60 |
+
total = resume_size + int(content_length) if content_length is not None else None
|
61 |
+
progress = tqdm(
|
62 |
+
unit="B",
|
63 |
+
unit_scale=True,
|
64 |
+
total=total,
|
65 |
+
initial=resume_size,
|
66 |
+
desc="Downloading",
|
67 |
+
disable=False,
|
68 |
+
)
|
69 |
+
for chunk in r.iter_content(chunk_size=1024):
|
70 |
+
if chunk: # filter out keep-alive new chunks
|
71 |
+
progress.update(len(chunk))
|
72 |
+
temp_file.write(chunk)
|
73 |
+
progress.close()
|
74 |
+
|
75 |
+
|
76 |
+
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
77 |
+
url_bytes = url.encode("utf-8")
|
78 |
+
filename = sha256(url_bytes).hexdigest()
|
79 |
+
|
80 |
+
if etag:
|
81 |
+
etag_bytes = etag.encode("utf-8")
|
82 |
+
filename += "." + sha256(etag_bytes).hexdigest()
|
83 |
+
|
84 |
+
if url.endswith(".h5"):
|
85 |
+
filename += ".h5"
|
86 |
+
|
87 |
+
return filename
|
88 |
+
|
89 |
+
|
90 |
+
def hf_bucket_url(
|
91 |
+
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
92 |
+
) -> str:
|
93 |
+
if subfolder is not None:
|
94 |
+
filename = f"{subfolder}/{filename}"
|
95 |
+
|
96 |
+
if mirror:
|
97 |
+
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
98 |
+
legacy_format = "/" not in model_id
|
99 |
+
if legacy_format:
|
100 |
+
return f"{endpoint}/{model_id}-{filename}"
|
101 |
+
else:
|
102 |
+
return f"{endpoint}/{model_id}/{filename}"
|
103 |
+
|
104 |
+
if revision is None:
|
105 |
+
revision = "main"
|
106 |
+
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
107 |
+
|
108 |
+
|
109 |
+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
110 |
+
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
111 |
+
if is_torch_available():
|
112 |
+
ua += f"; torch/{_torch_version}"
|
113 |
+
if is_tf_available():
|
114 |
+
ua += f"; tensorflow/{_tf_version}"
|
115 |
+
if isinstance(user_agent, dict):
|
116 |
+
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
117 |
+
elif isinstance(user_agent, str):
|
118 |
+
ua += "; " + user_agent
|
119 |
+
return ua
|
120 |
+
|
121 |
+
|
122 |
+
def get_from_cache(
|
123 |
+
url: str,
|
124 |
+
cache_dir=None,
|
125 |
+
force_download=False,
|
126 |
+
proxies=None,
|
127 |
+
etag_timeout=10,
|
128 |
+
resume_download=False,
|
129 |
+
user_agent: Union[Dict, str, None] = None,
|
130 |
+
use_auth_token: Union[bool, str, None] = None,
|
131 |
+
local_files_only=False,
|
132 |
+
) -> Optional[str]:
|
133 |
+
if cache_dir is None:
|
134 |
+
cache_dir = TRANSFORMERS_CACHE
|
135 |
+
if isinstance(cache_dir, Path):
|
136 |
+
cache_dir = str(cache_dir)
|
137 |
+
|
138 |
+
os.makedirs(cache_dir, exist_ok=True)
|
139 |
+
|
140 |
+
headers = {"user-agent": http_user_agent(user_agent)}
|
141 |
+
if isinstance(use_auth_token, str):
|
142 |
+
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
143 |
+
elif use_auth_token:
|
144 |
+
token = HfFolder.get_token()
|
145 |
+
if token is None:
|
146 |
+
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
147 |
+
headers["authorization"] = "Bearer {}".format(token)
|
148 |
+
|
149 |
+
url_to_download = url
|
150 |
+
etag = None
|
151 |
+
if not local_files_only:
|
152 |
+
try:
|
153 |
+
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
154 |
+
r.raise_for_status()
|
155 |
+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
156 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
157 |
+
# we fallback to the regular etag header.
|
158 |
+
# If we don't have any of those, raise an error.
|
159 |
+
if etag is None:
|
160 |
+
raise OSError(
|
161 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
162 |
+
)
|
163 |
+
# In case of a redirect,
|
164 |
+
# save an extra redirect on the request.get call,
|
165 |
+
# and ensure we download the exact atomic version even if it changed
|
166 |
+
# between the HEAD and the GET (unlikely, but hey).
|
167 |
+
if 300 <= r.status_code <= 399:
|
168 |
+
url_to_download = r.headers["Location"]
|
169 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
170 |
+
# etag is already None
|
171 |
+
pass
|
172 |
+
|
173 |
+
filename = url_to_filename(url, etag)
|
174 |
+
|
175 |
+
# get cache path to put the file
|
176 |
+
cache_path = os.path.join(cache_dir, filename)
|
177 |
+
|
178 |
+
# etag is None == we don't have a connection or we passed local_files_only.
|
179 |
+
# try to get the last downloaded one
|
180 |
+
if etag is None:
|
181 |
+
if os.path.exists(cache_path):
|
182 |
+
return cache_path
|
183 |
+
else:
|
184 |
+
matching_files = [
|
185 |
+
file
|
186 |
+
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
187 |
+
if not file.endswith(".json") and not file.endswith(".lock")
|
188 |
+
]
|
189 |
+
if len(matching_files) > 0:
|
190 |
+
return os.path.join(cache_dir, matching_files[-1])
|
191 |
+
else:
|
192 |
+
# If files cannot be found and local_files_only=True,
|
193 |
+
# the models might've been found if local_files_only=False
|
194 |
+
# Notify the user about that
|
195 |
+
if local_files_only:
|
196 |
+
raise FileNotFoundError(
|
197 |
+
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
198 |
+
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
199 |
+
" to False."
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
raise ValueError(
|
203 |
+
"Connection error, and we cannot find the requested files in the cached path."
|
204 |
+
" Please try again or make sure your Internet connection is on."
|
205 |
+
)
|
206 |
+
|
207 |
+
# From now on, etag is not None.
|
208 |
+
if os.path.exists(cache_path) and not force_download:
|
209 |
+
return cache_path
|
210 |
+
|
211 |
+
# Prevent parallel downloads of the same file with a lock.
|
212 |
+
lock_path = cache_path + ".lock"
|
213 |
+
with FileLock(lock_path):
|
214 |
+
|
215 |
+
# If the download just completed while the lock was activated.
|
216 |
+
if os.path.exists(cache_path) and not force_download:
|
217 |
+
# Even if returning early like here, the lock will be released.
|
218 |
+
return cache_path
|
219 |
+
|
220 |
+
if resume_download:
|
221 |
+
incomplete_path = cache_path + ".incomplete"
|
222 |
+
|
223 |
+
@contextmanager
|
224 |
+
def _resumable_file_manager() -> "io.BufferedWriter":
|
225 |
+
with open(incomplete_path, "ab") as f:
|
226 |
+
yield f
|
227 |
+
|
228 |
+
temp_file_manager = _resumable_file_manager
|
229 |
+
if os.path.exists(incomplete_path):
|
230 |
+
resume_size = os.stat(incomplete_path).st_size
|
231 |
+
else:
|
232 |
+
resume_size = 0
|
233 |
+
else:
|
234 |
+
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
235 |
+
resume_size = 0
|
236 |
+
|
237 |
+
# Download to temporary file, then copy to cache dir once finished.
|
238 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
239 |
+
with temp_file_manager() as temp_file:
|
240 |
+
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
241 |
+
|
242 |
+
os.replace(temp_file.name, cache_path)
|
243 |
+
|
244 |
+
meta = {"url": url, "etag": etag}
|
245 |
+
meta_path = cache_path + ".json"
|
246 |
+
with open(meta_path, "w") as meta_file:
|
247 |
+
json.dump(meta, meta_file)
|
248 |
+
|
249 |
+
return cache_path
|
250 |
+
|
251 |
+
|
252 |
+
def cached_path(
|
253 |
+
url_or_filename,
|
254 |
+
cache_dir=None,
|
255 |
+
force_download=False,
|
256 |
+
proxies=None,
|
257 |
+
resume_download=False,
|
258 |
+
user_agent: Union[Dict, str, None] = None,
|
259 |
+
extract_compressed_file=False,
|
260 |
+
force_extract=False,
|
261 |
+
use_auth_token: Union[bool, str, None] = None,
|
262 |
+
local_files_only=False,
|
263 |
+
) -> Optional[str]:
|
264 |
+
if cache_dir is None:
|
265 |
+
cache_dir = TRANSFORMERS_CACHE
|
266 |
+
if isinstance(url_or_filename, Path):
|
267 |
+
url_or_filename = str(url_or_filename)
|
268 |
+
if isinstance(cache_dir, Path):
|
269 |
+
cache_dir = str(cache_dir)
|
270 |
+
|
271 |
+
if is_remote_url(url_or_filename):
|
272 |
+
# URL, so get it from the cache (downloading if necessary)
|
273 |
+
output_path = get_from_cache(
|
274 |
+
url_or_filename,
|
275 |
+
cache_dir=cache_dir,
|
276 |
+
force_download=force_download,
|
277 |
+
proxies=proxies,
|
278 |
+
resume_download=resume_download,
|
279 |
+
user_agent=user_agent,
|
280 |
+
use_auth_token=use_auth_token,
|
281 |
+
local_files_only=local_files_only,
|
282 |
+
)
|
283 |
+
elif os.path.exists(url_or_filename):
|
284 |
+
# File, and it exists.
|
285 |
+
output_path = url_or_filename
|
286 |
+
elif urlparse(url_or_filename).scheme == "":
|
287 |
+
# File, but it doesn't exist.
|
288 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
289 |
+
else:
|
290 |
+
# Something unknown
|
291 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
292 |
+
|
293 |
+
if extract_compressed_file:
|
294 |
+
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
295 |
+
return output_path
|
296 |
+
|
297 |
+
# Path where we extract compressed archives
|
298 |
+
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
299 |
+
output_dir, output_file = os.path.split(output_path)
|
300 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
301 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
302 |
+
|
303 |
+
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
304 |
+
return output_path_extracted
|
305 |
+
|
306 |
+
# Prevent parallel extractions
|
307 |
+
lock_path = output_path + ".lock"
|
308 |
+
with FileLock(lock_path):
|
309 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
310 |
+
os.makedirs(output_path_extracted)
|
311 |
+
if is_zipfile(output_path):
|
312 |
+
with ZipFile(output_path, "r") as zip_file:
|
313 |
+
zip_file.extractall(output_path_extracted)
|
314 |
+
zip_file.close()
|
315 |
+
elif tarfile.is_tarfile(output_path):
|
316 |
+
tar_file = tarfile.open(output_path)
|
317 |
+
tar_file.extractall(output_path_extracted)
|
318 |
+
tar_file.close()
|
319 |
+
else:
|
320 |
+
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
321 |
+
|
322 |
+
return output_path_extracted
|
323 |
+
|
324 |
+
return output_path
|
325 |
+
|
326 |
+
|
327 |
+
def get_parameter_dtype(parameter: Union[nn.Module]):
|
328 |
+
try:
|
329 |
+
return next(parameter.parameters()).dtype
|
330 |
+
except StopIteration:
|
331 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
332 |
+
|
333 |
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
334 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
335 |
+
return tuples
|
336 |
+
|
337 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
338 |
+
first_tuple = next(gen)
|
339 |
+
return first_tuple[1].dtype
|
340 |
+
|
341 |
+
|
342 |
+
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
|
343 |
+
# attention_mask [batch_size, seq_length]
|
344 |
+
assert attention_mask.dim() == 2
|
345 |
+
# [batch_size, 1, 1, seq_length] for multi-head attention
|
346 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
347 |
+
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
348 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
349 |
+
return extended_attention_mask
|