GlowCheese commited on
Commit
354d4d3
·
1 Parent(s): dfcb0a5

Remove everything

Browse files
base_bert.py DELETED
@@ -1,248 +0,0 @@
1
- import re
2
- from torch import device, dtype
3
- from config import BertConfig, PretrainedConfig
4
- from utils import *
5
-
6
-
7
- class BertPreTrainedModel(nn.Module):
8
- config_class = BertConfig
9
- base_model_prefix = "bert"
10
- _keys_to_ignore_on_load_missing = [r"position_ids"]
11
- _keys_to_ignore_on_load_unexpected = None
12
-
13
- def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
14
- super().__init__()
15
- self.config = config
16
- self.name_or_path = config.name_or_path
17
-
18
- def init_weights(self):
19
- # Initialize weights
20
- self.apply(self._init_weights)
21
-
22
- def _init_weights(self, module):
23
- """ Initialize the weights """
24
- if isinstance(module, (nn.Linear, nn.Embedding)):
25
- # Slightly different from the TF version which uses truncated_normal for initialization
26
- # cf https://github.com/pytorch/pytorch/pull/5617
27
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
28
- elif isinstance(module, nn.LayerNorm):
29
- module.bias.data.zero_()
30
- module.weight.data.fill_(1.0)
31
- if isinstance(module, nn.Linear) and module.bias is not None:
32
- module.bias.data.zero_()
33
-
34
- @property
35
- def dtype(self) -> dtype:
36
- return get_parameter_dtype(self)
37
-
38
- @classmethod
39
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
40
- config = kwargs.pop("config", None)
41
- state_dict = kwargs.pop("state_dict", None)
42
- cache_dir = kwargs.pop("cache_dir", None)
43
- force_download = kwargs.pop("force_download", False)
44
- resume_download = kwargs.pop("resume_download", False)
45
- proxies = kwargs.pop("proxies", None)
46
- output_loading_info = kwargs.pop("output_loading_info", False)
47
- local_files_only = kwargs.pop("local_files_only", False)
48
- use_auth_token = kwargs.pop("use_auth_token", None)
49
- revision = kwargs.pop("revision", None)
50
- mirror = kwargs.pop("mirror", None)
51
-
52
- # Load config if we don't provide a configuration
53
- if not isinstance(config, PretrainedConfig):
54
- config_path = config if config is not None else pretrained_model_name_or_path
55
- config, model_kwargs = cls.config_class.from_pretrained(
56
- config_path,
57
- *model_args,
58
- cache_dir=cache_dir,
59
- return_unused_kwargs=True,
60
- force_download=force_download,
61
- resume_download=resume_download,
62
- proxies=proxies,
63
- local_files_only=local_files_only,
64
- use_auth_token=use_auth_token,
65
- revision=revision,
66
- **kwargs,
67
- )
68
- else:
69
- model_kwargs = kwargs
70
-
71
- # Load model
72
- if pretrained_model_name_or_path is not None:
73
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
74
- if os.path.isdir(pretrained_model_name_or_path):
75
- # Load from a PyTorch checkpoint
76
- archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
77
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
78
- archive_file = pretrained_model_name_or_path
79
- else:
80
- archive_file = hf_bucket_url(
81
- pretrained_model_name_or_path,
82
- filename=WEIGHTS_NAME,
83
- revision=revision,
84
- mirror=mirror,
85
- )
86
- try:
87
- # Load from URL or cache if already cached
88
- resolved_archive_file = cached_path(
89
- archive_file,
90
- cache_dir=cache_dir,
91
- force_download=force_download,
92
- proxies=proxies,
93
- resume_download=resume_download,
94
- local_files_only=local_files_only,
95
- use_auth_token=use_auth_token,
96
- )
97
- except EnvironmentError as err:
98
- #logger.error(err)
99
- msg = (
100
- f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
101
- f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
102
- f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
103
- )
104
- raise EnvironmentError(msg)
105
- else:
106
- resolved_archive_file = None
107
-
108
- config.name_or_path = pretrained_model_name_or_path
109
-
110
- # Instantiate model.
111
- model = cls(config, *model_args, **model_kwargs)
112
-
113
- if state_dict is None:
114
- try:
115
- state_dict = torch.load(resolved_archive_file, map_location="cpu", weights_only=True)
116
- except Exception:
117
- raise OSError(
118
- f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
119
- f"at '{resolved_archive_file}'"
120
- )
121
-
122
- missing_keys = []
123
- unexpected_keys = []
124
- error_msgs = []
125
-
126
- # Convert old format to new format if needed from a PyTorch state_dict
127
- old_keys = []
128
- new_keys = []
129
- m = {'embeddings.word_embeddings': 'word_embedding',
130
- 'embeddings.position_embeddings': 'pos_embedding',
131
- 'embeddings.token_type_embeddings': 'tk_type_embedding',
132
- 'embeddings.LayerNorm': 'embed_layer_norm',
133
- 'embeddings.dropout': 'embed_dropout',
134
- 'encoder.layer': 'bert_layers',
135
- 'pooler.dense': 'pooler_dense',
136
- 'pooler.activation': 'pooler_af',
137
- 'attention.self': "self_attention",
138
- 'attention.output.dense': 'attention_dense',
139
- 'attention.output.LayerNorm': 'attention_layer_norm',
140
- 'attention.output.dropout': 'attention_dropout',
141
- 'intermediate.dense': 'interm_dense',
142
- 'intermediate.intermediate_act_fn': 'interm_af',
143
- 'output.dense': 'out_dense',
144
- 'output.LayerNorm': 'out_layer_norm',
145
- 'output.dropout': 'out_dropout'}
146
-
147
- for key in state_dict.keys():
148
- new_key = None
149
- if "gamma" in key:
150
- new_key = key.replace("gamma", "weight")
151
- if "beta" in key:
152
- new_key = key.replace("beta", "bias")
153
- for x, y in m.items():
154
- if new_key is not None:
155
- _key = new_key
156
- else:
157
- _key = key
158
- if x in key:
159
- new_key = _key.replace(x, y)
160
- if new_key:
161
- old_keys.append(key)
162
- new_keys.append(new_key)
163
-
164
- for old_key, new_key in zip(old_keys, new_keys):
165
- # print(old_key, new_key)
166
- state_dict[new_key] = state_dict.pop(old_key)
167
-
168
- # copy state_dict so _load_from_state_dict can modify it
169
- metadata = getattr(state_dict, "_metadata", None)
170
- state_dict = state_dict.copy()
171
- if metadata is not None:
172
- state_dict._metadata = metadata
173
-
174
- your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
175
- for k in state_dict:
176
- if k not in your_bert_params and not k.startswith("cls."):
177
- possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
178
- raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")
179
-
180
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
181
- # so we need to apply the function recursively.
182
- def load(module: nn.Module, prefix=""):
183
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
184
- module._load_from_state_dict(
185
- state_dict,
186
- prefix,
187
- local_metadata,
188
- True,
189
- missing_keys,
190
- unexpected_keys,
191
- error_msgs,
192
- )
193
- for name, child in module._modules.items():
194
- if child is not None:
195
- load(child, prefix + name + ".")
196
-
197
- # Make sure we are able to load base models as well as derived models (with heads)
198
- start_prefix = ""
199
- model_to_load = model
200
- has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
201
- if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
202
- start_prefix = cls.base_model_prefix + "."
203
- if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
204
- model_to_load = getattr(model, cls.base_model_prefix)
205
- load(model_to_load, prefix=start_prefix)
206
-
207
- if model.__class__.__name__ != model_to_load.__class__.__name__:
208
- base_model_state_dict = model_to_load.state_dict().keys()
209
- head_model_state_dict_without_base_prefix = [
210
- key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
211
- ]
212
- missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
213
-
214
- # Some models may have keys that are not in the state by design, removing them before needlessly warning
215
- # the user.
216
- if cls._keys_to_ignore_on_load_missing is not None:
217
- for pat in cls._keys_to_ignore_on_load_missing:
218
- missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
219
-
220
- if cls._keys_to_ignore_on_load_unexpected is not None:
221
- for pat in cls._keys_to_ignore_on_load_unexpected:
222
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
223
-
224
- if len(error_msgs) > 0:
225
- raise RuntimeError(
226
- "Error(s) in loading state_dict for {}:\n\t{}".format(
227
- model.__class__.__name__, "\n\t".join(error_msgs)
228
- )
229
- )
230
-
231
- # Set model in evaluation mode to deactivate DropOut modules by default
232
- model.eval()
233
-
234
- if output_loading_info:
235
- loading_info = {
236
- "missing_keys": missing_keys,
237
- "unexpected_keys": unexpected_keys,
238
- "error_msgs": error_msgs,
239
- }
240
- return model, loading_info
241
-
242
- if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
243
- import torch_xla.core.xla_model as xm
244
-
245
- model = xm.send_cpu_data_to_device(model, xm.xla_device())
246
- model.to(xm.xla_device())
247
-
248
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bert.py DELETED
@@ -1,225 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from base_bert import BertPreTrainedModel
6
- from utils import *
7
-
8
-
9
- class BertSelfAttention(nn.Module):
10
- def __init__(self, config):
11
- super().__init__()
12
-
13
- self.num_attention_heads = config.num_attention_heads
14
- self.attention_head_size = config.hidden_size // config.num_attention_heads
15
- self.all_head_size = self.num_attention_heads * self.attention_head_size
16
-
17
- # Initialize the linear transformation layers for key, value, query.
18
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
19
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
20
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
21
- # This dropout is applied to normalized attention scores following the original
22
- # implementation of transformer. Although it is a bit unusual, we empirically
23
- # observe that it yields better performance.
24
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
25
-
26
- def transform(self, x, linear_layer):
27
- # The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
28
- bs, seq_len = x.shape[:2]
29
- proj = linear_layer(x)
30
- # Next, we need to produce multiple heads for the proj. This is done by spliting the
31
- # hidden state to self.num_attention_heads, each of size self.attention_head_size.
32
- proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
33
- # By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
34
- proj = proj.transpose(1, 2)
35
- return proj
36
-
37
- def attention(self, key, query, value, attention_mask):
38
- """
39
- key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
40
- attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
41
- """
42
-
43
- d_k = query.size(-1) # attention_head_size
44
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
45
- # attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
46
-
47
- # Apply attention mask
48
- attention_scores = attention_scores + attention_mask
49
-
50
- # Normalize scores with softmax and apply dropout.
51
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
52
- attention_probs = self.dropout(attention_probs)
53
-
54
- context = torch.matmul(attention_probs, value)
55
- # context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
56
-
57
- # Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
58
- context = context.transpose(1, 2).contiguous()
59
- context = context.view(context.size(0), context.size(1), -1)
60
-
61
- return context
62
-
63
-
64
- def forward(self, hidden_states, attention_mask):
65
- """
66
- hidden_states: [bs, seq_len, hidden_state]
67
- attention_mask: [bs, 1, 1, seq_len]
68
- output: [bs, seq_len, hidden_state]
69
- """
70
- # First, we have to generate the key, value, query for each token for multi-head attention
71
- # using self.transform (more details inside the function).
72
- # Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
73
- key_layer = self.transform(hidden_states, self.key)
74
- value_layer = self.transform(hidden_states, self.value)
75
- query_layer = self.transform(hidden_states, self.query)
76
- # Calculate the multi-head attention.
77
- attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
78
- return attn_value
79
-
80
-
81
- class BertLayer(nn.Module):
82
- def __init__(self, config):
83
- super().__init__()
84
- # Multi-head attention.
85
- self.self_attention = BertSelfAttention(config)
86
- # Add-norm for multi-head attention.
87
- self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
88
- self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
89
- self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
90
- # Feed forward.
91
- self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
92
- self.interm_af = F.gelu
93
- # Add-norm for feed forward.
94
- self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
95
- self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
96
- self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
97
-
98
-
99
- def add_norm(self, input, output, dense_layer, dropout, ln_layer):
100
- transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
101
- transformed_output = dropout(transformed_output) # Áp dụng dropout
102
- added_output = input + transformed_output # Kết hợp input và output
103
- normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
104
- return normalized_output
105
-
106
-
107
- def forward(self, hidden_states, attention_mask):
108
- # 1. Multi-head attention
109
- attention_output = self.self_attention(hidden_states, attention_mask)
110
-
111
- # 2. Add-norm after attention
112
- attention_output = self.add_norm(
113
- hidden_states,
114
- attention_output,
115
- self.attention_dense,
116
- self.attention_dropout,
117
- self.attention_layer_norm
118
- )
119
-
120
- # 3. Feed-forward network
121
- intermediate_output = self.interm_af(self.interm_dense(attention_output))
122
-
123
- # 4. Add-norm after feed-forward
124
- layer_output = self.add_norm(
125
- attention_output,
126
- intermediate_output,
127
- self.out_dense,
128
- self.out_dropout,
129
- self.out_layer_norm
130
- )
131
-
132
- return layer_output
133
-
134
-
135
-
136
-
137
- class BertModel(BertPreTrainedModel):
138
- """
139
- The BERT model returns the final embeddings for each token in a sentence.
140
-
141
- The model consists of:
142
- 1. Embedding layers (used in self.embed).
143
- 2. A stack of n BERT layers (used in self.encode).
144
- 3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
145
- """
146
- def __init__(self, config):
147
- super().__init__(config)
148
- self.config = config
149
-
150
- # Embedding layers.
151
- self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
152
- self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
153
- self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
154
- self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
155
- self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
156
- # Register position_ids (1, len position emb) to buffer because it is a constant.
157
- position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
158
- self.register_buffer('position_ids', position_ids)
159
-
160
- # BERT encoder.
161
- self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
162
-
163
- # [CLS] token transformations.
164
- self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
165
- self.pooler_af = nn.Tanh()
166
-
167
- self.init_weights()
168
-
169
-
170
- def embed(self, input_ids):
171
- input_shape = input_ids.size()
172
- seq_length = input_shape[1]
173
-
174
- inputs_embeds = self.word_embedding(input_ids)
175
-
176
- pos_ids = self.position_ids[:, :seq_length]
177
- pos_embeds = self.pos_embedding(pos_ids)
178
-
179
- # Since we are not considering token type, this embedding is just a placeholder.
180
- tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
181
- tk_type_embeds = self.tk_type_embedding(tk_type_ids)
182
-
183
- embeddings = inputs_embeds + pos_embeds + tk_type_embeds
184
- embeddings = self.embed_layer_norm(embeddings)
185
- embeddings = self.embed_dropout(embeddings)
186
-
187
- return embeddings
188
-
189
-
190
- def encode(self, hidden_states, attention_mask):
191
- """
192
- hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
193
- attention_mask: [batch_size, seq_len]
194
- """
195
- # Get the extended attention mask for self-attention.
196
- # Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
197
- # Distinguishes between non-padding tokens (with a value of 0) and padding tokens
198
- # (with a value of a large negative number).
199
- extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
200
-
201
- # Pass the hidden states through the encoder layers.
202
- for i, layer_module in enumerate(self.bert_layers):
203
- # Feed the encoding from the last bert_layer to the next.
204
- hidden_states = layer_module(hidden_states, extended_attention_mask)
205
-
206
- return hidden_states
207
-
208
-
209
- def forward(self, input_ids, attention_mask):
210
- """
211
- input_ids: [batch_size, seq_len], seq_len is the max length of the batch
212
- attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
213
- """
214
- # Get the embedding for each input token.
215
- embedding_output = self.embed(input_ids=input_ids)
216
-
217
- # Feed to a transformer (a stack of BertLayers).
218
- sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
219
-
220
- # Get cls token hidden state.
221
- first_tk = sequence_output[:, 0]
222
- first_tk = self.pooler_dense(first_tk)
223
- first_tk = self.pooler_af(first_tk)
224
-
225
- return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
classifier.py DELETED
@@ -1,411 +0,0 @@
1
- import random, numpy as np, argparse
2
- from types import SimpleNamespace
3
- import csv
4
-
5
- import torch
6
- from tqdm import tqdm
7
- import torch.nn.functional as F
8
- from torch.utils.data import Dataset, DataLoader
9
- from sklearn.metrics import f1_score, accuracy_score
10
-
11
- from tokenizer import BertTokenizer
12
- from bert import BertModel
13
- from optimizer import AdamW
14
-
15
-
16
- TQDM_DISABLE=True
17
-
18
-
19
- # Fix the random seed.
20
- def seed_everything(seed=11711):
21
- random.seed(seed)
22
- np.random.seed(seed)
23
- torch.manual_seed(seed)
24
- torch.cuda.manual_seed(seed)
25
- torch.cuda.manual_seed_all(seed)
26
- torch.backends.cudnn.benchmark = False
27
- torch.backends.cudnn.deterministic = True
28
-
29
-
30
- class BertSentimentClassifier(torch.nn.Module):
31
- '''
32
- This module performs sentiment classification using BERT embeddings on the SST dataset.
33
-
34
- In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
35
- Thus, your forward() should return one logit for each of the 5 classes.
36
- '''
37
- def __init__(self, config, bert_model = None):
38
- super(BertSentimentClassifier, self).__init__()
39
- self.num_labels = config.num_labels
40
- self.bert: BertModel = bert_model or BertModel.from_pretrained('bert-base-uncased')
41
-
42
- # Pretrain mode does not require updating BERT paramters.
43
- assert config.fine_tune_mode in ["last-linear-layer", "full-model"]
44
- for param in self.bert.parameters():
45
- if config.fine_tune_mode == 'last-linear-layer':
46
- param.requires_grad = False
47
- elif config.fine_tune_mode == 'full-model':
48
- param.requires_grad = True
49
-
50
- # Create any instance variables you need to classify the sentiment of BERT embeddings.
51
- self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
52
- self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
53
-
54
-
55
- def forward(self, input_ids, attention_mask):
56
- '''Takes a batch of sentences and returns logits for sentiment classes'''
57
- # The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
58
- # HINT: You should consider what is an appropriate return value given that
59
- # the training loop currently uses F.cross_entropy as the loss function.
60
-
61
- # Get the embedding for each input token.
62
- outputs = self.bert(input_ids, attention_mask)
63
- pooler_output = outputs['pooler_output']
64
-
65
- # Pass the [CLS] token representation through the classifier.
66
- logits = self.classifier(self.dropout(pooler_output))
67
-
68
- return logits
69
-
70
-
71
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
72
-
73
- class SentimentDataset(Dataset):
74
- def __init__(self, dataset, args):
75
- self.dataset = dataset
76
- self.p = args
77
-
78
- def __len__(self):
79
- return len(self.dataset)
80
-
81
- def __getitem__(self, idx):
82
- return self.dataset[idx]
83
-
84
- def pad_data(self, data):
85
- sents = [x[0] for x in data]
86
- labels = [x[1] for x in data]
87
- sent_ids = [x[2] for x in data]
88
-
89
- encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
90
- token_ids = torch.LongTensor(encoding['input_ids'])
91
- attention_mask = torch.LongTensor(encoding['attention_mask'])
92
- labels = torch.LongTensor(labels)
93
-
94
- return token_ids, attention_mask, labels, sents, sent_ids
95
-
96
- def collate_fn(self, all_data):
97
- token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)
98
-
99
- batched_data = {
100
- 'token_ids': token_ids,
101
- 'attention_mask': attention_mask,
102
- 'labels': labels,
103
- 'sents': sents,
104
- 'sent_ids': sent_ids
105
- }
106
-
107
- return batched_data
108
-
109
-
110
- class SentimentTestDataset(Dataset):
111
- def __init__(self, dataset, args):
112
- self.dataset = dataset
113
- self.p = args
114
-
115
- def __len__(self):
116
- return len(self.dataset)
117
-
118
- def __getitem__(self, idx):
119
- return self.dataset[idx]
120
-
121
- def pad_data(self, data):
122
- sents = [x[0] for x in data]
123
- sent_ids = [x[1] for x in data]
124
-
125
- encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
126
- token_ids = torch.LongTensor(encoding['input_ids'])
127
- attention_mask = torch.LongTensor(encoding['attention_mask'])
128
-
129
- return token_ids, attention_mask, sents, sent_ids
130
-
131
- def collate_fn(self, all_data):
132
- token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
133
-
134
- batched_data = {
135
- 'token_ids': token_ids,
136
- 'attention_mask': attention_mask,
137
- 'sents': sents,
138
- 'sent_ids': sent_ids
139
- }
140
-
141
- return batched_data
142
-
143
-
144
- # Load the data: a list of (sentence, label).
145
- def load_data(filename, flag='train'):
146
- num_labels = set()
147
- data = []
148
- with open(filename, 'r') as fp:
149
- for record in csv.DictReader(fp, delimiter = '\t'):
150
- if flag == 'test':
151
- sent = record['sentence'].lower().strip()
152
- sent_id = record['id'].lower().strip()
153
- data.append((sent,sent_id))
154
- else:
155
- sent = record['sentence'].lower().strip()
156
- sent_id = record['id'].lower().strip()
157
- label = int(record['sentiment'].strip())
158
- num_labels.add(label)
159
- data.append((sent, label, sent_id))
160
- print(f"load {len(data)} data from {filename}")
161
-
162
- if flag == 'train':
163
- return data, len(num_labels)
164
- else:
165
- return data
166
-
167
-
168
- # Evaluate the model on dev examples.
169
- def model_eval(dataloader, model, device):
170
- model.eval() # Switch to eval model, will turn off randomness like dropout.
171
- y_true = []
172
- y_pred = []
173
- sents = []
174
- sent_ids = []
175
- for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
176
- b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
177
- batch['labels'], batch['sents'], batch['sent_ids']
178
-
179
- b_ids = b_ids.to(device)
180
- b_mask = b_mask.to(device)
181
-
182
- logits = model(b_ids, b_mask)
183
- logits = logits.detach().cpu().numpy()
184
- preds = np.argmax(logits, axis=1).flatten()
185
-
186
- b_labels = b_labels.flatten()
187
- y_true.extend(b_labels)
188
- y_pred.extend(preds)
189
- sents.extend(b_sents)
190
- sent_ids.extend(b_sent_ids)
191
-
192
- f1 = f1_score(y_true, y_pred, average='macro')
193
- acc = accuracy_score(y_true, y_pred)
194
-
195
- return acc, f1, y_pred, y_true, sents, sent_ids
196
-
197
-
198
- # Evaluate the model on test examples.
199
- def model_test_eval(dataloader, model, device):
200
- model.eval() # Switch to eval model, will turn off randomness like dropout.
201
- y_pred = []
202
- sents = []
203
- sent_ids = []
204
- for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)):
205
- b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
206
- batch['sents'], batch['sent_ids']
207
-
208
- b_ids = b_ids.to(device)
209
- b_mask = b_mask.to(device)
210
-
211
- logits = model(b_ids, b_mask)
212
- logits = logits.detach().cpu().numpy()
213
- preds = np.argmax(logits, axis=1).flatten()
214
-
215
- y_pred.extend(preds)
216
- sents.extend(b_sents)
217
- sent_ids.extend(b_sent_ids)
218
-
219
- return y_pred, sents, sent_ids
220
-
221
-
222
- def save_model(model, optimizer, args, config, filepath):
223
- save_info = {
224
- 'model': model.state_dict(),
225
- 'optim': optimizer.state_dict(),
226
- 'args': args,
227
- 'model_config': config,
228
- 'system_rng': random.getstate(),
229
- 'numpy_rng': np.random.get_state(),
230
- 'torch_rng': torch.random.get_rng_state(),
231
- }
232
-
233
- torch.save(save_info, filepath)
234
- print(f"save the model to {filepath}")
235
-
236
-
237
- def train(args):
238
- device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
239
- # Create the data and its corresponding datasets and dataloader.
240
- train_data, num_labels = load_data(args.train, 'train')
241
- dev_data = load_data(args.dev, 'valid')
242
-
243
- train_dataset = SentimentDataset(train_data, args)
244
- dev_dataset = SentimentDataset(dev_data, args)
245
-
246
- train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
247
- num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
248
- dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
249
- num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
250
-
251
- # Init model.
252
- config = {'hidden_dropout_prob': args.hidden_dropout_prob,
253
- 'num_labels': num_labels,
254
- 'hidden_size': 768,
255
- 'data_dir': '.',
256
- 'fine_tune_mode': args.fine_tune_mode}
257
-
258
- config = SimpleNamespace(**config)
259
-
260
- model = BertSentimentClassifier(config)
261
- model = model.to(device)
262
-
263
- lr = args.lr
264
- optimizer = AdamW(model.parameters(), lr=lr)
265
- best_dev_acc = 0
266
-
267
- # Run for the specified number of epochs.
268
- for epoch in range(args.epochs):
269
- model.train()
270
- train_loss = 0
271
- num_batches = 0
272
- for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE):
273
- b_ids, b_mask, b_labels = (batch['token_ids'],
274
- batch['attention_mask'], batch['labels'])
275
-
276
- b_ids = b_ids.to(device)
277
- b_mask = b_mask.to(device)
278
- b_labels = b_labels.to(device)
279
-
280
- optimizer.zero_grad()
281
- logits = model(b_ids, b_mask)
282
- loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
283
-
284
- loss.backward()
285
- optimizer.step()
286
-
287
- train_loss += loss.item()
288
- num_batches += 1
289
-
290
- train_loss = train_loss / (num_batches)
291
-
292
- train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
293
- dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
294
-
295
- if dev_acc > best_dev_acc:
296
- best_dev_acc = dev_acc
297
- save_model(model, optimizer, args, config, args.filepath)
298
-
299
- print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
300
-
301
-
302
- def test(args):
303
- with torch.no_grad():
304
- device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
305
- saved = torch.load(args.filepath, weights_only=False)
306
- config = saved['model_config']
307
- model = BertSentimentClassifier(config)
308
- model.load_state_dict(saved['model'])
309
- model = model.to(device)
310
- print(f"load model from {args.filepath}")
311
-
312
- dev_data = load_data(args.dev, 'valid')
313
- dev_dataset = SentimentDataset(dev_data, args)
314
- dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
315
- num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
316
-
317
- dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
318
- print('DONE DEV')
319
- print(f"dev acc :: {dev_acc :.3f}")
320
-
321
- # ---- SKIP RUNNING ON TEST DATASET ---- #
322
- # test_data = load_data(args.test, 'test')
323
- # test_dataset = SentimentTestDataset(test_data, args)
324
- # test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
325
- # num_workers=args.num_cpu_cores, collate_fn=test_dataset.collate_fn)
326
- # test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
327
- # print('DONE TEST')
328
-
329
- # ---- SKIP SAVING PREDICTIONS ----
330
- # with open(args.dev_out, "w+") as f:
331
- # f.write(f"id \t Predicted_Sentiment \n")
332
- # for p, s in zip(dev_sent_ids,dev_pred):
333
- # f.write(f"{p} , {s} \n")
334
- # with open(args.test_out, "w+") as f:
335
- # f.write(f"id \t Predicted_Sentiment \n")
336
- # for p, s in zip(test_sent_ids,test_pred ):
337
- # f.write(f"{p} , {s} \n")
338
-
339
-
340
- def get_args():
341
- parser = argparse.ArgumentParser()
342
- parser.add_argument("--seed", type=int, default=11711)
343
- parser.add_argument("--num-cpu-cores", type=int, default=8)
344
- parser.add_argument("--epochs", type=int, default=10)
345
- parser.add_argument("--fine-tune-mode", type=str,
346
- help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well',
347
- choices=('last-linear-layer', 'full-model'), default="last-linear-layer")
348
- parser.add_argument("--use_gpu", action='store_true')
349
-
350
- parser.add_argument("--batch_size_sst", help='64 can fit a 12GB GPU', type=int, default=64)
351
- parser.add_argument("--batch_size_cfimdb", help='8 can fit a 12GB GPU', type=int, default=8)
352
- parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
353
- parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
354
- default=1e-3)
355
-
356
- args = parser.parse_args()
357
- return args
358
-
359
-
360
- def main():
361
- args = get_args()
362
- seed_everything(args.seed)
363
- torch.set_num_threads(args.num_cpu_cores)
364
-
365
- print('Training Sentiment Classifier on SST...')
366
- config = SimpleNamespace(
367
- filepath='sst-classifier.pt',
368
- lr=args.lr,
369
- num_cpu_cores=args.num_cpu_cores,
370
- use_gpu=args.use_gpu,
371
- epochs=args.epochs,
372
- batch_size=args.batch_size_sst,
373
- hidden_dropout_prob=args.hidden_dropout_prob,
374
- train='data/ids-sst-train.csv',
375
- dev='data/ids-sst-dev.csv',
376
- test='data/ids-sst-test-student.csv',
377
- fine_tune_mode=args.fine_tune_mode,
378
- dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv',
379
- test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv'
380
- )
381
-
382
- train(config)
383
-
384
- print('Evaluating on SST...')
385
- test(config)
386
-
387
- print('Training Sentiment Classifier on cfimdb...')
388
- config = SimpleNamespace(
389
- filepath='cfimdb-classifier.pt',
390
- lr=args.lr,
391
- num_cpu_cores=args.num_cpu_cores,
392
- use_gpu=args.use_gpu,
393
- epochs=args.epochs,
394
- batch_size=args.batch_size_cfimdb,
395
- hidden_dropout_prob=args.hidden_dropout_prob,
396
- train='data/ids-cfimdb-train.csv',
397
- dev='data/ids-cfimdb-dev.csv',
398
- test='data/ids-cfimdb-test-student.csv',
399
- fine_tune_mode=args.fine_tune_mode,
400
- dev_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-dev-out.csv',
401
- test_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-test-out.csv'
402
- )
403
-
404
- train(config)
405
-
406
- print('Evaluating on cfimdb...')
407
- test(config)
408
-
409
-
410
- if __name__ == "__main__":
411
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py DELETED
@@ -1,222 +0,0 @@
1
- from typing import Union, Tuple, Dict, Any, Optional
2
- import os
3
- import json
4
- from collections import OrderedDict
5
- import torch
6
- from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
7
-
8
- class PretrainedConfig(object):
9
- model_type: str = ""
10
- is_composition: bool = False
11
-
12
- def __init__(self, **kwargs):
13
- # Attributes with defaults
14
- self.return_dict = kwargs.pop("return_dict", True)
15
- self.output_hidden_states = kwargs.pop("output_hidden_states", False)
16
- self.output_attentions = kwargs.pop("output_attentions", False)
17
- self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
18
- self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
19
- self.pruned_heads = kwargs.pop("pruned_heads", {})
20
- self.tie_word_embeddings = kwargs.pop(
21
- "tie_word_embeddings", True
22
- ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
23
-
24
- # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
25
- self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
26
- self.is_decoder = kwargs.pop("is_decoder", False)
27
- self.add_cross_attention = kwargs.pop("add_cross_attention", False)
28
- self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
29
-
30
- # Parameters for sequence generation
31
- self.max_length = kwargs.pop("max_length", 20)
32
- self.min_length = kwargs.pop("min_length", 0)
33
- self.do_sample = kwargs.pop("do_sample", False)
34
- self.early_stopping = kwargs.pop("early_stopping", False)
35
- self.num_beams = kwargs.pop("num_beams", 1)
36
- self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
37
- self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
38
- self.temperature = kwargs.pop("temperature", 1.0)
39
- self.top_k = kwargs.pop("top_k", 50)
40
- self.top_p = kwargs.pop("top_p", 1.0)
41
- self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
42
- self.length_penalty = kwargs.pop("length_penalty", 1.0)
43
- self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
44
- self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
45
- self.bad_words_ids = kwargs.pop("bad_words_ids", None)
46
- self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
47
- self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
48
- self.output_scores = kwargs.pop("output_scores", False)
49
- self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
50
- self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
51
- self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
52
-
53
- # Fine-tuning task arguments
54
- self.architectures = kwargs.pop("architectures", None)
55
- self.finetuning_task = kwargs.pop("finetuning_task", None)
56
- self.id2label = kwargs.pop("id2label", None)
57
- self.label2id = kwargs.pop("label2id", None)
58
- if self.id2label is not None:
59
- kwargs.pop("num_labels", None)
60
- self.id2label = dict((int(key), value) for key, value in self.id2label.items())
61
- # Keys are always strings in JSON so convert ids to int here.
62
- else:
63
- self.num_labels = kwargs.pop("num_labels", 2)
64
-
65
- # Tokenizer arguments
66
- self.tokenizer_class = kwargs.pop("tokenizer_class", None)
67
- self.prefix = kwargs.pop("prefix", None)
68
- self.bos_token_id = kwargs.pop("bos_token_id", None)
69
- self.pad_token_id = kwargs.pop("pad_token_id", None)
70
- self.eos_token_id = kwargs.pop("eos_token_id", None)
71
- self.sep_token_id = kwargs.pop("sep_token_id", None)
72
-
73
- self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
74
-
75
- # task specific arguments
76
- self.task_specific_params = kwargs.pop("task_specific_params", None)
77
-
78
- # TPU arguments
79
- self.xla_device = kwargs.pop("xla_device", None)
80
-
81
- # Name or path to the pretrained checkpoint
82
- self._name_or_path = str(kwargs.pop("name_or_path", ""))
83
-
84
- # Drop the transformers version info
85
- kwargs.pop("transformers_version", None)
86
-
87
- # Additional attributes without default values
88
- for key, value in kwargs.items():
89
- try:
90
- setattr(self, key, value)
91
- except AttributeError as err:
92
- raise err
93
-
94
- @classmethod
95
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
96
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
97
- return cls.from_dict(config_dict, **kwargs)
98
-
99
- @classmethod
100
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
101
- with open(json_file, "r", encoding="utf-8") as reader:
102
- text = reader.read()
103
- return json.loads(text)
104
-
105
- @classmethod
106
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
107
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
108
-
109
- config = cls(**config_dict)
110
-
111
- if hasattr(config, "pruned_heads"):
112
- config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
113
-
114
- # Update config with kwargs if needed
115
- to_remove = []
116
- for key, value in kwargs.items():
117
- if hasattr(config, key):
118
- setattr(config, key, value)
119
- to_remove.append(key)
120
- for key in to_remove:
121
- kwargs.pop(key, None)
122
-
123
- if return_unused_kwargs:
124
- return config, kwargs
125
- else:
126
- return config
127
-
128
- @classmethod
129
- def get_config_dict(
130
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
131
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
132
- cache_dir = kwargs.pop("cache_dir", None)
133
- force_download = kwargs.pop("force_download", False)
134
- resume_download = kwargs.pop("resume_download", False)
135
- proxies = kwargs.pop("proxies", None)
136
- use_auth_token = kwargs.pop("use_auth_token", None)
137
- local_files_only = kwargs.pop("local_files_only", False)
138
- revision = kwargs.pop("revision", None)
139
-
140
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
141
- if os.path.isdir(pretrained_model_name_or_path):
142
- config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
143
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
144
- config_file = pretrained_model_name_or_path
145
- else:
146
- config_file = hf_bucket_url(
147
- pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
148
- )
149
-
150
- try:
151
- # Load from URL or cache if already cached
152
- resolved_config_file = cached_path(
153
- config_file,
154
- cache_dir=cache_dir,
155
- force_download=force_download,
156
- proxies=proxies,
157
- resume_download=resume_download,
158
- local_files_only=local_files_only,
159
- use_auth_token=use_auth_token,
160
- )
161
- # Load config dict
162
- config_dict = cls._dict_from_json_file(resolved_config_file)
163
-
164
- except EnvironmentError as err:
165
- msg = (
166
- f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
167
- f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
168
- f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
169
- )
170
- raise EnvironmentError(msg)
171
-
172
- except json.JSONDecodeError:
173
- msg = (
174
- "Couldn't reach server at '{}' to download configuration file or "
175
- "configuration file is not a valid JSON file. "
176
- "Please check network or file content here: {}.".format(config_file, resolved_config_file)
177
- )
178
- raise EnvironmentError(msg)
179
-
180
- return config_dict, kwargs
181
-
182
-
183
- class BertConfig(PretrainedConfig):
184
- model_type = "bert"
185
-
186
- def __init__(
187
- self,
188
- vocab_size=30522,
189
- hidden_size=768,
190
- num_hidden_layers=12,
191
- num_attention_heads=12,
192
- intermediate_size=3072,
193
- hidden_act="gelu",
194
- hidden_dropout_prob=0.1,
195
- attention_probs_dropout_prob=0.1,
196
- max_position_embeddings=512,
197
- type_vocab_size=2,
198
- initializer_range=0.02,
199
- layer_norm_eps=1e-12,
200
- pad_token_id=0,
201
- gradient_checkpointing=False,
202
- position_embedding_type="absolute",
203
- use_cache=True,
204
- **kwargs
205
- ):
206
- super().__init__(pad_token_id=pad_token_id, **kwargs)
207
-
208
- self.vocab_size = vocab_size
209
- self.hidden_size = hidden_size
210
- self.num_hidden_layers = num_hidden_layers
211
- self.num_attention_heads = num_attention_heads
212
- self.hidden_act = hidden_act
213
- self.intermediate_size = intermediate_size
214
- self.hidden_dropout_prob = hidden_dropout_prob
215
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
216
- self.max_position_embeddings = max_position_embeddings
217
- self.type_vocab_size = type_vocab_size
218
- self.initializer_range = initializer_range
219
- self.layer_norm_eps = layer_norm_eps
220
- self.gradient_checkpointing = gradient_checkpointing
221
- self.position_embedding_type = position_embedding_type
222
- self.use_cache = use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/amazon-polarity.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbe4770cfa6be45add6c9a322044bd4c1901520dde5a2707eca402a74fbe854e
3
- size 870289
 
 
 
 
data/ids-cfimdb-dev.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3087f571b66860fe5d035b5a018d08202ad3fd3720e4821c04b2acf6c7ded559
3
- size 249095
 
 
 
 
data/ids-cfimdb-test-student.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ae611548c9eac879e9ebb406cc9f8ae68ff12f78090e4965af5cbdfa06240f4
3
- size 495595
 
 
 
 
data/ids-cfimdb-train.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:140fc513045a966109faed46a5c7a898767b96714d71bcb9c15f659129fadcea
3
- size 1693182
 
 
 
 
data/ids-sst-dev.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a186ce94577635fbe10beaaddd50f16cccf6c30973221cefdf90deed2a584bfe
3
- size 151384
 
 
 
 
data/ids-sst-test-student.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bdd5a767faa0c26782117e37767ece154c30d5d04fb8727d09c71e3850a55c7b
3
- size 313202
 
 
 
 
data/ids-sst-train.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:03b2b625c090f94a6afd59f114cde5282e2053aab0b101e87ed695d8a0c5b1df
3
- size 1175139
 
 
 
 
data/nli-dev.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c267496435885e724abc71e53669fae59db875bfa13389eab8f9b0b2dfb2b32e
3
- size 782233
 
 
 
 
data/nli-test.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:01688df43ae4c019a86144a0d2351146b124688a55f285071cccd156225a5fdf
3
- size 810423
 
 
 
 
data/nli-train.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f9aeca80b1bda983ee316f854ebc37af8341877fb932dd6a2c6aba978ad112a5
3
- size 38396324
 
 
 
 
data/stsb-dev.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c6e0e9881f1b398abe3e439a482f4686305c3784568c462f6bba58bdff03b0a
3
- size 142187
 
 
 
 
data/stsb-test.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8acbc291c50977d8655934952956016c3e049c2fe04f8a6c454c1bf6acc42ca1
3
- size 108100
 
 
 
 
data/stsb-train.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eae324ff1eac2d0ba769851736eb7232eda64f370a16eb20e74a2c5f8f5fafe0
3
- size 470612
 
 
 
 
evaluation.py DELETED
@@ -1,205 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- '''
4
- Multitask BERT evaluation functions.
5
-
6
- When training your multitask model, you will find it useful to call
7
- model_eval_multitask to evaluate your model on the 3 tasks' dev sets.
8
- '''
9
-
10
- import torch
11
- from sklearn.metrics import f1_score, accuracy_score
12
- from tqdm import tqdm
13
- import numpy as np
14
-
15
-
16
- TQDM_DISABLE = False
17
-
18
-
19
- # Evaluate multitask model on SST only.
20
- def model_eval_sst(dataloader, model, device):
21
- model.eval() # Switch to eval model, will turn off randomness like dropout.
22
- y_true = []
23
- y_pred = []
24
- sents = []
25
- sent_ids = []
26
- for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
27
- b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
28
- batch['labels'], batch['sents'], batch['sent_ids']
29
-
30
- b_ids = b_ids.to(device)
31
- b_mask = b_mask.to(device)
32
-
33
- logits = model.predict_sentiment(b_ids, b_mask)
34
- logits = logits.detach().cpu().numpy()
35
- preds = np.argmax(logits, axis=1).flatten()
36
-
37
- b_labels = b_labels.flatten()
38
- y_true.extend(b_labels)
39
- y_pred.extend(preds)
40
- sents.extend(b_sents)
41
- sent_ids.extend(b_sent_ids)
42
-
43
- f1 = f1_score(y_true, y_pred, average='macro')
44
- acc = accuracy_score(y_true, y_pred)
45
-
46
- return acc, f1, y_pred, y_true, sents, sent_ids
47
-
48
-
49
- # Evaluate multitask model on dev sets.
50
- def model_eval_multitask(sentiment_dataloader,
51
- paraphrase_dataloader,
52
- sts_dataloader,
53
- model, device):
54
- model.eval() # Switch to eval model, will turn off randomness like dropout.
55
-
56
- with torch.no_grad():
57
- # Evaluate sentiment classification.
58
- sst_y_true = []
59
- sst_y_pred = []
60
- sst_sent_ids = []
61
- for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
62
- b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']
63
-
64
- b_ids = b_ids.to(device)
65
- b_mask = b_mask.to(device)
66
-
67
- logits = model.predict_sentiment(b_ids, b_mask)
68
- y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
69
- b_labels = b_labels.flatten().cpu().numpy()
70
-
71
- sst_y_pred.extend(y_hat)
72
- sst_y_true.extend(b_labels)
73
- sst_sent_ids.extend(b_sent_ids)
74
-
75
- sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))
76
-
77
- # Evaluate paraphrase detection.
78
- para_y_true = []
79
- para_y_pred = []
80
- para_sent_ids = []
81
- for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
82
- (b_ids1, b_mask1,
83
- b_ids2, b_mask2,
84
- b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
85
- batch['token_ids_2'], batch['attention_mask_2'],
86
- batch['labels'], batch['sent_ids'])
87
-
88
- b_ids1 = b_ids1.to(device)
89
- b_mask1 = b_mask1.to(device)
90
- b_ids2 = b_ids2.to(device)
91
- b_mask2 = b_mask2.to(device)
92
-
93
- logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
94
- y_hat = logits.sigmoid().round().flatten().cpu().numpy()
95
- b_labels = b_labels.flatten().cpu().numpy()
96
-
97
- para_y_pred.extend(y_hat)
98
- para_y_true.extend(b_labels)
99
- para_sent_ids.extend(b_sent_ids)
100
-
101
- paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))
102
-
103
- # Evaluate semantic textual similarity.
104
- sts_y_true = []
105
- sts_y_pred = []
106
- sts_sent_ids = []
107
- for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
108
- (b_ids1, b_mask1,
109
- b_ids2, b_mask2,
110
- b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
111
- batch['token_ids_2'], batch['attention_mask_2'],
112
- batch['labels'], batch['sent_ids'])
113
-
114
- b_ids1 = b_ids1.to(device)
115
- b_mask1 = b_mask1.to(device)
116
- b_ids2 = b_ids2.to(device)
117
- b_mask2 = b_mask2.to(device)
118
-
119
- logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
120
- y_hat = logits.flatten().cpu().numpy()
121
- b_labels = b_labels.flatten().cpu().numpy()
122
-
123
- sts_y_pred.extend(y_hat)
124
- sts_y_true.extend(b_labels)
125
- sts_sent_ids.extend(b_sent_ids)
126
- pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)
127
- sts_corr = pearson_mat[1][0]
128
-
129
- print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')
130
- print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')
131
- print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')
132
-
133
- return (sentiment_accuracy,sst_y_pred, sst_sent_ids,
134
- paraphrase_accuracy, para_y_pred, para_sent_ids,
135
- sts_corr, sts_y_pred, sts_sent_ids)
136
-
137
-
138
- # Evaluate multitask model on test sets.
139
- def model_eval_test_multitask(sentiment_dataloader,
140
- paraphrase_dataloader,
141
- sts_dataloader,
142
- model, device):
143
- model.eval() # Switch to eval model, will turn off randomness like dropout.
144
-
145
- with torch.no_grad():
146
- # Evaluate sentiment classification.
147
- sst_y_pred = []
148
- sst_sent_ids = []
149
- for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
150
- b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
151
-
152
- b_ids = b_ids.to(device)
153
- b_mask = b_mask.to(device)
154
-
155
- logits = model.predict_sentiment(b_ids, b_mask)
156
- y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
157
-
158
- sst_y_pred.extend(y_hat)
159
- sst_sent_ids.extend(b_sent_ids)
160
-
161
- # Evaluate paraphrase detection.
162
- para_y_pred = []
163
- para_sent_ids = []
164
- for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
165
- (b_ids1, b_mask1,
166
- b_ids2, b_mask2,
167
- b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
168
- batch['token_ids_2'], batch['attention_mask_2'],
169
- batch['sent_ids'])
170
-
171
- b_ids1 = b_ids1.to(device)
172
- b_mask1 = b_mask1.to(device)
173
- b_ids2 = b_ids2.to(device)
174
- b_mask2 = b_mask2.to(device)
175
-
176
- logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
177
- y_hat = logits.sigmoid().round().flatten().cpu().numpy()
178
-
179
- para_y_pred.extend(y_hat)
180
- para_sent_ids.extend(b_sent_ids)
181
-
182
- # Evaluate semantic textual similarity.
183
- sts_y_pred = []
184
- sts_sent_ids = []
185
- for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
186
- (b_ids1, b_mask1,
187
- b_ids2, b_mask2,
188
- b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
189
- batch['token_ids_2'], batch['attention_mask_2'],
190
- batch['sent_ids'])
191
-
192
- b_ids1 = b_ids1.to(device)
193
- b_mask1 = b_mask1.to(device)
194
- b_ids2 = b_ids2.to(device)
195
- b_mask2 = b_mask2.to(device)
196
-
197
- logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
198
- y_hat = logits.flatten().cpu().numpy()
199
-
200
- sts_y_pred.extend(y_hat)
201
- sts_sent_ids.extend(b_sent_ids)
202
-
203
- return (sst_y_pred, sst_sent_ids,
204
- para_y_pred, para_sent_ids,
205
- sts_y_pred, sts_sent_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
justfile DELETED
@@ -1,14 +0,0 @@
1
- # Testing on Google Cloud VM with no GPU and 8 cpu cores.
2
-
3
- # If this doesn't meet your need, add the --use_gpu
4
- # or --num-cpu-cores arguments to the existing commands.
5
-
6
-
7
- default:
8
- @just --list
9
-
10
- last-linear:
11
- python classifier.py
12
-
13
- full-model:
14
- python classifier.py --fine-tune-mode full-model --lr 1e-5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
optimizer.py DELETED
@@ -1,90 +0,0 @@
1
- from typing import Callable, Iterable, Tuple
2
- import math
3
-
4
- import torch
5
- from torch.optim import Optimizer
6
-
7
-
8
- class AdamW(Optimizer):
9
- def __init__(
10
- self,
11
- params: Iterable[torch.nn.parameter.Parameter],
12
- lr: float = 1e-3,
13
- betas: Tuple[float, float] = (0.9, 0.999),
14
- eps: float = 1e-6,
15
- weight_decay: float = 0.0,
16
- correct_bias: bool = True,
17
- ):
18
- if lr < 0.0:
19
- raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
20
- if not 0.0 <= betas[0] < 1.0:
21
- raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
22
- if not 0.0 <= betas[1] < 1.0:
23
- raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
24
- if not 0.0 <= eps:
25
- raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
26
- defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
27
- super().__init__(params, defaults)
28
-
29
- def step(self, closure: Callable = None):
30
- loss = None
31
- if closure is not None:
32
- loss = closure()
33
-
34
- for group in self.param_groups:
35
- for p in group["params"]:
36
- if p.grad is None:
37
- continue
38
- grad = p.grad.data
39
- if grad.is_sparse:
40
- raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
41
-
42
- # Access state
43
- state = self.state[p]
44
-
45
- # Initialize state if not already done
46
- if len(state) == 0:
47
- state["step"] = 0
48
- state["exp_avg"] = torch.zeros_like(p.data)
49
- state["exp_avg_sq"] = torch.zeros_like(p.data)
50
-
51
- # Hyperparameters
52
- alpha = group["lr"]
53
- beta1, beta2 = group["betas"]
54
- eps = group["eps"]
55
- weight_decay = group["weight_decay"]
56
- correct_bias = group["correct_bias"]
57
-
58
- # Retrieve state variables
59
- exp_avg = state["exp_avg"]
60
- exp_avg_sq = state["exp_avg_sq"]
61
- step = state["step"]
62
-
63
- # Update step
64
- step += 1
65
- state["step"] = step
66
-
67
- # Update biased first and second moment estimates
68
- exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
69
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
70
-
71
- # Compute bias-corrected moments
72
- if correct_bias:
73
- bias_correction1 = 1 - beta1 ** step
74
- bias_correction2 = 1 - beta2 ** step
75
- exp_avg_corr = exp_avg / bias_correction1
76
- exp_avg_sq_corr = exp_avg_sq / bias_correction2
77
- else:
78
- exp_avg_corr = exp_avg
79
- exp_avg_sq_corr = exp_avg_sq
80
-
81
- # Update parameters
82
- denom = exp_avg_sq_corr.sqrt().add_(eps)
83
- step_size = alpha
84
- p.data.addcdiv_(exp_avg_corr, denom, value=-step_size)
85
-
86
- # Apply weight decay
87
- if weight_decay != 0:
88
- p.data.add_(p.data, alpha=-alpha * weight_decay)
89
-
90
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
optimizer_test.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:77b817e0dce16a9bc8d3a6bcb88035db68f7d783dc8a565737581fadd05db815
3
- size 152
 
 
 
 
optimizer_test.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from optimizer import AdamW
4
-
5
- seed = 0
6
-
7
-
8
- def test_optimizer(opt_class) -> torch.Tensor:
9
- rng = np.random.default_rng(seed)
10
- torch.manual_seed(seed)
11
- model = torch.nn.Linear(3, 2, bias=False)
12
- opt = opt_class(
13
- model.parameters(),
14
- lr=1e-3,
15
- weight_decay=1e-4,
16
- correct_bias=True,
17
- )
18
- for i in range(1000):
19
- opt.zero_grad()
20
- x = torch.FloatTensor(rng.uniform(size=[model.in_features]))
21
- y_hat = model(x)
22
- y = torch.Tensor([x[0] + x[1], -x[2]])
23
- loss = ((y - y_hat) ** 2).sum()
24
- loss.backward()
25
- opt.step()
26
- return model.weight.detach()
27
-
28
-
29
- ref = torch.tensor(np.load("optimizer_test.npy"))
30
- actual = test_optimizer(AdamW)
31
- print(ref)
32
- print(actual)
33
- assert torch.allclose(ref, actual, atol=1e-6, rtol=1e-4)
34
- print("Optimizer test passed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
predictions/README DELETED
@@ -1,2 +0,0 @@
1
- By default, `classifier.py` and `multitask_classifier.py` write your model predictions into this folder.
2
- Before running prepare_submit.py, make sure that this directory has been populated!
 
 
 
prepare_submit.py DELETED
@@ -1,18 +0,0 @@
1
- # Creates a zip file for submission on Gradescope.
2
-
3
- import os
4
- import zipfile
5
-
6
- required_files = [p for p in os.listdir('.') if p.endswith('.py')] + \
7
- [f'predictions/{p}' for p in os.listdir('predictions')]
8
-
9
- def main():
10
- aid = 'cs224n_default_final_project_submission'
11
- path = os.getcwd()
12
- with zipfile.ZipFile(f"{aid}.zip", 'w') as zz:
13
- for file in required_files:
14
- zz.write(file, os.path.join(".", file))
15
- print(f"Submission zip file created: {aid}.zip")
16
-
17
- if __name__ == '__main__':
18
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompt DELETED
@@ -1,3 +0,0 @@
1
- Tôi muốn finetune minBERT bằng phương pháp Unsupervised SimCSE để thực hiện sentiment analysis nhưng chưa biết phải làm như thế nào. theo như tôi hiểu thì tôi sẽ finetune mô hình minBERT bằng SimCSE để có được embeddings tốt hơn, sau đó sẽ dùng embeddings này để truyền qua SentimentClassifier để phân loại. Tuy nhiên, hướng tiếp cận đúng đắn là gì?
2
- Tôi đã nghĩ đến hai cách sau đây (hoặc có thể cách khác nhưng chưa nghĩ ra). Bạn xem xét thử nhé!
3
- 1. Finetune minBERT bằng SimCSE trước rồi mới finetune SentimentClassifier: sử dụng dataset STS-B hoặc Twitter Sentiment Dataset để finetune minBERT, rồi đánh giá độ
 
 
 
 
sanity_check.data DELETED
Binary file (56.4 kB)
 
sanity_check.py DELETED
@@ -1,19 +0,0 @@
1
- import torch
2
- from bert import BertModel
3
-
4
-
5
- sanity_data = torch.load("./sanity_check.data", weights_only=True)
6
- sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
7
- [101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
8
- att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
9
-
10
- # Load model.
11
- bert = BertModel.from_pretrained('bert-base-uncased')
12
- outputs = bert(sent_ids, att_mask)
13
- att_mask = att_mask.unsqueeze(-1)
14
- outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
15
- sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
16
-
17
- for k in ['last_hidden_state', 'pooler_output']:
18
- assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
19
- print("Your BERT implementation is correct!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.sh DELETED
@@ -1,13 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- conda create -n cs224n_dfp python=3.8
4
- conda activate cs224n_dfp
5
-
6
- pip install torch torchvision torchaudio
7
- pip install tqdm==4.58.0
8
- pip install requests==2.25.1
9
- pip install importlib-metadata==3.7.0
10
- pip install filelock==3.0.12
11
- pip install sklearn==0.0
12
- pip install tokenizers==0.15
13
- pip install explainaboard_client==0.0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.py DELETED
The diff for this file is too large to render. See raw diff
 
unsup_simcse.py DELETED
@@ -1,252 +0,0 @@
1
- import csv
2
- import torch
3
- import random
4
- import argparse
5
- import numpy as np
6
- import pandas as pd
7
- import torch.nn.functional as F
8
-
9
- from tqdm import tqdm
10
- from torch import Tensor
11
- from types import SimpleNamespace
12
- from torch.utils.data import Dataset, DataLoader
13
- from sklearn.metrics import f1_score, accuracy_score
14
-
15
- from bert import BertModel
16
- from optimizer import AdamW
17
- from classifier import seed_everything, tokenizer
18
- from classifier import SentimentDataset, BertSentimentClassifier
19
-
20
-
21
- TQDM_DISABLE = False
22
-
23
-
24
- class AmazonDataset(Dataset):
25
- def __init__(self, dataset, args):
26
- self.dataset = dataset
27
- self.p = args
28
-
29
- def __len__(self):
30
- return len(self.dataset)
31
-
32
- def __getitem__(self, idx):
33
- return self.dataset[idx]
34
-
35
- def pad_data(self, data):
36
- sents = [x[0] for x in data]
37
- sent_ids = [x[1] for x in data]
38
- encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
39
- token_ids = torch.LongTensor(encoding['input_ids'])
40
- attension_mask = torch.LongTensor(encoding['attention_mask'])
41
-
42
- return token_ids, attension_mask, sent_ids
43
-
44
- def collate_fn(self, data):
45
- token_ids, attention_mask, sent_ids = self.pad_data(data)
46
-
47
- batched_data = {
48
- 'token_ids': token_ids,
49
- 'attention_mask': attention_mask,
50
- 'sent_ids': sent_ids
51
- }
52
-
53
- return batched_data
54
-
55
-
56
- def load_data(filename, flag='train'):
57
- '''
58
- - for amazon dataset: list of (sent, sent_id)
59
- - for test dataset: list of (sent, sent_id)
60
- - for train dataset: list of (sent, label, sent_id)
61
- '''
62
-
63
- if flag == 'amazon':
64
- df = pd.read_parquet(filename)
65
- data = list(zip(df['content'], df.index))
66
- else:
67
- data, num_labels = [], set()
68
-
69
- with open(filename, 'r') as fp:
70
- if flag == 'test':
71
- for record in csv.DictReader(fp, delimiter = '\t'):
72
- sent = record['sentence'].lower().strip()
73
- sent_id = record['id'].lower().strip()
74
- data.append((sent,sent_id))
75
- else:
76
- for record in csv.DictReader(fp, delimiter = '\t'):
77
- sent = record['sentence'].lower().strip()
78
- sent_id = record['id'].lower().strip()
79
- label = int(record['sentiment'].strip())
80
- num_labels.add(label)
81
- data.append((sent, label, sent_id))
82
-
83
- print(f"load {len(data)} data from {filename}")
84
- if flag in ['test', 'amazon']:
85
- return data
86
- else:
87
- return data, len(num_labels)
88
-
89
-
90
- def save_model(model, optimizer, args, config, filepath):
91
- save_info = {
92
- 'model': model.state_dict(),
93
- 'optim': optimizer.state_dict(),
94
- 'args': args,
95
- 'model_config': config,
96
- 'system_rng': random.getstate(),
97
- 'numpy_rng': np.random.get_state(),
98
- 'torch_rng': torch.random.get_rng_state(),
99
- }
100
-
101
- torch.save(save_info, filepath)
102
- print(f"save the model to {filepath}")
103
-
104
-
105
- def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
106
- '''
107
- embeds_1: [batch_size, hidden_size]
108
- embeds_2: [batch_size, hidden_size]
109
- '''
110
-
111
- # [batch_size, batch_size]
112
- sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
113
-
114
- # [batch_size]
115
- positive_sim = torch.diagonal(sim_matrix)
116
-
117
- # [batch_size]
118
- nume = torch.exp(positive_sim)
119
-
120
- # [batch_size]
121
- deno = torch.exp(sim_matrix).sum(1)
122
-
123
- # [batch_size]
124
- loss_per_batch = -torch.log(nume / deno)
125
-
126
- return loss_per_batch.mean()
127
-
128
-
129
- def train(args):
130
- '''
131
- Training Pipeline
132
- -----------------
133
- 1. Load the Amazon Polarity and SST Dataset.
134
- 2. Determine batch_size (64) and number of batches (?).
135
- 3. Initialize SentimentClassifier (including bert).
136
- 4. Looping through 10 epoches.
137
- 5. Finetune minBERT with SimCSE loss function.
138
- 6. Finetune Classifier with cross-entropy function.
139
- 7. Backpropagation using Adam Optimizer for both.
140
- 8. Evaluating the model on dev dataset.
141
- 9. If dev_acc > best_dev_acc: save_model(...)
142
- '''
143
-
144
- amazon_data = load_data(args.train_bert, 'amazon')
145
- train_data, num_labels = load_data(args.train, 'train')
146
- dev_data = load_data(args.dev, 'valid')
147
-
148
- amazon_dataset = AmazonDataset(amazon_data, args)
149
- train_dataset = SentimentDataset(train_data, args)
150
- dev_dataset = SentimentDataset(dev_data, args)
151
-
152
- amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
153
- num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
154
- train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
155
- num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
156
- dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
157
- num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
158
-
159
- config = SimpleNamespace(
160
- hidden_dropout_prob=args.hidden_dropout_prob,
161
- num_labels=num_labels,
162
- hidden_size=768,
163
- data_dir='.',
164
- fine_tune_mode='full-model'
165
- )
166
-
167
- device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
168
- model = BertSentimentClassifier(config)
169
- model = model.to(device)
170
-
171
- optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse)
172
- optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
173
- best_dev_acc = 0
174
-
175
- # ---- Training minBERT using SimCSE ---- #
176
- for epoch in range(args.epochs):
177
- model.bert.train()
178
- train_loss = num_batches = 0
179
- for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
180
- b_ids, b_mask = batch['token_ids'], batch['attention_mask']
181
- b_ids = b_ids.to(device)
182
- b_mask = b_mask.to(device)
183
-
184
- # Get different embeddings with different dropout masks
185
- logits_1 = model.bert(b_ids, b_mask)['pooler_output']
186
- logits_2 = model.bert(b_ids, b_mask)['pooler_output']
187
-
188
- # Calculate mean SimCSE loss function
189
- loss = contrastive_loss(logits_1, logits_2)
190
-
191
- # Back propagation
192
- optimizer_cse.zero_grad()
193
- loss.backward()
194
- optimizer_cse.step()
195
-
196
- train_loss += loss.item()
197
- num_batches += 1
198
-
199
- train_loss = train_loss / num_batches
200
- print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
201
-
202
-
203
- def get_args():
204
- parser = argparse.ArgumentParser()
205
- parser.add_argument("--seed", type=int, default=11711)
206
- parser.add_argument("--num-cpu-cores", type=int, default=8)
207
- parser.add_argument("--epochs", type=int, default=10)
208
- parser.add_argument("--use_gpu", action='store_true')
209
- parser.add_argument("--batch_size_cse", type=int, default=8)
210
- parser.add_argument("--batch_size_sst", type=int, default=64)
211
- parser.add_argument("--batch_size_cfimdb", type=int, default=8)
212
- parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
213
- parser.add_argument("--lr_cse", type=float, default=1e-5)
214
- parser.add_argument("--lr_classifier", type=float, default=1e-5)
215
-
216
- args = parser.parse_args()
217
- return args
218
-
219
-
220
- if __name__ == "__main__":
221
- args = get_args()
222
- seed_everything(args.seed)
223
- torch.set_num_threads(args.num_cpu_cores)
224
-
225
- print('Finetuning minBERT with Unsupervised SimCSE...')
226
- config = SimpleNamespace(
227
- filepath='contrastive-nli.pt',
228
- lr_cse=args.lr_cse,
229
- lr_classifier=args.lr_classifier,
230
- num_cpu_cores=args.num_cpu_cores,
231
- use_gpu=args.use_gpu,
232
- epochs=args.epochs,
233
- batch_size_cse=args.batch_size_cse,
234
- batch_size_classifier=args.batch_size_sst,
235
- hidden_dropout_prob=args.hidden_dropout_prob,
236
- train_bert='data/amazon-polarity.parquet',
237
- train='data/ids-sst-train.csv',
238
- dev='data/ids-sst-dev.csv',
239
- test='data/ids-sst-test-student.csv'
240
- )
241
-
242
- train(config)
243
-
244
- # model = BertModel.from_pretrained('bert-base-uncased')
245
-
246
- # model.eval()
247
-
248
- # s = set()
249
- # for param in model.parameters():
250
- # s.add(param.requires_grad)
251
-
252
- # print(s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,347 +0,0 @@
1
- from typing import Dict, List, Optional, Union, Tuple, BinaryIO
2
- import os
3
- import sys
4
- import json
5
- import tempfile
6
- import copy
7
- from tqdm.auto import tqdm
8
- from functools import partial
9
- from urllib.parse import urlparse
10
- from pathlib import Path
11
- import requests
12
- from hashlib import sha256
13
- from filelock import FileLock
14
- import importlib_metadata
15
- import torch
16
- import torch.nn as nn
17
- from torch import Tensor
18
- import fnmatch
19
-
20
- __version__ = "4.0.0"
21
- _torch_version = importlib_metadata.version("torch")
22
-
23
- hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
24
- default_cache_path = os.path.join(hf_cache_home, "transformers")
25
- PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
26
- PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
27
- TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
28
-
29
- PRESET_MIRROR_DICT = {
30
- "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
31
- "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
32
- }
33
- HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
34
- WEIGHTS_NAME = "pytorch_model.bin"
35
- CONFIG_NAME = "config.json"
36
-
37
-
38
- def is_torch_available():
39
- return True
40
-
41
-
42
- def is_tf_available():
43
- return False
44
-
45
-
46
- def is_remote_url(url_or_filename):
47
- parsed = urlparse(url_or_filename)
48
- return parsed.scheme in ("http", "https")
49
-
50
-
51
- def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
52
- headers = copy.deepcopy(headers)
53
- if resume_size > 0:
54
- headers["Range"] = "bytes=%d-" % (resume_size,)
55
- r = requests.get(url, stream=True, proxies=proxies, headers=headers)
56
- r.raise_for_status()
57
- content_length = r.headers.get("Content-Length")
58
- total = resume_size + int(content_length) if content_length is not None else None
59
- progress = tqdm(
60
- unit="B",
61
- unit_scale=True,
62
- total=total,
63
- initial=resume_size,
64
- desc="Downloading",
65
- disable=False,
66
- )
67
- for chunk in r.iter_content(chunk_size=1024):
68
- if chunk: # filter out keep-alive new chunks
69
- progress.update(len(chunk))
70
- temp_file.write(chunk)
71
- progress.close()
72
-
73
-
74
- def url_to_filename(url: str, etag: Optional[str] = None) -> str:
75
- url_bytes = url.encode("utf-8")
76
- filename = sha256(url_bytes).hexdigest()
77
-
78
- if etag:
79
- etag_bytes = etag.encode("utf-8")
80
- filename += "." + sha256(etag_bytes).hexdigest()
81
-
82
- if url.endswith(".h5"):
83
- filename += ".h5"
84
-
85
- return filename
86
-
87
-
88
- def hf_bucket_url(
89
- model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
90
- ) -> str:
91
- if subfolder is not None:
92
- filename = f"{subfolder}/{filename}"
93
-
94
- if mirror:
95
- endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
96
- legacy_format = "/" not in model_id
97
- if legacy_format:
98
- return f"{endpoint}/{model_id}-{filename}"
99
- else:
100
- return f"{endpoint}/{model_id}/{filename}"
101
-
102
- if revision is None:
103
- revision = "main"
104
- return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
105
-
106
-
107
- def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
108
- ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
109
- if is_torch_available():
110
- ua += f"; torch/{_torch_version}"
111
- if is_tf_available():
112
- ua += f"; tensorflow/{_tf_version}"
113
- if isinstance(user_agent, dict):
114
- ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
115
- elif isinstance(user_agent, str):
116
- ua += "; " + user_agent
117
- return ua
118
-
119
-
120
- def get_from_cache(
121
- url: str,
122
- cache_dir=None,
123
- force_download=False,
124
- proxies=None,
125
- etag_timeout=10,
126
- resume_download=False,
127
- user_agent: Union[Dict, str, None] = None,
128
- use_auth_token: Union[bool, str, None] = None,
129
- local_files_only=False,
130
- ) -> Optional[str]:
131
- if cache_dir is None:
132
- cache_dir = TRANSFORMERS_CACHE
133
- if isinstance(cache_dir, Path):
134
- cache_dir = str(cache_dir)
135
-
136
- os.makedirs(cache_dir, exist_ok=True)
137
-
138
- headers = {"user-agent": http_user_agent(user_agent)}
139
- if isinstance(use_auth_token, str):
140
- headers["authorization"] = "Bearer {}".format(use_auth_token)
141
- elif use_auth_token:
142
- token = HfFolder.get_token()
143
- if token is None:
144
- raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
145
- headers["authorization"] = "Bearer {}".format(token)
146
-
147
- url_to_download = url
148
- etag = None
149
- if not local_files_only:
150
- try:
151
- r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
152
- r.raise_for_status()
153
- etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
154
- # We favor a custom header indicating the etag of the linked resource, and
155
- # we fallback to the regular etag header.
156
- # If we don't have any of those, raise an error.
157
- if etag is None:
158
- raise OSError(
159
- "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
160
- )
161
- # In case of a redirect,
162
- # save an extra redirect on the request.get call,
163
- # and ensure we download the exact atomic version even if it changed
164
- # between the HEAD and the GET (unlikely, but hey).
165
- if 300 <= r.status_code <= 399:
166
- url_to_download = r.headers["Location"]
167
- except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
168
- # etag is already None
169
- pass
170
-
171
- filename = url_to_filename(url, etag)
172
-
173
- # get cache path to put the file
174
- cache_path = os.path.join(cache_dir, filename)
175
-
176
- # etag is None == we don't have a connection or we passed local_files_only.
177
- # try to get the last downloaded one
178
- if etag is None:
179
- if os.path.exists(cache_path):
180
- return cache_path
181
- else:
182
- matching_files = [
183
- file
184
- for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
185
- if not file.endswith(".json") and not file.endswith(".lock")
186
- ]
187
- if len(matching_files) > 0:
188
- return os.path.join(cache_dir, matching_files[-1])
189
- else:
190
- # If files cannot be found and local_files_only=True,
191
- # the models might've been found if local_files_only=False
192
- # Notify the user about that
193
- if local_files_only:
194
- raise FileNotFoundError(
195
- "Cannot find the requested files in the cached path and outgoing traffic has been"
196
- " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
197
- " to False."
198
- )
199
- else:
200
- raise ValueError(
201
- "Connection error, and we cannot find the requested files in the cached path."
202
- " Please try again or make sure your Internet connection is on."
203
- )
204
-
205
- # From now on, etag is not None.
206
- if os.path.exists(cache_path) and not force_download:
207
- return cache_path
208
-
209
- # Prevent parallel downloads of the same file with a lock.
210
- lock_path = cache_path + ".lock"
211
- with FileLock(lock_path):
212
-
213
- # If the download just completed while the lock was activated.
214
- if os.path.exists(cache_path) and not force_download:
215
- # Even if returning early like here, the lock will be released.
216
- return cache_path
217
-
218
- if resume_download:
219
- incomplete_path = cache_path + ".incomplete"
220
-
221
- @contextmanager
222
- def _resumable_file_manager() -> "io.BufferedWriter":
223
- with open(incomplete_path, "ab") as f:
224
- yield f
225
-
226
- temp_file_manager = _resumable_file_manager
227
- if os.path.exists(incomplete_path):
228
- resume_size = os.stat(incomplete_path).st_size
229
- else:
230
- resume_size = 0
231
- else:
232
- temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
233
- resume_size = 0
234
-
235
- # Download to temporary file, then copy to cache dir once finished.
236
- # Otherwise you get corrupt cache entries if the download gets interrupted.
237
- with temp_file_manager() as temp_file:
238
- http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
239
-
240
- os.replace(temp_file.name, cache_path)
241
-
242
- meta = {"url": url, "etag": etag}
243
- meta_path = cache_path + ".json"
244
- with open(meta_path, "w") as meta_file:
245
- json.dump(meta, meta_file)
246
-
247
- return cache_path
248
-
249
-
250
- def cached_path(
251
- url_or_filename,
252
- cache_dir=None,
253
- force_download=False,
254
- proxies=None,
255
- resume_download=False,
256
- user_agent: Union[Dict, str, None] = None,
257
- extract_compressed_file=False,
258
- force_extract=False,
259
- use_auth_token: Union[bool, str, None] = None,
260
- local_files_only=False,
261
- ) -> Optional[str]:
262
- if cache_dir is None:
263
- cache_dir = TRANSFORMERS_CACHE
264
- if isinstance(url_or_filename, Path):
265
- url_or_filename = str(url_or_filename)
266
- if isinstance(cache_dir, Path):
267
- cache_dir = str(cache_dir)
268
-
269
- if is_remote_url(url_or_filename):
270
- # URL, so get it from the cache (downloading if necessary)
271
- output_path = get_from_cache(
272
- url_or_filename,
273
- cache_dir=cache_dir,
274
- force_download=force_download,
275
- proxies=proxies,
276
- resume_download=resume_download,
277
- user_agent=user_agent,
278
- use_auth_token=use_auth_token,
279
- local_files_only=local_files_only,
280
- )
281
- elif os.path.exists(url_or_filename):
282
- # File, and it exists.
283
- output_path = url_or_filename
284
- elif urlparse(url_or_filename).scheme == "":
285
- # File, but it doesn't exist.
286
- raise EnvironmentError("file {} not found".format(url_or_filename))
287
- else:
288
- # Something unknown
289
- raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
290
-
291
- if extract_compressed_file:
292
- if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
293
- return output_path
294
-
295
- # Path where we extract compressed archives
296
- # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
297
- output_dir, output_file = os.path.split(output_path)
298
- output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
299
- output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
300
-
301
- if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
302
- return output_path_extracted
303
-
304
- # Prevent parallel extractions
305
- lock_path = output_path + ".lock"
306
- with FileLock(lock_path):
307
- shutil.rmtree(output_path_extracted, ignore_errors=True)
308
- os.makedirs(output_path_extracted)
309
- if is_zipfile(output_path):
310
- with ZipFile(output_path, "r") as zip_file:
311
- zip_file.extractall(output_path_extracted)
312
- zip_file.close()
313
- elif tarfile.is_tarfile(output_path):
314
- tar_file = tarfile.open(output_path)
315
- tar_file.extractall(output_path_extracted)
316
- tar_file.close()
317
- else:
318
- raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
319
-
320
- return output_path_extracted
321
-
322
- return output_path
323
-
324
-
325
- def get_parameter_dtype(parameter: Union[nn.Module]):
326
- try:
327
- return next(parameter.parameters()).dtype
328
- except StopIteration:
329
- # For nn.DataParallel compatibility in PyTorch 1.5
330
-
331
- def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
332
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
333
- return tuples
334
-
335
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
336
- first_tuple = next(gen)
337
- return first_tuple[1].dtype
338
-
339
-
340
- def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
341
- # attention_mask [batch_size, seq_length]
342
- assert attention_mask.dim() == 2
343
- # [batch_size, 1, 1, seq_length] for multi-head attention
344
- extended_attention_mask = attention_mask[:, None, None, :]
345
- extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
346
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
347
- return extended_attention_mask