GlowCheese commited on
Commit
a0b398e
·
1 Parent(s): 354d4d3

Transfer code from Kaggle

Browse files
.python-version CHANGED
@@ -1 +1 @@
1
- 3.8.20
 
1
+ 3.10.15
base_bert.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from torch import device, dtype, nn
3
+ from config import BertConfig, PretrainedConfig
4
+ from utils import *
5
+
6
+ class BertPreTrainedModel(nn.Module):
7
+ config_class = BertConfig
8
+ base_model_prefix = "bert"
9
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
10
+ _keys_to_ignore_on_load_unexpected = None
11
+
12
+ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
13
+ super().__init__()
14
+ self.config = config
15
+ self.name_or_path = config.name_or_path
16
+
17
+ def init_weights(self):
18
+ # Initialize weights
19
+ self.apply(self._init_weights)
20
+
21
+ def _init_weights(self, module):
22
+ """ Initialize the weights """
23
+ if isinstance(module, (nn.Linear, nn.Embedding)):
24
+ # Slightly different from the TF version which uses truncated_normal for initialization
25
+ # cf https://github.com/pytorch/pytorch/pull/5617
26
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
27
+ elif isinstance(module, nn.LayerNorm):
28
+ module.bias.data.zero_()
29
+ module.weight.data.fill_(1.0)
30
+ if isinstance(module, nn.Linear) and module.bias is not None:
31
+ module.bias.data.zero_()
32
+
33
+ @property
34
+ def dtype(self) -> dtype:
35
+ return get_parameter_dtype(self)
36
+
37
+ @classmethod
38
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
39
+ config = kwargs.pop("config", None)
40
+ state_dict = kwargs.pop("state_dict", None)
41
+ cache_dir = kwargs.pop("cache_dir", None)
42
+ force_download = kwargs.pop("force_download", False)
43
+ resume_download = kwargs.pop("resume_download", False)
44
+ proxies = kwargs.pop("proxies", None)
45
+ output_loading_info = kwargs.pop("output_loading_info", False)
46
+ local_files_only = kwargs.pop("local_files_only", False)
47
+ use_auth_token = kwargs.pop("use_auth_token", None)
48
+ revision = kwargs.pop("revision", None)
49
+ mirror = kwargs.pop("mirror", None)
50
+
51
+ # Load config if we don't provide a configuration
52
+ if not isinstance(config, PretrainedConfig):
53
+ config_path = config if config is not None else pretrained_model_name_or_path
54
+ config, model_kwargs = cls.config_class.from_pretrained(
55
+ config_path,
56
+ *model_args,
57
+ cache_dir=cache_dir,
58
+ return_unused_kwargs=True,
59
+ force_download=force_download,
60
+ resume_download=resume_download,
61
+ proxies=proxies,
62
+ local_files_only=local_files_only,
63
+ use_auth_token=use_auth_token,
64
+ revision=revision,
65
+ **kwargs,
66
+ )
67
+ else:
68
+ model_kwargs = kwargs
69
+
70
+ # Load model
71
+ if pretrained_model_name_or_path is not None:
72
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
73
+ if os.path.isdir(pretrained_model_name_or_path):
74
+ # Load from a PyTorch checkpoint
75
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
76
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
77
+ archive_file = pretrained_model_name_or_path
78
+ else:
79
+ archive_file = hf_bucket_url(
80
+ pretrained_model_name_or_path,
81
+ filename=WEIGHTS_NAME,
82
+ revision=revision,
83
+ mirror=mirror,
84
+ )
85
+ try:
86
+ # Load from URL or cache if already cached
87
+ resolved_archive_file = cached_path(
88
+ archive_file,
89
+ cache_dir=cache_dir,
90
+ force_download=force_download,
91
+ proxies=proxies,
92
+ resume_download=resume_download,
93
+ local_files_only=local_files_only,
94
+ use_auth_token=use_auth_token,
95
+ )
96
+ except EnvironmentError as err:
97
+ #logger.error(err)
98
+ msg = (
99
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
100
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
101
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
102
+ )
103
+ raise EnvironmentError(msg)
104
+ else:
105
+ resolved_archive_file = None
106
+
107
+ config.name_or_path = pretrained_model_name_or_path
108
+
109
+ # Instantiate model.
110
+ model = cls(config, *model_args, **model_kwargs)
111
+
112
+ if state_dict is None:
113
+ try:
114
+ state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
115
+ except Exception:
116
+ raise OSError(
117
+ f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
118
+ f"at '{resolved_archive_file}'"
119
+ )
120
+
121
+ missing_keys = []
122
+ unexpected_keys = []
123
+ error_msgs = []
124
+
125
+ # Convert old format to new format if needed from a PyTorch state_dict
126
+ old_keys = []
127
+ new_keys = []
128
+ m = {'embeddings.word_embeddings': 'word_embedding',
129
+ 'embeddings.position_embeddings': 'pos_embedding',
130
+ 'embeddings.token_type_embeddings': 'tk_type_embedding',
131
+ 'embeddings.LayerNorm': 'embed_layer_norm',
132
+ 'embeddings.dropout': 'embed_dropout',
133
+ 'encoder.layer': 'bert_layers',
134
+ 'pooler.dense': 'pooler_dense',
135
+ 'pooler.activation': 'pooler_af',
136
+ 'attention.self': "self_attention",
137
+ 'attention.output.dense': 'attention_dense',
138
+ 'attention.output.LayerNorm': 'attention_layer_norm',
139
+ 'attention.output.dropout': 'attention_dropout',
140
+ 'intermediate.dense': 'interm_dense',
141
+ 'intermediate.intermediate_act_fn': 'interm_af',
142
+ 'output.dense': 'out_dense',
143
+ 'output.LayerNorm': 'out_layer_norm',
144
+ 'output.dropout': 'out_dropout'}
145
+
146
+ for key in state_dict.keys():
147
+ new_key = None
148
+ if "gamma" in key:
149
+ new_key = key.replace("gamma", "weight")
150
+ if "beta" in key:
151
+ new_key = key.replace("beta", "bias")
152
+ for x, y in m.items():
153
+ if new_key is not None:
154
+ _key = new_key
155
+ else:
156
+ _key = key
157
+ if x in key:
158
+ new_key = _key.replace(x, y)
159
+ if new_key:
160
+ old_keys.append(key)
161
+ new_keys.append(new_key)
162
+
163
+ for old_key, new_key in zip(old_keys, new_keys):
164
+ # print(old_key, new_key)
165
+ state_dict[new_key] = state_dict.pop(old_key)
166
+
167
+ # copy state_dict so _load_from_state_dict can modify it
168
+ metadata = getattr(state_dict, "_metadata", None)
169
+ state_dict = state_dict.copy()
170
+ if metadata is not None:
171
+ state_dict._metadata = metadata
172
+
173
+ your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
174
+ for k in state_dict:
175
+ if k not in your_bert_params and not k.startswith("cls."):
176
+ possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
177
+ raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
178
+
179
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
180
+ # so we need to apply the function recursively.
181
+ def load(module: nn.Module, prefix=""):
182
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
183
+ module._load_from_state_dict(
184
+ state_dict,
185
+ prefix,
186
+ local_metadata,
187
+ True,
188
+ missing_keys,
189
+ unexpected_keys,
190
+ error_msgs,
191
+ )
192
+ for name, child in module._modules.items():
193
+ if child is not None:
194
+ load(child, prefix + name + ".")
195
+
196
+ # Make sure we are able to load base models as well as derived models (with heads)
197
+ start_prefix = ""
198
+ model_to_load = model
199
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
200
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
201
+ start_prefix = cls.base_model_prefix + "."
202
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
203
+ model_to_load = getattr(model, cls.base_model_prefix)
204
+ load(model_to_load, prefix=start_prefix)
205
+
206
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
207
+ base_model_state_dict = model_to_load.state_dict().keys()
208
+ head_model_state_dict_without_base_prefix = [
209
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
210
+ ]
211
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
212
+
213
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
214
+ # the user.
215
+ if cls._keys_to_ignore_on_load_missing is not None:
216
+ for pat in cls._keys_to_ignore_on_load_missing:
217
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
218
+
219
+ if cls._keys_to_ignore_on_load_unexpected is not None:
220
+ for pat in cls._keys_to_ignore_on_load_unexpected:
221
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
222
+
223
+ if len(error_msgs) > 0:
224
+ raise RuntimeError(
225
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
226
+ model.__class__.__name__, "\n\t".join(error_msgs)
227
+ )
228
+ )
229
+
230
+ # Set model in evaluation mode to deactivate DropOut modules by default
231
+ model.eval()
232
+
233
+ if output_loading_info:
234
+ loading_info = {
235
+ "missing_keys": missing_keys,
236
+ "unexpected_keys": unexpected_keys,
237
+ "error_msgs": error_msgs,
238
+ }
239
+ return model, loading_info
240
+
241
+ if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
242
+ import torch_xla.core.xla_model as xm
243
+
244
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
245
+ model.to(xm.xla_device())
246
+
247
+ return model
bert.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class BertSelfAttention(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+
11
+ self.num_attention_heads = config.num_attention_heads
12
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
13
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
14
+
15
+ # Initialize the linear transformation layers for key, value, query.
16
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
17
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
18
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
19
+ # This dropout is applied to normalized attention scores following the original
20
+ # implementation of transformer. Although it is a bit unusual, we empirically
21
+ # observe that it yields better performance.
22
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
23
+
24
+ def transform(self, x, linear_layer):
25
+ # The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
26
+ bs, seq_len = x.shape[:2]
27
+ proj = linear_layer(x)
28
+ # Next, we need to produce multiple heads for the proj. This is done by spliting the
29
+ # hidden state to self.num_attention_heads, each of size self.attention_head_size.
30
+ proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
31
+ # By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
32
+ proj = proj.transpose(1, 2)
33
+ return proj
34
+
35
+ def attention(self, key, query, value, attention_mask):
36
+ """
37
+ key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
38
+ attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
39
+ """
40
+
41
+ d_k = query.size(-1) # attention_head_size
42
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
43
+ # attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
44
+
45
+ # Apply attention mask
46
+ attention_scores = attention_scores + attention_mask
47
+
48
+ # Normalize scores with softmax and apply dropout.
49
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
50
+ attention_probs = self.dropout(attention_probs)
51
+
52
+ context = torch.matmul(attention_probs, value)
53
+ # context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
54
+
55
+ # Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
56
+ context = context.transpose(1, 2).contiguous()
57
+ context = context.view(context.size(0), context.size(1), -1)
58
+
59
+ return context
60
+
61
+
62
+ def forward(self, hidden_states, attention_mask):
63
+ """
64
+ hidden_states: [bs, seq_len, hidden_state]
65
+ attention_mask: [bs, 1, 1, seq_len]
66
+ output: [bs, seq_len, hidden_state]
67
+ """
68
+ # First, we have to generate the key, value, query for each token for multi-head attention
69
+ # using self.transform (more details inside the function).
70
+ # Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
71
+ key_layer = self.transform(hidden_states, self.key)
72
+ value_layer = self.transform(hidden_states, self.value)
73
+ query_layer = self.transform(hidden_states, self.query)
74
+ # Calculate the multi-head attention.
75
+ attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
76
+ return attn_value
77
+
78
+
79
+ class BertLayer(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ # Multi-head attention.
83
+ self.self_attention = BertSelfAttention(config)
84
+ # Add-norm for multi-head attention.
85
+ self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
86
+ self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
87
+ self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
88
+ # Feed forward.
89
+ self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
90
+ self.interm_af = F.gelu
91
+ # Add-norm for feed forward.
92
+ self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
93
+ self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
94
+ self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
95
+
96
+
97
+ def add_norm(self, input, output, dense_layer, dropout, ln_layer):
98
+ transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
99
+ transformed_output = dropout(transformed_output) # Áp dụng dropout
100
+ added_output = input + transformed_output # Kết hợp input và output
101
+ normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
102
+ return normalized_output
103
+
104
+
105
+ def forward(self, hidden_states, attention_mask):
106
+ # 1. Multi-head attention
107
+ attention_output = self.self_attention(hidden_states, attention_mask)
108
+
109
+ # 2. Add-norm after attention
110
+ attention_output = self.add_norm(
111
+ hidden_states,
112
+ attention_output,
113
+ self.attention_dense,
114
+ self.attention_dropout,
115
+ self.attention_layer_norm
116
+ )
117
+
118
+ # 3. Feed-forward network
119
+ intermediate_output = self.interm_af(self.interm_dense(attention_output))
120
+
121
+ # 4. Add-norm after feed-forward
122
+ layer_output = self.add_norm(
123
+ attention_output,
124
+ intermediate_output,
125
+ self.out_dense,
126
+ self.out_dropout,
127
+ self.out_layer_norm
128
+ )
129
+
130
+ return layer_output
131
+
132
+
133
+ class BertModel(BertPreTrainedModel):
134
+ """
135
+ The BERT model returns the final embeddings for each token in a sentence.
136
+
137
+ The model consists of:
138
+ 1. Embedding layers (used in self.embed).
139
+ 2. A stack of n BERT layers (used in self.encode).
140
+ 3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
141
+ """
142
+ def __init__(self, config):
143
+ super().__init__(config)
144
+ self.config = config
145
+
146
+ # Embedding layers.
147
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
148
+ self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
149
+ self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
150
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
151
+ self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
152
+ # Register position_ids (1, len position emb) to buffer because it is a constant.
153
+ position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
154
+ self.register_buffer('position_ids', position_ids)
155
+
156
+ # BERT encoder.
157
+ self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
158
+
159
+ # [CLS] token transformations.
160
+ self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
161
+ self.pooler_af = nn.Tanh()
162
+
163
+ self.init_weights()
164
+
165
+
166
+ def embed(self, input_ids):
167
+ input_shape = input_ids.size()
168
+ seq_length = input_shape[1]
169
+
170
+ inputs_embeds = self.word_embedding(input_ids)
171
+
172
+ pos_ids = self.position_ids[:, :seq_length]
173
+ pos_embeds = self.pos_embedding(pos_ids)
174
+
175
+ # Since we are not considering token type, this embedding is just a placeholder.
176
+ tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
177
+ tk_type_embeds = self.tk_type_embedding(tk_type_ids)
178
+
179
+ embeddings = inputs_embeds + pos_embeds + tk_type_embeds
180
+ embeddings = self.embed_layer_norm(embeddings)
181
+ embeddings = self.embed_dropout(embeddings)
182
+
183
+ return embeddings
184
+
185
+
186
+ def encode(self, hidden_states, attention_mask):
187
+ """
188
+ hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
189
+ attention_mask: [batch_size, seq_len]
190
+ """
191
+ # Get the extended attention mask for self-attention.
192
+ # Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
193
+ # Distinguishes between non-padding tokens (with a value of 0) and padding tokens
194
+ # (with a value of a large negative number).
195
+ extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
196
+
197
+ # Pass the hidden states through the encoder layers.
198
+ for i, layer_module in enumerate(self.bert_layers):
199
+ # Feed the encoding from the last bert_layer to the next.
200
+ hidden_states = layer_module(hidden_states, extended_attention_mask)
201
+
202
+ return hidden_states
203
+
204
+
205
+ def forward(self, input_ids, attention_mask):
206
+ """
207
+ input_ids: [batch_size, seq_len], seq_len is the max length of the batch
208
+ attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
209
+ """
210
+ # Get the embedding for each input token.
211
+ embedding_output = self.embed(input_ids=input_ids)
212
+
213
+ # Feed to a transformer (a stack of BertLayers).
214
+ sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
215
+
216
+ # Get cls token hidden state.
217
+ first_tk = sequence_output[:, 0]
218
+ first_tk = self.pooler_dense(first_tk)
219
+ first_tk = self.pooler_af(first_tk)
220
+
221
+ return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
classifier.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from classifier_utils import *
2
+
3
+
4
+ TQDM_DISABLE=True
5
+
6
+
7
+ class BertSentimentClassifier(torch.nn.Module):
8
+ def __init__(self, config, custom_bert = None):
9
+ super(BertSentimentClassifier, self).__init__()
10
+ self.num_labels = config.num_labels
11
+ self.bert: BertModel = custom_bert or BertModel.from_pretrained('bert-base-uncased')
12
+
13
+ # Pretrain mode does not require updating BERT paramters.
14
+ assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
15
+ for param in self.bert.parameters():
16
+ if config.fine_tune_mode == 'last-linear-layer':
17
+ param.requires_grad = False
18
+ elif config.fine_tune_mode == 'full-model':
19
+ param.requires_grad = True
20
+
21
+ # Classifier = Dropout + Linear
22
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
23
+ self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
24
+
25
+
26
+ def forward(self, input_ids, attention_mask):
27
+ outputs = self.bert(input_ids, attention_mask)
28
+ pooler_output = outputs['pooler_output']
29
+
30
+ return self.classifier(self.dropout(pooler_output))
31
+
32
+
33
+ # Evaluate the model on dev examples.
34
+ def model_eval(dataloader, model: BertSentimentClassifier, device):
35
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
36
+ y_true = []
37
+ y_pred = []
38
+ sents = []
39
+ sent_ids = []
40
+ for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
41
+ b_labels, b_sents, b_sent_ids = batch['labels'], batch['sents'], batch['sent_ids']
42
+
43
+ b_ids = batch['token_ids'].to(device)
44
+ b_mask = batch['attention_mask'].to(device)
45
+
46
+ logits = model(b_ids, b_mask)
47
+ logits = logits.detach().cpu().numpy()
48
+ preds = np.argmax(logits, axis=1).flatten()
49
+
50
+ b_labels = b_labels.flatten()
51
+ y_true.extend(b_labels)
52
+ y_pred.extend(preds)
53
+ sents.extend(b_sents)
54
+ sent_ids.extend(b_sent_ids)
55
+
56
+ f1 = f1_score(y_true, y_pred, average='macro')
57
+ acc = accuracy_score(y_true, y_pred)
58
+
59
+ return acc, f1, y_pred, y_true, sents, sent_ids
60
+
61
+
62
+ # Evaluate the model on test examples.
63
+ def model_test_eval(dataloader, model, device):
64
+ model.eval() # Switch to eval model, will turn off randomness like dropout.
65
+ y_pred = []
66
+ sents = []
67
+ sent_ids = []
68
+ for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
69
+ b_sents, b_sent_ids = batch['sents'], batch['sent_ids']
70
+
71
+ b_ids = batch['token_ids'].to(device)
72
+ b_mask = batch['attention_mask'].to(device)
73
+
74
+ logits = model(b_ids, b_mask)
75
+ logits = logits.detach().cpu().numpy()
76
+ preds = np.argmax(logits, axis=1).flatten()
77
+
78
+ y_pred.extend(preds)
79
+ sents.extend(b_sents)
80
+ sent_ids.extend(b_sent_ids)
81
+
82
+ return y_pred, sents, sent_ids
83
+
84
+
85
+ def save_model(model, args, config, filepath):
86
+ save_info = {
87
+ 'model': model.state_dict(),
88
+ 'args': args,
89
+ 'model_config': config,
90
+ 'system_rng': random.getstate(),
91
+ 'numpy_rng': np.random.get_state(),
92
+ 'torch_rng': torch.random.get_rng_state(),
93
+ }
94
+
95
+ torch.save(save_info, filepath)
96
+ print(f"save the model to {filepath}")
97
+
98
+
99
+ def train(args, custom_bert=None):
100
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
101
+ # Create the data and its corresponding datasets and dataloader.
102
+ train_data, num_labels = load_data(args.train, 'train')
103
+ dev_data = load_data(args.dev, 'valid')
104
+
105
+ train_dataset = SentimentDataset(train_data)
106
+ dev_dataset = SentimentDataset(dev_data)
107
+
108
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
109
+ num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn)
110
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
111
+ num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
112
+
113
+ # Init model.
114
+ config = {'hidden_dropout_prob': HIDDEN_DROPOUT_PROB,
115
+ 'num_labels': num_labels,
116
+ 'hidden_size': 768,
117
+ 'data_dir': '.',
118
+ 'fine_tune_mode': args.fine_tune_mode}
119
+
120
+ config = SimpleNamespace(**config)
121
+
122
+ model = BertSentimentClassifier(config, custom_bert)
123
+ model = model.to(device)
124
+
125
+ lr = args.lr
126
+ optimizer = AdamW(model.parameters(), lr=lr)
127
+ best_dev_acc = 0
128
+
129
+ # Run for the specified number of epochs.
130
+ for epoch in range(EPOCHS):
131
+ model.train()
132
+ train_loss = 0
133
+ num_batches = 0
134
+ for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
135
+ b_ids = batch['token_ids'].to(device)
136
+ b_mask = batch['attention_mask'].to(device)
137
+ b_labels = batch['labels'].to(device)
138
+
139
+ optimizer.zero_grad()
140
+ logits = model(b_ids, b_mask)
141
+ loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
142
+
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ train_loss += loss.item()
147
+ num_batches += 1
148
+
149
+ train_loss = train_loss / (num_batches)
150
+
151
+ train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
152
+ dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
153
+
154
+ if dev_acc > best_dev_acc:
155
+ best_dev_acc = dev_acc
156
+ save_model(model, args, config, args.filepath)
157
+
158
+ print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
159
+
160
+
161
+ def test(args):
162
+ with torch.no_grad():
163
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
164
+ saved = torch.load(args.filepath, weights_only=False)
165
+ config = saved['model_config']
166
+ model = BertSentimentClassifier(config)
167
+ model.load_state_dict(saved['model'])
168
+ model = model.to(device)
169
+ print(f"load model from {args.filepath}")
170
+
171
+ dev_data = load_data(args.dev, 'valid')
172
+ dev_dataset = SentimentDataset(dev_data)
173
+ dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
174
+ num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn)
175
+
176
+ dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
177
+ print('DONE DEV')
178
+ print(f"dev acc :: {dev_acc :.3f}")
179
+
180
+
181
+ def classifier_run(args, custom_bert=None):
182
+ seed_everything(SEED)
183
+ torch.set_num_threads(NUM_CPU_CORES)
184
+
185
+ print(f'Training Sentiment Classifier on {args.dataset}...')
186
+ config = SimpleNamespace(
187
+ filepath=f'/kaggle/working/{args.dataset}-classifier.pt',
188
+ lr=args.lr,
189
+ batch_size=args.batch_size,
190
+ fine_tune_mode=args.fine_tune_mode,
191
+ train=args.train, dev=args.dev, test=args.test,
192
+ dev_out = f'/kaggle/working/predictions/{args.fine_tune_mode}-{args.dataset}-dev-out.csv',
193
+ test_out = f'/kaggle/working/predictions/{args.fine_tune_mode}-{args.dataset}-test-out.csv'
194
+ )
195
+
196
+ train(config, custom_bert)
197
+
198
+ print(f'Evaluating on {args.dataset}...')
199
+ test(config)
classifier_utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from everything import *
2
+ from bert import BertModel
3
+ from optimizer import AdamW
4
+ from tokenizer import BertTokenizer
5
+
6
+
7
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+
9
+
10
+ class SentimentDataset(Dataset):
11
+ def __init__(self, dataset):
12
+ self.dataset = dataset
13
+
14
+ def __len__(self):
15
+ return len(self.dataset)
16
+
17
+ def __getitem__(self, idx):
18
+ return self.dataset[idx]
19
+
20
+ def pad_data(self, data):
21
+ sents = [x[0] for x in data]
22
+ labels = [x[1] for x in data]
23
+ sent_ids = [x[2] for x in data]
24
+
25
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
26
+ token_ids = torch.LongTensor(encoding['input_ids'])
27
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
28
+ labels = torch.LongTensor(labels)
29
+
30
+ return token_ids, attention_mask, labels, sents, sent_ids
31
+
32
+ def collate_fn(self, all_data):
33
+ token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
34
+
35
+ batched_data = {
36
+ 'token_ids': token_ids,
37
+ 'attention_mask': attention_mask,
38
+ 'labels': labels,
39
+ 'sents': sents,
40
+ 'sent_ids': sent_ids
41
+ }
42
+
43
+ return batched_data
44
+
45
+
46
+ class SentimentTestDataset(Dataset):
47
+ def __init__(self, dataset):
48
+ self.dataset = dataset
49
+
50
+ def __len__(self):
51
+ return len(self.dataset)
52
+
53
+ def __getitem__(self, idx):
54
+ return self.dataset[idx]
55
+
56
+ def pad_data(self, data):
57
+ sents = [x[0] for x in data]
58
+ sent_ids = [x[1] for x in data]
59
+
60
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
61
+ token_ids = torch.LongTensor(encoding['input_ids'])
62
+ attention_mask = torch.LongTensor(encoding['attention_mask'])
63
+
64
+ return token_ids, attention_mask, sents, sent_ids
65
+
66
+ def collate_fn(self, all_data):
67
+ token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
68
+
69
+ batched_data = {
70
+ 'token_ids': token_ids,
71
+ 'attention_mask': attention_mask,
72
+ 'sents': sents,
73
+ 'sent_ids': sent_ids
74
+ }
75
+
76
+ return batched_data
77
+
78
+
79
+ class AmazonDataset(Dataset):
80
+ def __init__(self, dataset):
81
+ self.dataset = dataset
82
+
83
+ def __len__(self):
84
+ return len(self.dataset)
85
+
86
+ def __getitem__(self, idx):
87
+ return self.dataset[idx]
88
+
89
+ def pad_data(self, data):
90
+ sents = [x[0] for x in data]
91
+ sent_ids = [x[1] for x in data]
92
+ encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
93
+ token_ids = torch.LongTensor(encoding['input_ids'])
94
+ attension_mask = torch.LongTensor(encoding['attention_mask'])
95
+
96
+ return token_ids, attension_mask, sent_ids
97
+
98
+ def collate_fn(self, data):
99
+ token_ids, attention_mask, sent_ids = self.pad_data(data)
100
+
101
+ batched_data = {
102
+ 'token_ids': token_ids,
103
+ 'attention_mask': attention_mask,
104
+ 'sent_ids': sent_ids
105
+ }
106
+
107
+ return batched_data
108
+
109
+
110
+ class SemanticDataset(Dataset):
111
+ def __init__(self, dataset):
112
+ self.dataset = dataset
113
+
114
+ def __len__(self):
115
+ return len(self.dataset)
116
+
117
+ def __getitem__(self, idx):
118
+ return self.dataset[idx]
119
+
120
+ def pad_data(self, data):
121
+ sents1 = [x[0] for x in data]
122
+ sents2 = [x[1] for x in data]
123
+ score = [x[2] for x in data]
124
+ sent_ids = [x[3] for x in data]
125
+ encoding = tokenizer(sents1 + sents2, return_tensors='pt', padding=True, truncation=True)
126
+ token_ids = torch.LongTensor(encoding['input_ids'])
127
+ attension_mask = torch.LongTensor(encoding['attention_mask'])
128
+
129
+ return token_ids, attension_mask, score, sent_ids
130
+
131
+ def collate_fn(self, data):
132
+ token_ids, attention_mask, score, sent_ids = self.pad_data(data)
133
+ n = len(sent_ids)
134
+
135
+ batched_data = {
136
+ 'token_ids_1': token_ids[:n],
137
+ 'token_ids_2': token_ids[n:],
138
+ 'attention_mask_1': attention_mask[:n],
139
+ 'attention_mask_2': attention_mask[n:],
140
+ 'score': score,
141
+ 'sent_ids': sent_ids
142
+ }
143
+
144
+ return batched_data
145
+
146
+
147
+ class InferenceDataset(Dataset):
148
+ def __init__(self, dataset):
149
+ self.dataset = dataset
150
+
151
+ def __len__(self):
152
+ return len(self.dataset)
153
+
154
+ def __getitem__(self, idx):
155
+ return self.dataset[idx]
156
+
157
+ def pad_data(self, data):
158
+ anchor = [x[0] for x in data]
159
+ positive = [x[1] for x in data]
160
+ negative = [x[2] for x in data]
161
+ sent_ids = [x[3] for x in data]
162
+ encoding = tokenizer(anchor + positive + negative, return_tensors='pt', padding=True, truncation=True)
163
+ token_ids = torch.LongTensor(encoding['input_ids'])
164
+ attension_mask = torch.LongTensor(encoding['attention_mask'])
165
+
166
+ return token_ids, attension_mask, sent_ids
167
+
168
+ def collate_fn(self, data):
169
+ token_ids, attention_mask, sent_ids = self.pad_data(data)
170
+ n = len(sent_ids)
171
+
172
+ batched_data = {
173
+ 'anchor_ids': token_ids[:n],
174
+ 'positive_ids': token_ids[n:2*n],
175
+ 'negative_ids': token_ids[2*n:],
176
+ 'anchor_masks': attention_mask[:n],
177
+ 'positive_masks': attention_mask[n:2*n],
178
+ 'negative_masks': attention_mask[2*n:],
179
+ 'sent_ids': sent_ids
180
+ }
181
+
182
+ return batched_data
183
+
184
+
185
+ def load_data(filename, flag='train'):
186
+ '''
187
+ - for amazon dataset: list of (sent, id)
188
+ - for nli dataset: list of (anchor, positive, negative, id)
189
+ - for stsb dataset: list of (sentence1, sentence2, score, id)
190
+
191
+ - for test dataset: list of (sent, id)
192
+ - for train dataset: list of (sent, label, id)
193
+ '''
194
+
195
+ if flag == 'amazon':
196
+ df = pd.read_parquet(filename)
197
+ data = list(zip(df['content'], df.index))
198
+ elif flag == 'nli':
199
+ df = pd.read_parquet(filename)
200
+ data = list(zip(df['anchor'], df['positive'], df['negative'], df.index))
201
+ elif flag == 'stsb':
202
+ df = pd.read_parquet(filename)
203
+ data = list(zip(df['sentence1'], df['sentence2'], df['score'], df.index))
204
+ else:
205
+ data, num_labels = [], set()
206
+
207
+ with open(filename, 'r') as fp:
208
+ if flag == 'test':
209
+ for record in csv.DictReader(fp, delimiter = '\t'):
210
+ sent = record['sentence'].lower().strip()
211
+ sent_id = record['id'].lower().strip()
212
+ data.append((sent,sent_id))
213
+ else:
214
+ for record in csv.DictReader(fp, delimiter = '\t'):
215
+ sent = record['sentence'].lower().strip()
216
+ sent_id = record['id'].lower().strip()
217
+ label = int(record['sentiment'].strip())
218
+ num_labels.add(label)
219
+ data.append((sent, label, sent_id))
220
+
221
+ print(f"load {len(data)} data from {filename}")
222
+ if flag == "train":
223
+ return data, len(num_labels)
224
+ else:
225
+ return data
226
+
227
+
228
+ def seed_everything(seed=11711):
229
+ random.seed(seed)
230
+ np.random.seed(seed)
231
+ torch.manual_seed(seed)
232
+ torch.cuda.manual_seed(seed)
233
+ torch.cuda.manual_seed_all(seed)
234
+ torch.backends.cudnn.benchmark = False
235
+ torch.backends.cudnn.deterministic = True
config.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple, Dict, Any, Optional
2
+ import os
3
+ import json
4
+ from collections import OrderedDict
5
+ import torch
6
+ from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
7
+
8
+ class PretrainedConfig(object):
9
+ model_type: str = ""
10
+ is_composition: bool = False
11
+
12
+ def __init__(self, **kwargs):
13
+ # Attributes with defaults
14
+ self.return_dict = kwargs.pop("return_dict", True)
15
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
16
+ self.output_attentions = kwargs.pop("output_attentions", False)
17
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
18
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
19
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
20
+ self.tie_word_embeddings = kwargs.pop(
21
+ "tie_word_embeddings", True
22
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
23
+
24
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
25
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
26
+ self.is_decoder = kwargs.pop("is_decoder", False)
27
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
28
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
29
+
30
+ # Parameters for sequence generation
31
+ self.max_length = kwargs.pop("max_length", 20)
32
+ self.min_length = kwargs.pop("min_length", 0)
33
+ self.do_sample = kwargs.pop("do_sample", False)
34
+ self.early_stopping = kwargs.pop("early_stopping", False)
35
+ self.num_beams = kwargs.pop("num_beams", 1)
36
+ self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
37
+ self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
38
+ self.temperature = kwargs.pop("temperature", 1.0)
39
+ self.top_k = kwargs.pop("top_k", 50)
40
+ self.top_p = kwargs.pop("top_p", 1.0)
41
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
42
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
43
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
44
+ self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
45
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
46
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
47
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
48
+ self.output_scores = kwargs.pop("output_scores", False)
49
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
50
+ self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
51
+ self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
52
+
53
+ # Fine-tuning task arguments
54
+ self.architectures = kwargs.pop("architectures", None)
55
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
56
+ self.id2label = kwargs.pop("id2label", None)
57
+ self.label2id = kwargs.pop("label2id", None)
58
+ if self.id2label is not None:
59
+ kwargs.pop("num_labels", None)
60
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
61
+ # Keys are always strings in JSON so convert ids to int here.
62
+ else:
63
+ self.num_labels = kwargs.pop("num_labels", 2)
64
+
65
+ # Tokenizer arguments
66
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
67
+ self.prefix = kwargs.pop("prefix", None)
68
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
69
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
70
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
71
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
72
+
73
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
74
+
75
+ # task specific arguments
76
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
77
+
78
+ # TPU arguments
79
+ self.xla_device = kwargs.pop("xla_device", None)
80
+
81
+ # Name or path to the pretrained checkpoint
82
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
83
+
84
+ # Drop the transformers version info
85
+ kwargs.pop("transformers_version", None)
86
+
87
+ # Additional attributes without default values
88
+ for key, value in kwargs.items():
89
+ try:
90
+ setattr(self, key, value)
91
+ except AttributeError as err:
92
+ raise err
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
96
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
97
+ return cls.from_dict(config_dict, **kwargs)
98
+
99
+ @classmethod
100
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
101
+ with open(json_file, "r", encoding="utf-8") as reader:
102
+ text = reader.read()
103
+ return json.loads(text)
104
+
105
+ @classmethod
106
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
107
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
108
+
109
+ config = cls(**config_dict)
110
+
111
+ if hasattr(config, "pruned_heads"):
112
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
113
+
114
+ # Update config with kwargs if needed
115
+ to_remove = []
116
+ for key, value in kwargs.items():
117
+ if hasattr(config, key):
118
+ setattr(config, key, value)
119
+ to_remove.append(key)
120
+ for key in to_remove:
121
+ kwargs.pop(key, None)
122
+
123
+ if return_unused_kwargs:
124
+ return config, kwargs
125
+ else:
126
+ return config
127
+
128
+ @classmethod
129
+ def get_config_dict(
130
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
131
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
132
+ cache_dir = kwargs.pop("cache_dir", None)
133
+ force_download = kwargs.pop("force_download", False)
134
+ resume_download = kwargs.pop("resume_download", False)
135
+ proxies = kwargs.pop("proxies", None)
136
+ use_auth_token = kwargs.pop("use_auth_token", None)
137
+ local_files_only = kwargs.pop("local_files_only", False)
138
+ revision = kwargs.pop("revision", None)
139
+
140
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
141
+ if os.path.isdir(pretrained_model_name_or_path):
142
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
143
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
144
+ config_file = pretrained_model_name_or_path
145
+ else:
146
+ config_file = hf_bucket_url(
147
+ pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
148
+ )
149
+
150
+ try:
151
+ # Load from URL or cache if already cached
152
+ resolved_config_file = cached_path(
153
+ config_file,
154
+ cache_dir=cache_dir,
155
+ force_download=force_download,
156
+ proxies=proxies,
157
+ resume_download=resume_download,
158
+ local_files_only=local_files_only,
159
+ use_auth_token=use_auth_token,
160
+ )
161
+ # Load config dict
162
+ config_dict = cls._dict_from_json_file(resolved_config_file)
163
+
164
+ except EnvironmentError as err:
165
+ msg = (
166
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
167
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
168
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
169
+ )
170
+ raise EnvironmentError(msg)
171
+
172
+ except json.JSONDecodeError:
173
+ msg = (
174
+ "Couldn't reach server at '{}' to download configuration file or "
175
+ "configuration file is not a valid JSON file. "
176
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
177
+ )
178
+ raise EnvironmentError(msg)
179
+
180
+ return config_dict, kwargs
181
+
182
+
183
+ class BertConfig(PretrainedConfig):
184
+ model_type = "bert"
185
+
186
+ def __init__(
187
+ self,
188
+ vocab_size=30522,
189
+ hidden_size=768,
190
+ num_hidden_layers=12,
191
+ num_attention_heads=12,
192
+ intermediate_size=3072,
193
+ hidden_act="gelu",
194
+ hidden_dropout_prob=0.1,
195
+ attention_probs_dropout_prob=0.1,
196
+ max_position_embeddings=512,
197
+ type_vocab_size=2,
198
+ initializer_range=0.02,
199
+ layer_norm_eps=1e-12,
200
+ pad_token_id=0,
201
+ gradient_checkpointing=False,
202
+ position_embedding_type="absolute",
203
+ use_cache=True,
204
+ **kwargs
205
+ ):
206
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
207
+
208
+ self.vocab_size = vocab_size
209
+ self.hidden_size = hidden_size
210
+ self.num_hidden_layers = num_hidden_layers
211
+ self.num_attention_heads = num_attention_heads
212
+ self.hidden_act = hidden_act
213
+ self.intermediate_size = intermediate_size
214
+ self.hidden_dropout_prob = hidden_dropout_prob
215
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
216
+ self.max_position_embeddings = max_position_embeddings
217
+ self.type_vocab_size = type_vocab_size
218
+ self.initializer_range = initializer_range
219
+ self.layer_norm_eps = layer_norm_eps
220
+ self.gradient_checkpointing = gradient_checkpointing
221
+ self.position_embedding_type = position_embedding_type
222
+ self.use_cache = use_cache
constants.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ DATA_DIR = 'minbert-data'
4
+ MODEL_DIR = 'minbert-model'
5
+
6
+
7
+ # Pretrained weights
8
+ SUP_BERT = os.path.join(MODEL_DIR, 'sup-cse-bert.pth')
9
+ UNSUP_BERT = os.path.join(MODEL_DIR, 'unsup-cse-bert.pth')
10
+
11
+
12
+ # CFIMDB dataset
13
+ IDS_CFIMDB_DEV = os.path.join(DATA_DIR, 'ids-cfimdb-dev.csv')
14
+ IDS_CFIMDB_TEST = os.path.join(DATA_DIR, 'ids-cfimdb-test-student.csv')
15
+ IDS_CFIMDB_TRAIN = os.path.join(DATA_DIR, 'ids-cfimdb-train.csv')
16
+
17
+ # SST dataset
18
+ IDS_SST_DEV = os.path.join(DATA_DIR, 'ids-sst-dev.csv')
19
+ IDS_SST_TEST = os.path.join(DATA_DIR, 'ids-sst-test-student.csv')
20
+ IDS_SST_TRAIN = os.path.join(DATA_DIR, 'ids-sst-train.csv')
21
+
22
+ # SimCSE train/dev dataset
23
+ NLI_TRAIN = os.path.join(DATA_DIR, 'nli-train.parquet')
24
+ AMAZON_POLARITY = os.path.join(DATA_DIR, 'amazon-polarity.parquet')
25
+ STSB_DEV = os.path.join(DATA_DIR, 'stsb-dev.parquet')
26
+
27
+
28
+ # Training-specific constants
29
+ SEED=11711
30
+ NUM_CPU_CORES=4
31
+ EPOCHS=10
32
+ USE_GPU=True
33
+ BATCH_SIZE_CSE=8
34
+ BATCH_SIZE_SST=64
35
+ BATCH_SIZE_CFIMDB=8
36
+ HIDDEN_DROPOUT_PROB=0.3
everything.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import torch.nn.functional as F
8
+
9
+ from tqdm import tqdm
10
+ from torch import nn, Tensor
11
+ from types import SimpleNamespace
12
+ from scipy.stats import spearmanr
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from sklearn.metrics import f1_score, accuracy_score
15
+
16
+ from constants import *
17
+
18
+ import random, numpy as np, argparse
19
+ from types import SimpleNamespace
20
+ import csv
finetune-scripts/sup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from finetuning import finetune_bert
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ mode='sup',
7
+ filepath='/minbert-model/sup-cse-bert.pth',
8
+ batch_size_train=12,
9
+ batch_size_dev=64,
10
+ temp=0.05, lr=1e-5,
11
+ )
12
+ finetune_bert(ARGUMENTS)
finetune-scripts/unsup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from finetuning import finetune_bert
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ mode='unsup',
7
+ filepath='/minbert-model/unsup-cse-bert.pth',
8
+ batch_size_train=8,
9
+ batch_size_dev=64,
10
+ temp=0.05, lr=1e-5,
11
+ )
12
+ finetune_bert(ARGUMENTS)
finetuning.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from classifier_utils import *
2
+
3
+
4
+ TQDM_DISABLE=True
5
+
6
+
7
+ def unsup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
8
+ '''
9
+ embeds_1: [batch_size, hidden_size]
10
+ embeds_2: [batch_size, hidden_size]
11
+ '''
12
+
13
+ # [batch_size, batch_size]
14
+ sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
15
+
16
+ # [batch_size]
17
+ positive_sim = torch.diagonal(sim_matrix)
18
+
19
+ # [batch_size]
20
+ nume = torch.exp(positive_sim)
21
+
22
+ # [batch_size]
23
+ deno = torch.exp(sim_matrix).sum(1)
24
+
25
+ # [batch_size]
26
+ loss_per_batch = -torch.log(nume / deno)
27
+
28
+ return loss_per_batch.sum()
29
+
30
+
31
+ def sup_contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, embeds_3: Tensor, temp=0.05):
32
+ '''
33
+ embeds_1: [batch_size, hidden_size]
34
+ embeds_2: [batch_size, hidden_size]
35
+ embeds_3: [batch_size, hidden_size]
36
+ '''
37
+
38
+ # [batch_size, batch_size]
39
+ pos_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
40
+ neg_sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_3.unsqueeze(0), dim=-1) / temp
41
+
42
+ # [batch_size]
43
+ positive_sim = torch.diagonal(pos_sim_matrix)
44
+
45
+ # [batch_size]
46
+ nume = torch.exp(positive_sim)
47
+
48
+ # [batch_size]
49
+ deno = (torch.exp(pos_sim_matrix) + torch.exp(neg_sim_matrix)).sum(1)
50
+
51
+ # [batch_size]
52
+ loss_per_batch = -torch.log(nume / deno)
53
+
54
+ return loss_per_batch.sum()
55
+
56
+
57
+ def sts_eval(dataloader, model: BertModel, device):
58
+ model.eval()
59
+ y_true = []
60
+ y_pred = []
61
+ sent_ids = []
62
+
63
+ with torch.no_grad():
64
+ for batch in tqdm(dataloader, desc='eval', leave=False, disable=TQDM_DISABLE):
65
+ token_ids_1 = batch['token_ids_1'].to(device)
66
+ token_ids_2 = batch['token_ids_2'].to(device)
67
+ attention_mask_1 = batch['attention_mask_1'].to(device)
68
+ attention_mask_2 = batch['attention_mask_2'].to(device)
69
+
70
+ scores = batch['score']
71
+ b_sent_ids = batch['sent_ids']
72
+
73
+ logits_1 = model(token_ids_1, attention_mask_1)['pooler_output']
74
+ logits_2 = model(token_ids_2, attention_mask_2)['pooler_output']
75
+
76
+ sim = F.cosine_similarity(logits_1, logits_2)
77
+ y_true.extend(scores)
78
+ y_pred.extend(sim.cpu().tolist())
79
+ sent_ids.extend(b_sent_ids)
80
+
81
+ spearman_corr, _ = spearmanr(y_pred, y_true)
82
+ return spearman_corr, b_sent_ids
83
+
84
+
85
+ def finetune_bert(args):
86
+ '''
87
+ Finetuning Baseline
88
+ -------------------
89
+ 1. Load the Amazon Polarity (train) and STS Dataset (dev).
90
+ 2. Initialize pretrained minBERT
91
+ 3. Looping through 10 epoches.
92
+ 4. Calculate batches' SimCSE loss function.
93
+ 5. Backpropagation using Adam Optimizer.
94
+ 6. Evaluation on dev dataset:
95
+ 6.1. Create two [CLS] embeddings for given pair.
96
+ 6.2. Calculate their cosine similarity (0 <= sim <= 1).
97
+ 6.3. Spearman's correlation between calculated sim and expected sim.
98
+ 7. Better spearman's correlation (dev_acc > best_dev_acc) -> save_model(...).
99
+ '''
100
+
101
+ assert args.mode in ['unsup', 'sup']
102
+
103
+ seed_everything(SEED)
104
+ torch.set_num_threads(NUM_CPU_CORES)
105
+
106
+ if args.mode == 'unsup':
107
+ train_dataset = AmazonDataset(load_data(AMAZON_POLARITY, 'amazon'))
108
+ else:
109
+ train_dataset = InferenceDataset(load_data(NLI_TRAIN, 'nli'))
110
+
111
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_train,
112
+ num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn)
113
+
114
+ sts_dataset = SemanticDataset(load_data(STSB_DEV, 'stsb'))
115
+ sts_dataloader = DataLoader(sts_dataset, shuffle=False, batch_size=args.batch_size_dev,
116
+ num_workers=NUM_CPU_CORES, collate_fn=sts_dataset.collate_fn)
117
+
118
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
119
+ model = BertModel.from_pretrained('bert-base-uncased')
120
+ model.to(device)
121
+
122
+ best_dev_acc = 0
123
+ optimizer = AdamW(model.parameters(), lr=args.lr)
124
+
125
+ print(f'Finetuning minBERT with {args.mode}ervised method...')
126
+
127
+ for epoch in range(EPOCHS):
128
+ model.train()
129
+ train_loss = num_batches = 0
130
+
131
+ for batch in tqdm(train_dataloader, f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
132
+ if args.mode == 'unsup':
133
+ b_ids = batch['token_ids'].to(device)
134
+ b_mask = batch['attention_mask'].to(device)
135
+
136
+ # Get different embeddings with different dropout masks
137
+ logits_1 = model(b_ids, b_mask)['pooler_output']
138
+ logits_2 = model(b_ids, b_mask)['pooler_output']
139
+
140
+ # Calculate mean SimCSE loss function
141
+ loss = unsup_contrastive_loss(logits_1, logits_2, args.temp)
142
+
143
+ else:
144
+ b_anchor_ids = batch['anchor_ids'].to(device)
145
+ b_positive_ids = batch['positive_ids'].to(device)
146
+ b_negative_ids = batch['negative_ids'].to(device)
147
+ b_anchor_masks = batch['anchor_masks'].to(device)
148
+ b_positive_masks = batch['positive_masks'].to(device)
149
+ b_negative_masks = batch['negative_masks'].to(device)
150
+
151
+ logits_1 = model(b_anchor_ids, b_anchor_masks)['pooler_output']
152
+ logits_2 = model(b_positive_ids, b_positive_masks)['pooler_output']
153
+ logits_3 = model(b_negative_ids, b_negative_masks)['pooler_output']
154
+
155
+ loss = sup_contrastive_loss(logits_1, logits_2, logits_3, args.temp)
156
+
157
+ # Back propagation
158
+ optimizer.zero_grad()
159
+ loss.backward()
160
+ optimizer.step()
161
+
162
+ train_loss += loss.item()
163
+ num_batches += 1
164
+
165
+ train_loss /= num_batches
166
+ dev_acc, _ = sts_eval(sts_dataloader, model, device)
167
+
168
+ if dev_acc > best_dev_acc:
169
+ best_dev_acc = dev_acc
170
+ torch.save(model.state_dict(), args.filepath)
171
+ print(f"save the model to {args.filepath}")
172
+
173
+ print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, dev acc :: {dev_acc :.3f}")
minbert-data/amazon-polarity.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cbb5fc18093875baac0f49ae926f76aa4938e3e9b8114d9a1d95fa05810d8e4
3
+ size 31246252
minbert-data/ids-cfimdb-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
3
+ size 249095
minbert-data/ids-cfimdb-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
3
+ size 495595
minbert-data/ids-cfimdb-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
3
+ size 1693182
minbert-data/ids-sst-dev.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
3
+ size 151384
minbert-data/ids-sst-test-student.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
3
+ size 313202
minbert-data/ids-sst-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
3
+ size 1175139
minbert-data/nli-train.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc8e9bffa6e24c1175be20aa2d064dae99501937d7dec52a341f23f75eaeaec8
3
+ size 28964735
minbert-data/optimizer_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
3
+ size 152
minbert-data/sanity_check.data ADDED
Binary file (56.4 kB). View file
 
minbert-data/stsb-dev.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
3
+ size 142187
minbert-model/sup-cse-bert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ba7bedbe15db3ce7345fa9cba47dc281d4ddb34512fe745445468adbe6abd08
3
+ size 438007966
minbert-model/unsup-cse-bert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a8cfeab46f5903b3297f9536c29516ec96c9bef525ecf13d33494dbd0edebd7
3
+ size 438008374
optimizer.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable, Tuple
2
+ import math
3
+
4
+ import torch
5
+ from torch.optim import Optimizer
6
+
7
+
8
+ class AdamW(Optimizer):
9
+ def __init__(
10
+ self,
11
+ params: Iterable[torch.nn.parameter.Parameter],
12
+ lr: float = 1e-3,
13
+ betas: Tuple[float, float] = (0.9, 0.999),
14
+ eps: float = 1e-6,
15
+ weight_decay: float = 0.0,
16
+ correct_bias: bool = True,
17
+ ):
18
+ if lr < 0.0:
19
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
20
+ if not 0.0 <= betas[0] < 1.0:
21
+ raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
22
+ if not 0.0 <= betas[1] < 1.0:
23
+ raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
24
+ if not 0.0 <= eps:
25
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
26
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
27
+ super().__init__(params, defaults)
28
+
29
+ def step(self, closure: Callable = None):
30
+ loss = None
31
+ if closure is not None:
32
+ loss = closure()
33
+
34
+ for group in self.param_groups:
35
+ for p in group["params"]:
36
+ if p.grad is None:
37
+ continue
38
+ grad = p.grad.data
39
+ if grad.is_sparse:
40
+ raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
41
+
42
+ # Access state
43
+ state = self.state[p]
44
+
45
+ # Initialize state if not already done
46
+ if len(state) == 0:
47
+ state["step"] = 0
48
+ state["exp_avg"] = torch.zeros_like(p.data)
49
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
50
+
51
+ # Hyperparameters
52
+ alpha = group["lr"]
53
+ beta1, beta2 = group["betas"]
54
+ eps = group["eps"]
55
+ weight_decay = group["weight_decay"]
56
+ correct_bias = group["correct_bias"]
57
+
58
+ # Retrieve state variables
59
+ exp_avg = state["exp_avg"]
60
+ exp_avg_sq = state["exp_avg_sq"]
61
+ step = state["step"]
62
+
63
+ # Update step
64
+ step += 1
65
+ state["step"] = step
66
+
67
+ # Update biased first and second moment estimates
68
+ exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
69
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
70
+
71
+ # Compute bias-corrected moments
72
+ if correct_bias:
73
+ bias_correction1 = 1 - beta1 ** step
74
+ bias_correction2 = 1 - beta2 ** step
75
+ exp_avg_corr = exp_avg / bias_correction1
76
+ exp_avg_sq_corr = exp_avg_sq / bias_correction2
77
+ else:
78
+ exp_avg_corr = exp_avg
79
+ exp_avg_sq_corr = exp_avg_sq
80
+
81
+ # Update parameters
82
+ denom = exp_avg_sq_corr.sqrt().add_(eps)
83
+ step_size = alpha
84
+ p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
85
+
86
+ # Apply weight decay
87
+ if weight_decay != 0:
88
+ p.data.add_(p.data, alpha=-alpha * weight_decay)
89
+
90
+ return loss
optimizer_test.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from optimizer import AdamW
5
+ from constants import DATA_DIR
6
+
7
+ seed = 0
8
+
9
+
10
+ def test_optimizer(opt_class) -> torch.Tensor:
11
+ rng = np.random.default_rng(seed)
12
+ torch.manual_seed(seed)
13
+ model = torch.nn.Linear(3, 2, bias=False)
14
+ opt = opt_class(
15
+ model.parameters(),
16
+ lr=1e-3,
17
+ weight_decay=1e-4,
18
+ correct_bias=True,
19
+ )
20
+ for i in range(1000):
21
+ opt.zero_grad()
22
+ x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
23
+ y_hat = model(x)
24
+ y = torch.Tensor([x[0] + x[1], -x[2]])
25
+ loss = ((y - y_hat) ** 2).sum()
26
+ loss.backward()
27
+ opt.step()
28
+ return model.weight.detach()
29
+
30
+
31
+ ref = torch.tensor(np.load(os.path.join(DATA_DIR, "optimizer_test.npy")))
32
+ actual = test_optimizer(AdamW)
33
+ print(ref)
34
+ print(actual)
35
+ assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
36
+ print("Optimizer test passed!")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ tqdm==4.58.0
3
+ requests==2.25.1
4
+ importlib-metadata==3.7.0
5
+ filelock==3.16.1
6
+ sklearn==0.0
7
+ tokenizers==0.15
8
+ explainaboard_client==0.0.7
9
+ pandas==2.2.3
sanity_check.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from bert import BertModel
5
+ from constants import DATA_DIR
6
+
7
+ sanity_data = torch.load(os.path.join(DATA_DIR, "sanity_check.data"), weights_only=True)
8
+ sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
9
+ [101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
10
+ att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
11
+
12
+ # Load model.
13
+ bert = BertModel.from_pretrained('bert-base-uncased')
14
+ outputs = bert(sent_ids, att_mask)
15
+ att_mask = att_mask.unsqueeze(-1)
16
+ outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
17
+ sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
18
+
19
+ for k in ['last_hidden_state', 'pooler_output']:
20
+ assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
21
+ print("Your BERT implementation is correct!")
tokenizer.py ADDED
The diff for this file is too large to render. See raw diff
 
train-scripts/base_cfimdb_onfm.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ dataset='cfimdb',
7
+ batch_size=BATCH_SIZE_CFIMDB,
8
+ train=IDS_CFIMDB_TRAIN,
9
+ dev=IDS_CFIMDB_DEV,
10
+ test=IDS_CFIMDB_TEST,
11
+ lr=1e-5,
12
+ fine_tune_mode='full-model'
13
+ )
14
+
15
+ classifier_run(ARGUMENTS)
train-scripts/base_cfimdb_onll.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ dataset='cfimdb',
7
+ batch_size=BATCH_SIZE_CFIMDB,
8
+ train=IDS_CFIMDB_TRAIN,
9
+ dev=IDS_CFIMDB_DEV,
10
+ test=IDS_CFIMDB_TEST,
11
+ lr=1e-3,
12
+ fine_tune_mode='last-linear-layer'
13
+ )
14
+
15
+ classifier_run(ARGUMENTS)
train-scripts/base_sst_onfm.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ dataset='sst',
7
+ batch_size=BATCH_SIZE_SST,
8
+ train=IDS_SST_TRAIN,
9
+ dev=IDS_SST_DEV,
10
+ test=IDS_SST_TEST,
11
+ lr=1e-5,
12
+ fine_tune_mode='full-model'
13
+ )
14
+
15
+ classifier_run(ARGUMENTS)
train-scripts/base_sst_onll.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+
5
+ ARGUMENTS = SimpleNamespace(
6
+ dataset='sst',
7
+ batch_size=BATCH_SIZE_SST,
8
+ train=IDS_SST_TRAIN,
9
+ dev=IDS_SST_DEV,
10
+ test=IDS_SST_TEST,
11
+ lr=1e-3,
12
+ fine_tune_mode='last-linear-layer'
13
+ )
14
+
15
+ classifier_run(ARGUMENTS)
train-scripts/finetuned_bert.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from everything import *
2
+ from bert import BertModel
3
+
4
+ def get_finetuned_bert(mode: str):
5
+ assert mode in ['sup', 'unsup']
6
+
7
+ bert = BertModel.from_pretrained('bert-base-uncased')
8
+ if mode == 'sup':
9
+ state_dict = torch.load(SUP_BERT, weights_only=True)
10
+ else:
11
+ state_dict = torch.load(UNSUP_BERT, weights_only=True)
12
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
13
+
14
+ bert.load_state_dict(state_dict)
15
+ return bert.to(device)
train-scripts/sup_cfimdb_onfm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='cfimdb',
8
+ batch_size=BATCH_SIZE_CFIMDB,
9
+ train=IDS_CFIMDB_TRAIN,
10
+ dev=IDS_CFIMDB_DEV,
11
+ test=IDS_CFIMDB_TEST,
12
+ lr=1e-5,
13
+ fine_tune_mode='full-model'
14
+ )
15
+
16
+ bert = get_finetuned_bert('sup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/sup_cfimdb_onll.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='cfimdb',
8
+ batch_size=BATCH_SIZE_CFIMDB,
9
+ train=IDS_CFIMDB_TRAIN,
10
+ dev=IDS_CFIMDB_DEV,
11
+ test=IDS_CFIMDB_TEST,
12
+ lr=1e-3,
13
+ fine_tune_mode='last-linear-layer'
14
+ )
15
+
16
+ bert = get_finetuned_bert('sup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/sup_sst_onfm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='sst',
8
+ batch_size=BATCH_SIZE_SST,
9
+ train=IDS_SST_TRAIN,
10
+ dev=IDS_SST_DEV,
11
+ test=IDS_SST_TEST,
12
+ lr=1e-5,
13
+ fine_tune_mode='full-model'
14
+ )
15
+
16
+ bert = get_finetuned_bert('sup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/sup_sst_onll.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='sst',
8
+ batch_size=BATCH_SIZE_SST,
9
+ train=IDS_SST_TRAIN,
10
+ dev=IDS_SST_DEV,
11
+ test=IDS_SST_TEST,
12
+ lr=1e-3,
13
+ fine_tune_mode='last-linear-layer'
14
+ )
15
+
16
+ bert = get_finetuned_bert('sup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/unsup_cfimdb_onfm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='cfimdb',
8
+ batch_size=BATCH_SIZE_CFIMDB,
9
+ train=IDS_CFIMDB_TRAIN,
10
+ dev=IDS_CFIMDB_DEV,
11
+ test=IDS_CFIMDB_TEST,
12
+ lr=1e-5,
13
+ fine_tune_mode='full-model'
14
+ )
15
+
16
+ bert = get_finetuned_bert('unsup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/unsup_cfimdb_onll.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='cfimdb',
8
+ batch_size=BATCH_SIZE_CFIMDB,
9
+ train=IDS_CFIMDB_TRAIN,
10
+ dev=IDS_CFIMDB_DEV,
11
+ test=IDS_CFIMDB_TEST,
12
+ lr=1e-3,
13
+ fine_tune_mode='last-linear-layer'
14
+ )
15
+
16
+ bert = get_finetuned_bert('unsup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/unsup_sst_onfm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='sst',
8
+ batch_size=BATCH_SIZE_SST,
9
+ train=IDS_SST_TRAIN,
10
+ dev=IDS_SST_DEV,
11
+ test=IDS_SST_TEST,
12
+ lr=1e-5,
13
+ fine_tune_mode='full-model'
14
+ )
15
+
16
+ bert = get_finetuned_bert('unsup')
17
+ classifier_run(ARGUMENTS, bert)
train-scripts/unsup_sst_onll.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ from types import SimpleNamespace
3
+ from classifier import classifier_run
4
+ from .finetuned_bert import get_finetuned_bert
5
+
6
+ ARGUMENTS = SimpleNamespace(
7
+ dataset='sst',
8
+ batch_size=BATCH_SIZE_SST,
9
+ train=IDS_SST_TRAIN,
10
+ dev=IDS_SST_DEV,
11
+ test=IDS_SST_TEST,
12
+ lr=1e-3,
13
+ fine_tune_mode='last-linear-layer'
14
+ )
15
+
16
+ bert = get_finetuned_bert('unsup')
17
+ classifier_run(ARGUMENTS, bert)
utils.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union, Tuple, BinaryIO
2
+ import os
3
+ import sys
4
+ import json
5
+ import shutil
6
+ import tempfile
7
+ import copy
8
+ from tqdm.auto import tqdm
9
+ from functools import partial
10
+ from urllib.parse import urlparse
11
+ from pathlib import Path
12
+ import requests
13
+ from hashlib import sha256
14
+ from filelock import FileLock
15
+ import importlib_metadata
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import fnmatch
20
+
21
+
22
+ __version__ = "4.0.0"
23
+ _torch_version = importlib_metadata.version("torch")
24
+
25
+ hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
26
+ default_cache_path = os.path.join(hf_cache_home, "transformers")
27
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
28
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
29
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
30
+
31
+ PRESET_MIRROR_DICT = {
32
+ "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
33
+ "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
34
+ }
35
+ HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
36
+ WEIGHTS_NAME = "pytorch_model.bin"
37
+ CONFIG_NAME = "config.json"
38
+
39
+
40
+ def is_torch_available():
41
+ return True
42
+
43
+
44
+ def is_tf_available():
45
+ return False
46
+
47
+
48
+ def is_remote_url(url_or_filename):
49
+ parsed = urlparse(url_or_filename)
50
+ return parsed.scheme in ("http", "https")
51
+
52
+
53
+ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
54
+ headers = copy.deepcopy(headers)
55
+ if resume_size > 0:
56
+ headers["Range"] = "bytes=%d-" % (resume_size,)
57
+ r = requests.get(url, stream=True, proxies=proxies, headers=headers)
58
+ r.raise_for_status()
59
+ content_length = r.headers.get("Content-Length")
60
+ total = resume_size + int(content_length) if content_length is not None else None
61
+ progress = tqdm(
62
+ unit="B",
63
+ unit_scale=True,
64
+ total=total,
65
+ initial=resume_size,
66
+ desc="Downloading",
67
+ disable=False,
68
+ )
69
+ for chunk in r.iter_content(chunk_size=1024):
70
+ if chunk: # filter out keep-alive new chunks
71
+ progress.update(len(chunk))
72
+ temp_file.write(chunk)
73
+ progress.close()
74
+
75
+
76
+ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
77
+ url_bytes = url.encode("utf-8")
78
+ filename = sha256(url_bytes).hexdigest()
79
+
80
+ if etag:
81
+ etag_bytes = etag.encode("utf-8")
82
+ filename += "." + sha256(etag_bytes).hexdigest()
83
+
84
+ if url.endswith(".h5"):
85
+ filename += ".h5"
86
+
87
+ return filename
88
+
89
+
90
+ def hf_bucket_url(
91
+ model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
92
+ ) -> str:
93
+ if subfolder is not None:
94
+ filename = f"{subfolder}/{filename}"
95
+
96
+ if mirror:
97
+ endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
98
+ legacy_format = "/" not in model_id
99
+ if legacy_format:
100
+ return f"{endpoint}/{model_id}-{filename}"
101
+ else:
102
+ return f"{endpoint}/{model_id}/{filename}"
103
+
104
+ if revision is None:
105
+ revision = "main"
106
+ return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
107
+
108
+
109
+ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
110
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
111
+ if is_torch_available():
112
+ ua += f"; torch/{_torch_version}"
113
+ if is_tf_available():
114
+ ua += f"; tensorflow/{_tf_version}"
115
+ if isinstance(user_agent, dict):
116
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
117
+ elif isinstance(user_agent, str):
118
+ ua += "; " + user_agent
119
+ return ua
120
+
121
+
122
+ def get_from_cache(
123
+ url: str,
124
+ cache_dir=None,
125
+ force_download=False,
126
+ proxies=None,
127
+ etag_timeout=10,
128
+ resume_download=False,
129
+ user_agent: Union[Dict, str, None] = None,
130
+ use_auth_token: Union[bool, str, None] = None,
131
+ local_files_only=False,
132
+ ) -> Optional[str]:
133
+ if cache_dir is None:
134
+ cache_dir = TRANSFORMERS_CACHE
135
+ if isinstance(cache_dir, Path):
136
+ cache_dir = str(cache_dir)
137
+
138
+ os.makedirs(cache_dir, exist_ok=True)
139
+
140
+ headers = {"user-agent": http_user_agent(user_agent)}
141
+ if isinstance(use_auth_token, str):
142
+ headers["authorization"] = "Bearer {}".format(use_auth_token)
143
+ elif use_auth_token:
144
+ token = HfFolder.get_token()
145
+ if token is None:
146
+ raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
147
+ headers["authorization"] = "Bearer {}".format(token)
148
+
149
+ url_to_download = url
150
+ etag = None
151
+ if not local_files_only:
152
+ try:
153
+ r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
154
+ r.raise_for_status()
155
+ etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
156
+ # We favor a custom header indicating the etag of the linked resource, and
157
+ # we fallback to the regular etag header.
158
+ # If we don't have any of those, raise an error.
159
+ if etag is None:
160
+ raise OSError(
161
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
162
+ )
163
+ # In case of a redirect,
164
+ # save an extra redirect on the request.get call,
165
+ # and ensure we download the exact atomic version even if it changed
166
+ # between the HEAD and the GET (unlikely, but hey).
167
+ if 300 <= r.status_code <= 399:
168
+ url_to_download = r.headers["Location"]
169
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
170
+ # etag is already None
171
+ pass
172
+
173
+ filename = url_to_filename(url, etag)
174
+
175
+ # get cache path to put the file
176
+ cache_path = os.path.join(cache_dir, filename)
177
+
178
+ # etag is None == we don't have a connection or we passed local_files_only.
179
+ # try to get the last downloaded one
180
+ if etag is None:
181
+ if os.path.exists(cache_path):
182
+ return cache_path
183
+ else:
184
+ matching_files = [
185
+ file
186
+ for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
187
+ if not file.endswith(".json") and not file.endswith(".lock")
188
+ ]
189
+ if len(matching_files) > 0:
190
+ return os.path.join(cache_dir, matching_files[-1])
191
+ else:
192
+ # If files cannot be found and local_files_only=True,
193
+ # the models might've been found if local_files_only=False
194
+ # Notify the user about that
195
+ if local_files_only:
196
+ raise FileNotFoundError(
197
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
198
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
199
+ " to False."
200
+ )
201
+ else:
202
+ raise ValueError(
203
+ "Connection error, and we cannot find the requested files in the cached path."
204
+ " Please try again or make sure your Internet connection is on."
205
+ )
206
+
207
+ # From now on, etag is not None.
208
+ if os.path.exists(cache_path) and not force_download:
209
+ return cache_path
210
+
211
+ # Prevent parallel downloads of the same file with a lock.
212
+ lock_path = cache_path + ".lock"
213
+ with FileLock(lock_path):
214
+
215
+ # If the download just completed while the lock was activated.
216
+ if os.path.exists(cache_path) and not force_download:
217
+ # Even if returning early like here, the lock will be released.
218
+ return cache_path
219
+
220
+ if resume_download:
221
+ incomplete_path = cache_path + ".incomplete"
222
+
223
+ @contextmanager
224
+ def _resumable_file_manager() -> "io.BufferedWriter":
225
+ with open(incomplete_path, "ab") as f:
226
+ yield f
227
+
228
+ temp_file_manager = _resumable_file_manager
229
+ if os.path.exists(incomplete_path):
230
+ resume_size = os.stat(incomplete_path).st_size
231
+ else:
232
+ resume_size = 0
233
+ else:
234
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
235
+ resume_size = 0
236
+
237
+ # Download to temporary file, then copy to cache dir once finished.
238
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
239
+ with temp_file_manager() as temp_file:
240
+ http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
241
+
242
+ os.replace(temp_file.name, cache_path)
243
+
244
+ meta = {"url": url, "etag": etag}
245
+ meta_path = cache_path + ".json"
246
+ with open(meta_path, "w") as meta_file:
247
+ json.dump(meta, meta_file)
248
+
249
+ return cache_path
250
+
251
+
252
+ def cached_path(
253
+ url_or_filename,
254
+ cache_dir=None,
255
+ force_download=False,
256
+ proxies=None,
257
+ resume_download=False,
258
+ user_agent: Union[Dict, str, None] = None,
259
+ extract_compressed_file=False,
260
+ force_extract=False,
261
+ use_auth_token: Union[bool, str, None] = None,
262
+ local_files_only=False,
263
+ ) -> Optional[str]:
264
+ if cache_dir is None:
265
+ cache_dir = TRANSFORMERS_CACHE
266
+ if isinstance(url_or_filename, Path):
267
+ url_or_filename = str(url_or_filename)
268
+ if isinstance(cache_dir, Path):
269
+ cache_dir = str(cache_dir)
270
+
271
+ if is_remote_url(url_or_filename):
272
+ # URL, so get it from the cache (downloading if necessary)
273
+ output_path = get_from_cache(
274
+ url_or_filename,
275
+ cache_dir=cache_dir,
276
+ force_download=force_download,
277
+ proxies=proxies,
278
+ resume_download=resume_download,
279
+ user_agent=user_agent,
280
+ use_auth_token=use_auth_token,
281
+ local_files_only=local_files_only,
282
+ )
283
+ elif os.path.exists(url_or_filename):
284
+ # File, and it exists.
285
+ output_path = url_or_filename
286
+ elif urlparse(url_or_filename).scheme == "":
287
+ # File, but it doesn't exist.
288
+ raise EnvironmentError("file {} not found".format(url_or_filename))
289
+ else:
290
+ # Something unknown
291
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
292
+
293
+ if extract_compressed_file:
294
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
295
+ return output_path
296
+
297
+ # Path where we extract compressed archives
298
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
299
+ output_dir, output_file = os.path.split(output_path)
300
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
301
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
302
+
303
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
304
+ return output_path_extracted
305
+
306
+ # Prevent parallel extractions
307
+ lock_path = output_path + ".lock"
308
+ with FileLock(lock_path):
309
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
310
+ os.makedirs(output_path_extracted)
311
+ if is_zipfile(output_path):
312
+ with ZipFile(output_path, "r") as zip_file:
313
+ zip_file.extractall(output_path_extracted)
314
+ zip_file.close()
315
+ elif tarfile.is_tarfile(output_path):
316
+ tar_file = tarfile.open(output_path)
317
+ tar_file.extractall(output_path_extracted)
318
+ tar_file.close()
319
+ else:
320
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
321
+
322
+ return output_path_extracted
323
+
324
+ return output_path
325
+
326
+
327
+ def get_parameter_dtype(parameter: Union[nn.Module]):
328
+ try:
329
+ return next(parameter.parameters()).dtype
330
+ except StopIteration:
331
+ # For nn.DataParallel compatibility in PyTorch 1.5
332
+
333
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
334
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
335
+ return tuples
336
+
337
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
338
+ first_tuple = next(gen)
339
+ return first_tuple[1].dtype
340
+
341
+
342
+ def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
343
+ # attention_mask [batch_size, seq_length]
344
+ assert attention_mask.dim() == 2
345
+ # [batch_size, 1, 1, seq_length] for multi-head attention
346
+ extended_attention_mask = attention_mask[:, None, None, :]
347
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
348
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
349
+ return extended_attention_mask